HyperAccel commited on
Commit
543258e
·
verified ·
1 Parent(s): 70fc7c5

Upload tiny-random kimi_linear model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": false,
3
+ "add_cross_attention": false,
4
+ "architectures": [
5
+ "KimiLinearForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_kimi_linear.KimiLinearConfig",
9
+ "AutoModelForCausalLM": "modeling_kimi_linear.KimiLinearForCausalLM"
10
+ },
11
+ "bos_token_id": 163584,
12
+ "cross_attention_hidden_size": null,
13
+ "decoder_start_token_id": null,
14
+ "dtype": "float32",
15
+ "eos_token_id": 163586,
16
+ "finetuning_task": null,
17
+ "first_k_dense_replace": 1,
18
+ "head_dim": 32,
19
+ "hidden_act": "silu",
20
+ "hidden_size": 512,
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 256,
23
+ "is_decoder": false,
24
+ "kv_lora_rank": 32,
25
+ "linear_attn_config": {
26
+ "full_attn_layers": [
27
+ 4
28
+ ],
29
+ "head_dim": 64,
30
+ "kda_layers": [
31
+ 1,
32
+ 2,
33
+ 3
34
+ ],
35
+ "num_heads": 8,
36
+ "short_conv_kernel_size": 4
37
+ },
38
+ "mla_use_nope": true,
39
+ "model_type": "kimi_linear",
40
+ "moe_intermediate_size": 256,
41
+ "moe_layer_freq": 1,
42
+ "moe_renormalize": true,
43
+ "moe_router_activation_func": "sigmoid",
44
+ "num_attention_heads": 8,
45
+ "num_expert_group": 1,
46
+ "num_experts": 4,
47
+ "num_experts_per_token": 2,
48
+ "num_hidden_layers": 5,
49
+ "num_key_value_heads": 8,
50
+ "num_nextn_predict_layers": 0,
51
+ "num_shared_experts": 1,
52
+ "pad_token_id": 163839,
53
+ "prefix": null,
54
+ "pruned_heads": {},
55
+ "q_lora_rank": null,
56
+ "qk_nope_head_dim": 32,
57
+ "qk_rope_head_dim": 16,
58
+ "rms_norm_eps": 1e-05,
59
+ "rope_parameters": {
60
+ "rope_theta": 10000.0,
61
+ "rope_type": "default"
62
+ },
63
+ "rope_theta": 10000.0,
64
+ "routed_scaling_factor": 2.446,
65
+ "sep_token_id": null,
66
+ "task_specific_params": null,
67
+ "tf_legacy_loss": false,
68
+ "tie_encoder_decoder": false,
69
+ "tie_word_embeddings": false,
70
+ "tokenizer_class": null,
71
+ "topk_group": 1,
72
+ "torchscript": false,
73
+ "transformers_version": "5.3.0",
74
+ "use_bfloat16": false,
75
+ "use_cache": true,
76
+ "use_grouped_topk": true,
77
+ "v_head_dim": 32,
78
+ "vocab_size": 163840
79
+ }
configuration_kimi_linear.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Optional
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class KimiLinearConfig(PretrainedConfig):
8
+ model_type = "kimi_linear"
9
+ keys_to_ignore_at_inference = ["past_key_values"]
10
+
11
+ def __init__(
12
+ self,
13
+ model_type="kimi_linear",
14
+ vocab_size=163840,
15
+ hidden_size=4096,
16
+ head_dim=None,
17
+ intermediate_size=11008,
18
+ num_hidden_layers=32,
19
+ num_attention_heads=32,
20
+ num_key_value_heads=None,
21
+ hidden_act="silu",
22
+ initializer_range=0.02,
23
+ rms_norm_eps=1e-6,
24
+ use_cache=True,
25
+ pad_token_id=0,
26
+ bos_token_id=1,
27
+ eos_token_id=2,
28
+ rope_theta=10000.0,
29
+ rope_scaling=None,
30
+ tie_word_embeddings=False,
31
+ moe_intermediate_size: Optional[int] = None,
32
+ moe_renormalize: bool = True,
33
+ moe_router_activation_func: str = "sigmoid",
34
+ num_experts: Optional[int] = None,
35
+ num_experts_per_token: Optional[int] = None,
36
+ num_shared_experts: int = 0,
37
+ routed_scaling_factor: float = 1.0,
38
+ first_k_dense_replace: int = 0,
39
+ moe_layer_freq: int = 1,
40
+ use_grouped_topk: bool = True,
41
+ num_expert_group: int = 1,
42
+ topk_group: int = 1,
43
+ q_lora_rank: Optional[int] = None,
44
+ kv_lora_rank: Optional[int] = None,
45
+ qk_nope_head_dim: Optional[int] = None,
46
+ qk_rope_head_dim: Optional[int] = None,
47
+ v_head_dim: Optional[int] = None,
48
+ mla_use_nope: Optional[bool] = False,
49
+ num_nextn_predict_layers: int = 0,
50
+ linear_attn_config: Optional[dict] = None,
51
+ **kwargs,
52
+ ):
53
+ self.model_type = model_type
54
+ self.vocab_size = vocab_size
55
+ self.hidden_size = hidden_size
56
+ self.head_dim = (
57
+ head_dim if head_dim is not None else hidden_size // num_attention_heads
58
+ )
59
+ self.intermediate_size = intermediate_size
60
+ self.num_hidden_layers = num_hidden_layers
61
+ self.num_attention_heads = num_attention_heads
62
+
63
+ # for backward compatibility
64
+ if num_key_value_heads is None:
65
+ num_key_value_heads = num_attention_heads
66
+
67
+ self.num_key_value_heads = num_key_value_heads
68
+ self.hidden_act = hidden_act
69
+ self.initializer_range = initializer_range
70
+ self.rms_norm_eps = rms_norm_eps
71
+ self.use_cache = use_cache
72
+ self.rope_theta = rope_theta
73
+ self.rope_scaling = rope_scaling
74
+
75
+ self.q_lora_rank = q_lora_rank
76
+ self.kv_lora_rank = kv_lora_rank
77
+ self.qk_nope_head_dim = qk_nope_head_dim
78
+ self.qk_rope_head_dim = qk_rope_head_dim
79
+ self.v_head_dim = v_head_dim
80
+ self.mla_use_nope = mla_use_nope
81
+ # moe config
82
+ self.num_experts = num_experts
83
+ self.num_experts_per_token = num_experts_per_token
84
+ self.moe_renormalize = moe_renormalize
85
+ self.num_shared_experts = num_shared_experts
86
+ self.routed_scaling_factor = routed_scaling_factor
87
+ self.moe_router_activation_func = moe_router_activation_func
88
+ assert self.moe_router_activation_func in ("softmax", "sigmoid")
89
+ self.moe_intermediate_size = moe_intermediate_size
90
+ self.first_k_dense_replace = first_k_dense_replace
91
+ self.moe_layer_freq = moe_layer_freq
92
+ self.use_grouped_topk = use_grouped_topk
93
+ self.num_expert_group = num_expert_group
94
+ self.topk_group = topk_group
95
+ self.num_nextn_predict_layers = num_nextn_predict_layers
96
+
97
+ if linear_attn_config is not None:
98
+ assert linear_attn_config["kda_layers"] is not None
99
+ assert linear_attn_config["full_attn_layers"] is not None
100
+ self.linear_attn_config = linear_attn_config
101
+
102
+ super().__init__(
103
+ pad_token_id=pad_token_id,
104
+ bos_token_id=bos_token_id,
105
+ eos_token_id=eos_token_id,
106
+ tie_word_embeddings=tie_word_embeddings,
107
+ **kwargs,
108
+ )
109
+
110
+ @property
111
+ def is_mla(self):
112
+ return (
113
+ self.q_lora_rank is not None
114
+ or self.kv_lora_rank is not None
115
+ or self.qk_nope_head_dim is not None
116
+ or self.qk_rope_head_dim is not None
117
+ or self.v_head_dim is not None
118
+ or self.mla_use_nope is True
119
+ )
120
+
121
+ @property
122
+ def is_moe(self):
123
+ return self.num_experts is not None
124
+
125
+ @property
126
+ def is_linear_attn(self) -> bool:
127
+ return not (
128
+ self.linear_attn_config is None
129
+ or (
130
+ isinstance(self.linear_attn_config, dict)
131
+ and self.linear_attn_config["kda_layers"] is not None
132
+ and len(self.linear_attn_config["kda_layers"]) == 0
133
+ )
134
+ )
135
+
136
+ def is_kda_layer(self, layer_idx: int):
137
+ return (
138
+ self.linear_attn_config is not None
139
+ and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
140
+ )
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 163584,
4
+ "eos_token_id": 163586,
5
+ "output_attentions": false,
6
+ "output_hidden_states": false,
7
+ "pad_token_id": 163839,
8
+ "transformers_version": "5.3.0",
9
+ "use_cache": true
10
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2512c91207ee08a12f2921efb0cbbf54765e2d15ca509891ddac605d2ff5c62a
3
+ size 721425696
modeling_kimi_linear.py ADDED
@@ -0,0 +1,1104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Any
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ from einops import rearrange, repeat
9
+ from packaging import version
10
+ from torch import nn
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import Cache
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.masking_utils import create_causal_mask
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
18
+ from transformers.processing_utils import Unpack
19
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
20
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
21
+ from transformers.utils.output_capturing import OutputRecorder
22
+
23
+ def check_model_inputs(fn):
24
+ return fn
25
+
26
+ try:
27
+ from fla.modules import FusedRMSNormGated, ShortConvolution
28
+ from fla.ops.kda import chunk_kda, fused_recurrent_kda
29
+ from fla.ops.kda.gate import fused_kda_gate
30
+ from fla.ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask
31
+ from fla.utils import tensor_cache
32
+ except ImportError:
33
+ raise ImportError("Plese run `pip install -U fla-core`")
34
+
35
+ from .configuration_kimi_linear import KimiLinearConfig
36
+
37
+ assert version.parse(transformers.__version__) >= version.parse("4.56.0"), \
38
+ "Please upgrade transformers to >= 4.56.0"
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ def index_first_axis(x, indices):
44
+ other_shape = x.shape[1:]
45
+ second_dim = other_shape.numel()
46
+ return torch.gather(
47
+ rearrange(x, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim),
48
+ ).reshape(-1, *other_shape)
49
+
50
+
51
+ def index_put_first_axis(x, indices, first_axis_dim):
52
+ y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype)
53
+ # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ y[indices] = x
55
+ # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x)
56
+ return y
57
+
58
+
59
+ @tensor_cache
60
+ def get_unpad_data(
61
+ attention_mask: torch.Tensor,
62
+ ) -> tuple[torch.Tensor, torch.Tensor, int]:
63
+ lens = prepare_lens_from_mask(attention_mask)
64
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
65
+ max_seqlen_in_batch = lens.max().item()
66
+ cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask)
67
+ return indices, cu_seqlens, max_seqlen_in_batch
68
+
69
+
70
+ def unpad_input(
71
+ q: torch.Tensor,
72
+ states: tuple[torch.Tensor],
73
+ attention_mask: torch.Tensor,
74
+ q_len: int,
75
+ keepdim: bool = False,
76
+ ):
77
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask)
78
+ batch_size, seq_len, *_ = states[0].shape
79
+
80
+ state = tuple(
81
+ index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k)
82
+ for s in states
83
+ )
84
+
85
+ if q_len == seq_len:
86
+ q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
87
+ cu_seqlens_q = cu_seqlens_k
88
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
89
+ indices_q = indices_k
90
+ elif q_len == 1:
91
+ max_seqlen_in_batch_q = 1
92
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
93
+ indices_q = cu_seqlens_q[:-1]
94
+ q = q.squeeze(1)
95
+ else:
96
+ raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")
97
+
98
+ if keepdim:
99
+ q = q.unsqueeze(0)
100
+ state = tuple(s.unsqueeze(0) for s in state)
101
+
102
+ return (
103
+ q,
104
+ state,
105
+ indices_q,
106
+ (cu_seqlens_q, cu_seqlens_k),
107
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
108
+ )
109
+
110
+
111
+ def pad_input(
112
+ hidden_states: torch.Tensor,
113
+ indices: torch.LongTensor,
114
+ batch_size: int,
115
+ seq_len: int,
116
+ ) -> torch.Tensor:
117
+ output = index_put_first_axis(hidden_states, indices, batch_size * seq_len)
118
+ return rearrange(output, "(b s) ... -> b s ...", b=batch_size)
119
+
120
+
121
+ class KimiDynamicCache:
122
+ """
123
+ Dynamic cache for Kimi model.
124
+ Inspired by Qwen3-Next
125
+ """
126
+ is_compileable = False
127
+
128
+ def __init__(self, config: KimiLinearConfig):
129
+ super().__init__()
130
+ self.config = config
131
+
132
+ if config.linear_attn_config is not None:
133
+ self.layer_types = []
134
+ for i in range(config.num_hidden_layers):
135
+ if config.is_kda_layer(i):
136
+ self.layer_types.append("linear_attention")
137
+ else:
138
+ self.layer_types.append("full_attention")
139
+ else:
140
+ self.layer_types = ["full_attention"] * config.num_hidden_layers
141
+
142
+ self.transformer_layers = [
143
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
144
+ ]
145
+
146
+ linear_layers = [i for i in range(
147
+ config.num_hidden_layers) if self.layer_types[i] == "linear_attention"]
148
+ self.last_linear_layer = linear_layers[-1] if linear_layers else -1
149
+
150
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
151
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
152
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
153
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
154
+
155
+ def __len__(self):
156
+ return len(self.layer_types)
157
+
158
+ def update(
159
+ self,
160
+ key_states: torch.Tensor,
161
+ value_states: torch.Tensor,
162
+ layer_idx: int,
163
+ cache_kwargs: dict[str, Any] | None = None,
164
+ ) -> tuple[torch.Tensor, torch.Tensor]:
165
+ if self.key_cache[layer_idx] is None:
166
+ self.key_cache[layer_idx] = key_states
167
+ self.value_cache[layer_idx] = value_states
168
+ else:
169
+ self.key_cache[layer_idx] = torch.cat(
170
+ [self.key_cache[layer_idx], key_states], dim=2)
171
+ self.value_cache[layer_idx] = torch.cat(
172
+ [self.value_cache[layer_idx], value_states], dim=2)
173
+
174
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
175
+
176
+ def reorder_cache(self, beam_idx: torch.LongTensor):
177
+ """Reorders the cache for beam search, given the selected beam indices."""
178
+ for layer_idx in range(len(self.key_cache)):
179
+ if self.key_cache[layer_idx] is not None:
180
+ device = self.key_cache[layer_idx].device
181
+ beam_idx = beam_idx.to(device)
182
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
183
+ 0, beam_idx)
184
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
185
+ 0, beam_idx)
186
+
187
+ if self.conv_states[layer_idx] is not None:
188
+ device = self.conv_states[layer_idx][0].device
189
+ beam_idx = beam_idx.to(device)
190
+ q_conv, k_conv, v_conv = self.conv_states[layer_idx]
191
+ self.conv_states[layer_idx] = (
192
+ q_conv.index_select(0, beam_idx),
193
+ k_conv.index_select(0, beam_idx),
194
+ v_conv.index_select(0, beam_idx),
195
+ )
196
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
197
+ 0, beam_idx)
198
+
199
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
200
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
201
+ # take any layer that contains cache and not empty tensor
202
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
203
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
204
+ return 0
205
+ return self.key_cache[layer_idx].shape[-2]
206
+
207
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
208
+ """
209
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
210
+ the given layer at `layer_idx`.
211
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
212
+ """
213
+ kv_offset = 0
214
+ query_length = cache_position.shape[0]
215
+ past_seen_tokens = self.get_seq_length(layer_idx)
216
+ kv_length = query_length + past_seen_tokens
217
+ return kv_length, kv_offset
218
+
219
+ @property
220
+ def has_previous_state(self):
221
+ """We have a previous state if the last linear (conv) layer was already updated."""
222
+ if self.last_linear_layer == -1:
223
+ return False
224
+ return self.conv_states[self.last_linear_layer] is not None
225
+
226
+
227
+ class KimiRMSNorm(nn.Module):
228
+ def __init__(self, hidden_size, eps=1e-6):
229
+ """
230
+ KimiRMSNorm is equivalent to T5LayerNorm
231
+ """
232
+ super().__init__()
233
+ self.weight = nn.Parameter(torch.ones(hidden_size))
234
+ self.variance_epsilon = eps
235
+
236
+ def forward(self, hidden_states):
237
+ input_dtype = hidden_states.dtype
238
+ hidden_states = hidden_states.to(torch.float32)
239
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
240
+ hidden_states = hidden_states * \
241
+ torch.rsqrt(variance + self.variance_epsilon)
242
+ return self.weight * hidden_states.to(input_dtype)
243
+
244
+
245
+ ALL_LAYERNORM_LAYERS.append(KimiRMSNorm)
246
+
247
+
248
+ class KimiBlockSparseMLP(nn.Module):
249
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
250
+ super().__init__()
251
+ self.config = config
252
+ self.ffn_dim = config.intermediate_size if intermediate_size is None else intermediate_size
253
+ self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size
254
+
255
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate
256
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down
257
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up
258
+
259
+ self.act_fn = ACT2FN[config.hidden_act]
260
+
261
+ def forward(self, hidden_states):
262
+ current_hidden_states = self.act_fn(
263
+ self.w1(hidden_states)) * self.w3(hidden_states)
264
+ current_hidden_states = self.w2(current_hidden_states)
265
+ return current_hidden_states
266
+
267
+
268
+ class KimiMLP(nn.Module):
269
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
270
+ super().__init__()
271
+ self.config = config
272
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
273
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
274
+ self.gate_proj = nn.Linear(
275
+ self.hidden_size, self.intermediate_size, bias=False)
276
+ self.up_proj = nn.Linear(
277
+ self.hidden_size, self.intermediate_size, bias=False)
278
+ self.down_proj = nn.Linear(
279
+ self.intermediate_size, self.hidden_size, bias=False)
280
+ self.act_fn = ACT2FN[config.hidden_act]
281
+
282
+ def forward(self, x):
283
+ down_proj = self.down_proj(self.act_fn(
284
+ self.gate_proj(x)) * self.up_proj(x))
285
+ return down_proj
286
+
287
+
288
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
289
+ """
290
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
291
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
292
+ """
293
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
294
+ if n_rep == 1:
295
+ return hidden_states
296
+ hidden_states = hidden_states[:, :, None, :, :].expand(
297
+ batch, num_key_value_heads, n_rep, slen, head_dim)
298
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
299
+
300
+
301
+ def eager_attention_forward(
302
+ module: nn.Module,
303
+ query: torch.Tensor,
304
+ key: torch.Tensor,
305
+ value: torch.Tensor,
306
+ attention_mask: torch.Tensor | None,
307
+ scaling: float,
308
+ dropout: float = 0.0,
309
+ **kwargs: Unpack[TransformersKwargs],
310
+ ):
311
+ key_states = repeat_kv(key, module.num_key_value_groups)
312
+ value_states = repeat_kv(value, module.num_key_value_groups)
313
+
314
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
315
+ if attention_mask is not None:
316
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
317
+ attn_weights = attn_weights + causal_mask
318
+
319
+ attn_weights = nn.functional.softmax(
320
+ attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
321
+ attn_weights = nn.functional.dropout(
322
+ attn_weights, p=dropout, training=module.training)
323
+ attn_output = torch.matmul(attn_weights, value_states)
324
+ attn_output = attn_output.transpose(1, 2).contiguous()
325
+
326
+ return attn_output, attn_weights
327
+
328
+
329
+ class KimiMLAAttention(nn.Module):
330
+ """
331
+ Multi-Latent Attention adapted from deepseek-v3
332
+ """
333
+
334
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
335
+ nn.Module.__init__(self)
336
+ self.config = config
337
+ self.layer_idx = layer_idx
338
+ self.hidden_size = config.hidden_size
339
+ self.num_heads = config.num_attention_heads
340
+ self.num_key_value_heads = config.num_key_value_heads
341
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
342
+
343
+ self.rope_theta = config.rope_theta
344
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
345
+
346
+ try:
347
+ self.q_lora_rank = config.q_lora_rank
348
+ self.qk_rope_head_dim = config.qk_rope_head_dim
349
+ self.kv_lora_rank = config.kv_lora_rank
350
+ self.v_head_dim = config.v_head_dim
351
+ self.qk_nope_head_dim = config.qk_nope_head_dim
352
+ self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
353
+ self.use_nope = config.mla_use_nope
354
+ self.scaling = self.q_head_dim ** (-0.5)
355
+ except Exception as e:
356
+ raise ValueError(
357
+ f"Kimi MLA config is not found or not properly formatted: {e}")
358
+
359
+ assert self.q_lora_rank is None
360
+ self.q_proj = nn.Linear(
361
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False,
362
+ )
363
+ self.kv_a_proj_with_mqa = nn.Linear(
364
+ self.hidden_size,
365
+ self.kv_lora_rank + self.qk_rope_head_dim,
366
+ bias=False,
367
+ )
368
+ self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank)
369
+ self.kv_b_proj = nn.Linear(
370
+ self.kv_lora_rank,
371
+ self.num_heads
372
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
373
+ bias=False,
374
+ )
375
+ self.o_proj = nn.Linear(
376
+ self.num_heads * self.v_head_dim,
377
+ self.hidden_size,
378
+ bias=False,
379
+ )
380
+ self.is_causal = True
381
+ assert self.use_nope
382
+
383
+ def forward(
384
+ self,
385
+ hidden_states: torch.Tensor,
386
+ attention_mask: torch.Tensor | None = None,
387
+ past_key_values: Cache | None = None,
388
+ **kwargs,
389
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
390
+ batch_size, seq_length = hidden_states.shape[:-1]
391
+ query_shape = (batch_size, seq_length, -1, self.q_head_dim)
392
+ key_shape = (batch_size, seq_length, -1,
393
+ self.qk_nope_head_dim + self.v_head_dim)
394
+
395
+ q_states = self.q_proj(hidden_states)
396
+ q_states = q_states.view(query_shape).transpose(1, 2)
397
+ q_pass, q_rot = torch.split(
398
+ q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
399
+
400
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
401
+ k_pass, k_rot = torch.split(
402
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
403
+
404
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(
405
+ k_pass)).view(key_shape).transpose(1, 2)
406
+ k_pass, value_states = torch.split(
407
+ k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
408
+
409
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
410
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
411
+
412
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
413
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
414
+
415
+ if past_key_values is not None:
416
+ key_states, value_states = past_key_values.update(
417
+ key_states, value_states, self.layer_idx)
418
+
419
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
420
+ value_states = F.pad(
421
+ value_states, [0, self.q_head_dim - self.v_head_dim])
422
+
423
+ attention_interface: Callable = eager_attention_forward
424
+ if self.config._attn_implementation != "eager":
425
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
426
+
427
+ attn_output, _ = attention_interface(
428
+ self,
429
+ query_states,
430
+ key_states,
431
+ value_states,
432
+ attention_mask,
433
+ dropout=0.0 if not self.training else self.attention_dropout,
434
+ scaling=self.scaling,
435
+ **kwargs,
436
+ )
437
+
438
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
439
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
440
+
441
+ attn_output = attn_output.reshape(
442
+ batch_size, seq_length, -1).contiguous()
443
+ attn_output = self.o_proj(attn_output)
444
+ return attn_output
445
+
446
+
447
+ class KimiDeltaAttention(nn.Module):
448
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
449
+ super().__init__()
450
+ self.config = config
451
+ self.mode = "chunk"
452
+
453
+ self.hidden_size = config.hidden_size
454
+ self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
455
+ self.head_dim = config.linear_attn_config["head_dim"]
456
+ self.num_heads = config.linear_attn_config["num_heads"]
457
+ self.head_k_dim = self.head_dim
458
+ self.num_k_heads = self.num_heads
459
+
460
+ self.layer_idx = layer_idx
461
+
462
+ assert self.mode in [
463
+ 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{self.mode}`."
464
+
465
+ projection_k_size = self.head_k_dim * self.num_k_heads
466
+ projection_size = self.head_dim * self.num_heads
467
+
468
+ self.q_proj = nn.Linear(
469
+ self.hidden_size, projection_k_size, bias=False)
470
+ self.k_proj = nn.Linear(
471
+ self.hidden_size, projection_k_size, bias=False)
472
+ self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)
473
+
474
+ self.q_conv1d = ShortConvolution(
475
+ hidden_size=projection_k_size,
476
+ kernel_size=self.conv_size,
477
+ activation='silu',
478
+ )
479
+ self.k_conv1d = ShortConvolution(
480
+ hidden_size=projection_k_size,
481
+ kernel_size=self.conv_size,
482
+ activation='silu',
483
+ )
484
+ self.v_conv1d = ShortConvolution(
485
+ hidden_size=projection_size,
486
+ kernel_size=self.conv_size,
487
+ activation='silu',
488
+ )
489
+
490
+ self.A_log = torch.nn.Parameter(torch.log(torch.empty(
491
+ self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1))
492
+
493
+ self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
494
+ self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
495
+
496
+ self.dt_bias = nn.Parameter(
497
+ torch.empty(projection_size, dtype=torch.float32))
498
+
499
+ self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
500
+
501
+ self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
502
+ self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
503
+
504
+ self.o_norm = FusedRMSNormGated(
505
+ self.head_dim, eps=config.rms_norm_eps, activation='sigmoid')
506
+ self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)
507
+
508
+ def forward(
509
+ self,
510
+ hidden_states: torch.Tensor,
511
+ attention_mask: torch.Tensor | None = None,
512
+ cache_params: KimiDynamicCache | None = None,
513
+ **kwargs: Unpack[dict],
514
+ ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
515
+ if attention_mask is not None:
516
+ if attention_mask.dim() != 2:
517
+ attention_mask = kwargs.get("padding_mask")
518
+
519
+ if attention_mask is not None and attention_mask.dim() != 2:
520
+ raise ValueError(
521
+ "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
522
+ "(0 = padding). 3D masks are not supported here.",
523
+ )
524
+ use_cache = cache_params is not None
525
+ batch_size, q_len, _ = hidden_states.shape
526
+ mode = 'fused_recurrent' if q_len <= 64 else self.mode
527
+ if self.training:
528
+ assert mode == 'chunk', "Only chunk mode is supported in training."
529
+
530
+ cu_seqlens = kwargs.get('cu_seqlens')
531
+ indices = None
532
+ if attention_mask is not None:
533
+ indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
534
+ hidden_states = index_first_axis(
535
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
536
+
537
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
538
+ recurrent_state = None
539
+ if cache_params is not None:
540
+ if cache_params.conv_states[self.layer_idx] is not None:
541
+ conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[
542
+ self.layer_idx]
543
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
544
+ q, conv_state_q = self.q_conv1d(
545
+ x=self.q_proj(hidden_states),
546
+ cache=conv_state_q,
547
+ output_final_state=use_cache,
548
+ cu_seqlens=cu_seqlens,
549
+ )
550
+ k, conv_state_k = self.k_conv1d(
551
+ x=self.k_proj(hidden_states),
552
+ cache=conv_state_k,
553
+ output_final_state=use_cache,
554
+ cu_seqlens=cu_seqlens,
555
+ )
556
+ v, conv_state_v = self.v_conv1d(
557
+ x=self.v_proj(hidden_states),
558
+ cache=conv_state_v,
559
+ output_final_state=use_cache,
560
+ cu_seqlens=cu_seqlens,
561
+ )
562
+ g = self.f_b_proj(self.f_a_proj(hidden_states))
563
+ g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
564
+ beta = self.b_proj(hidden_states).float().sigmoid()
565
+
566
+ q, k = map(lambda x: rearrange(
567
+ x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
568
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
569
+
570
+ if mode == 'chunk':
571
+ o, recurrent_state = chunk_kda(
572
+ q=q,
573
+ k=k,
574
+ v=v,
575
+ g=g,
576
+ beta=beta,
577
+ initial_state=recurrent_state,
578
+ output_final_state=True,
579
+ use_qk_l2norm_in_kernel=True,
580
+ cu_seqlens=cu_seqlens,
581
+ )
582
+ else:
583
+ o, recurrent_state = fused_recurrent_kda(
584
+ q=q,
585
+ k=k,
586
+ v=v,
587
+ g=g,
588
+ beta=beta,
589
+ initial_state=recurrent_state,
590
+ output_final_state=True,
591
+ use_qk_l2norm_in_kernel=True,
592
+ cu_seqlens=cu_seqlens,
593
+ )
594
+ if cache_params is not None:
595
+ cache_params.recurrent_states[self.layer_idx] = recurrent_state
596
+ cache_params.conv_states[self.layer_idx] = (
597
+ conv_state_q, conv_state_k, conv_state_v)
598
+
599
+ g = self.g_b_proj(self.g_a_proj(hidden_states))
600
+ g = rearrange(g, '... (h d) -> ... h d', d=self.head_dim)
601
+ o = self.o_norm(o, g)
602
+
603
+ o = rearrange(o, 'b t h d -> b t (h d)')
604
+ o = self.o_proj(o)
605
+ if attention_mask is not None:
606
+ o = pad_input(o.squeeze(0), indices, batch_size, q_len)
607
+
608
+ return o
609
+
610
+
611
+ class KimiMoEGate(nn.Module):
612
+ """
613
+ MoEGate adapted from Deepseek-V3.
614
+ Parameter correspondences:
615
+ num_experts -> n_routed_experts
616
+ num_experts_per_token -> num_experts_per_tok
617
+ num_expert_group -> n_group
618
+ moe_router_activation_func -> scoring_func
619
+ """
620
+
621
+ def __init__(self, config: KimiLinearConfig):
622
+ super().__init__()
623
+ self.config = config
624
+ self.top_k = config.num_experts_per_token
625
+ self.num_experts = config.num_experts
626
+ self.routed_scaling_factor = config.routed_scaling_factor
627
+ self.moe_router_activation_func = config.moe_router_activation_func
628
+ self.num_expert_group = getattr(config, "num_expert_group", 1)
629
+ self.topk_group = getattr(config, "topk_group", 1)
630
+
631
+ # topk selection algorithm
632
+ self.moe_renormalize = config.moe_renormalize
633
+ self.gating_dim = config.hidden_size
634
+ self.weight = nn.Parameter(
635
+ torch.empty((self.num_experts, self.gating_dim)),
636
+ )
637
+
638
+ self.e_score_correction_bias = nn.Parameter(
639
+ torch.empty(self.num_experts),
640
+ )
641
+ self.reset_parameters()
642
+
643
+ def reset_parameters(self) -> None:
644
+ import torch.nn.init as init
645
+
646
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
647
+
648
+ def forward(self, hidden_states):
649
+ bsz, seq_len, h = hidden_states.shape
650
+ # compute gating score
651
+ hidden_states = hidden_states.view(-1, h)
652
+ logits = F.linear(
653
+ hidden_states.type(torch.float32), self.weight.type(
654
+ torch.float32), None,
655
+ )
656
+ if self.moe_router_activation_func == "sigmoid":
657
+ scores = logits.sigmoid()
658
+ elif self.moe_router_activation_func == "softmax":
659
+ scores = logits.softmax(dim=1)
660
+ else:
661
+ raise NotImplementedError(
662
+ f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}",
663
+ )
664
+
665
+ # select top-k experts
666
+ assert not self.training
667
+ scores_for_choice = scores.view(bsz * seq_len, -1)
668
+ scores_for_choice += self.e_score_correction_bias.unsqueeze(0)
669
+ group_scores = (
670
+ scores_for_choice.view(
671
+ bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
672
+ ) # [n, num_expert_group]
673
+ group_idx = torch.topk(
674
+ group_scores, k=self.topk_group, dim=-1, sorted=False,
675
+ )[
676
+ 1
677
+ ] # [n, top_k_group]
678
+ group_mask = torch.zeros_like(group_scores) # [n, num_expert_group]
679
+ group_mask.scatter_(1, group_idx, 1) # [n, num_expert_group]
680
+ score_mask = (
681
+ group_mask.unsqueeze(-1)
682
+ .expand(
683
+ bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group,
684
+ )
685
+ .reshape(bsz * seq_len, -1)
686
+ ) # [n, e]
687
+ tmp_scores = scores_for_choice.masked_fill(
688
+ ~score_mask.bool(), 0.0) # [n, e]
689
+ _, topk_idx = torch.topk(
690
+ tmp_scores, k=self.top_k, dim=-1, sorted=False,
691
+ )
692
+ topk_weight = scores.gather(1, topk_idx)
693
+
694
+ # norm gate to sum 1
695
+ if self.top_k > 1 and self.moe_renormalize:
696
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
697
+ topk_weight = topk_weight / denominator
698
+ # must multiply the scaling factor
699
+ topk_weight = topk_weight * self.routed_scaling_factor
700
+
701
+ return topk_idx, topk_weight
702
+
703
+
704
+ class KimiSparseMoeBlock(nn.Module):
705
+ """
706
+ Adapted from Deepseek-V3's MOE implementation
707
+ The namings are consistent with Kimi's version.
708
+ """
709
+
710
+ def __init__(self, config: KimiLinearConfig):
711
+ super().__init__()
712
+ self.config = config
713
+ self.hidden_dim = config.hidden_size
714
+ self.num_experts = config.num_experts
715
+ self.top_k = config.num_experts_per_token
716
+ self.moe_renormalize = config.moe_renormalize
717
+
718
+ self.ep_size = 1
719
+ self.experts_per_rank = config.num_experts
720
+ self.ep_rank = 0
721
+ self.experts = nn.ModuleList(
722
+ [
723
+ KimiBlockSparseMLP(
724
+ config, intermediate_size=config.moe_intermediate_size,
725
+ )
726
+ for _ in range(config.num_experts)
727
+ ],
728
+ )
729
+ self.gate = KimiMoEGate(config)
730
+ if config.num_shared_experts is not None:
731
+ intermediate_size = config.moe_intermediate_size * config.num_shared_experts
732
+ self.shared_experts = KimiMLP(
733
+ config=config, intermediate_size=intermediate_size,
734
+ )
735
+
736
+ def forward(self, hidden_states):
737
+ identity = hidden_states
738
+ orig_shape = hidden_states.shape
739
+ topk_idx, topk_weight = self.gate(hidden_states)
740
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
741
+ if not self.training:
742
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
743
+ else:
744
+ raise NotImplementedError("Training mode is not supported in KimiSparseMoeBlock")
745
+ if self.config.num_shared_experts is not None:
746
+ y = y + self.shared_experts(identity)
747
+ return y
748
+
749
+ @torch.no_grad()
750
+ def moe_infer(self, x, topk_ids, topk_weight):
751
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
752
+ cnts.scatter_(1, topk_ids, 1)
753
+ tokens_per_expert = cnts.sum(dim=0)
754
+ idxs = topk_ids.view(-1).argsort()
755
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
756
+
757
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
758
+
759
+ outputs = []
760
+ start_idx = 0
761
+ for i, num_tokens in enumerate(tokens_per_expert):
762
+ end_idx = start_idx + num_tokens
763
+ if num_tokens == 0:
764
+ continue
765
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
766
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
767
+ expert_out = expert(tokens_for_this_expert)
768
+ outputs.append(expert_out)
769
+ start_idx = end_idx
770
+
771
+ outs = torch.cat(outputs, dim=0) if len(
772
+ outputs) else sorted_tokens.new_empty(0)
773
+
774
+ new_x = torch.empty_like(outs)
775
+ new_x[idxs] = outs
776
+ final_out = (
777
+ new_x.view(*topk_ids.shape, -1)
778
+ .type(topk_weight.dtype)
779
+ .mul_(topk_weight.unsqueeze(dim=-1))
780
+ .sum(dim=1)
781
+ .type(new_x.dtype)
782
+ )
783
+ return final_out
784
+
785
+
786
+ class KimiDecoderLayer(nn.Module):
787
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
788
+ super().__init__()
789
+ self.hidden_size = config.hidden_size
790
+ self.config = config
791
+ if config.is_kda_layer(layer_idx):
792
+ self.is_linear_attn = True
793
+ self.self_attn = KimiDeltaAttention(
794
+ config=config, layer_idx=layer_idx)
795
+ elif config.is_mla:
796
+ self.is_linear_attn = False
797
+ self.self_attn = KimiMLAAttention(
798
+ config=config, layer_idx=layer_idx)
799
+ else:
800
+ raise NotImplementedError
801
+ if (
802
+ config.num_experts is not None
803
+ and layer_idx >= config.first_k_dense_replace
804
+ and layer_idx % getattr(config, "moe_layer_freq", 1) == 0
805
+ ):
806
+ self.block_sparse_moe = KimiSparseMoeBlock(config)
807
+ else:
808
+ self.mlp = KimiMLP(config)
809
+ self.input_layernorm = KimiRMSNorm(
810
+ config.hidden_size, eps=config.rms_norm_eps)
811
+ self.post_attention_layernorm = KimiRMSNorm(
812
+ config.hidden_size, eps=config.rms_norm_eps)
813
+
814
+ def forward(
815
+ self,
816
+ hidden_states: torch.Tensor,
817
+ attention_mask: torch.Tensor | None = None,
818
+ position_ids: torch.LongTensor | None = None,
819
+ past_key_values: tuple[torch.Tensor] | None = None,
820
+ output_attentions: bool | None = False,
821
+ use_cache: bool | None = False,
822
+ **kwargs: Unpack[FlashAttentionKwargs],
823
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
824
+ """
825
+ Args:
826
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
827
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
828
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
829
+ output_attentions (`bool`, *optional*):
830
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
831
+ returned tensors for more detail.
832
+ use_cache (`bool`, *optional*):
833
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
834
+ (see `past_key_values`).
835
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
836
+ """
837
+
838
+ residual = hidden_states
839
+
840
+ hidden_states = self.input_layernorm(hidden_states)
841
+
842
+ # Self Attention
843
+ if self.is_linear_attn is False:
844
+ hidden_states = self.self_attn(
845
+ hidden_states=hidden_states,
846
+ attention_mask=attention_mask,
847
+ position_ids=position_ids,
848
+ past_key_values=past_key_values,
849
+ output_attentions=output_attentions,
850
+ use_cache=use_cache,
851
+ **kwargs,
852
+ )
853
+ else:
854
+ hidden_states = self.self_attn(
855
+ hidden_states=hidden_states,
856
+ attention_mask=attention_mask,
857
+ cache_params=past_key_values,
858
+ output_attentions=output_attentions,
859
+ use_cache=use_cache,
860
+ **kwargs,
861
+ )
862
+ hidden_states = residual + hidden_states
863
+
864
+ # Fully Connected
865
+ residual = hidden_states
866
+ hidden_states = self.post_attention_layernorm(hidden_states)
867
+ if hasattr(self, "block_sparse_moe"):
868
+ hidden_states = self.block_sparse_moe(hidden_states)
869
+ else:
870
+ hidden_states = self.mlp(hidden_states)
871
+ hidden_states = residual + hidden_states
872
+
873
+ return hidden_states
874
+
875
+
876
+ class KimiPreTrainedModel(PreTrainedModel):
877
+ config_class = KimiLinearConfig
878
+ base_model_prefix = "model"
879
+ supports_gradient_checkpointing = True
880
+ _no_split_modules = ["KimiDecoderLayer"]
881
+ _skip_keys_device_placement = "past_key_values"
882
+ _supports_flash_attn_2 = True
883
+ _can_record_outputs = {
884
+ "router_logits": OutputRecorder(KimiBlockSparseMLP, index=1),
885
+ "hidden_states": KimiDecoderLayer,
886
+ "attentions": KimiMLAAttention,
887
+ }
888
+ _is_stateful = True
889
+
890
+ def _init_weights(self, module):
891
+ std = self.config.initializer_range
892
+ if isinstance(module, nn.Linear):
893
+ module.weight.data.normal_(mean=0.0, std=std)
894
+ if module.bias is not None:
895
+ module.bias.data.zero_()
896
+ elif isinstance(module, nn.Embedding):
897
+ module.weight.data.normal_(mean=0.0, std=std)
898
+ if module.padding_idx is not None:
899
+ module.weight.data[module.padding_idx].zero_()
900
+
901
+
902
+ class KimiLinearModel(KimiPreTrainedModel):
903
+ def __init__(self, config: KimiLinearConfig):
904
+ super().__init__(config)
905
+ self.padding_idx = config.pad_token_id
906
+ self.vocab_size = config.vocab_size
907
+
908
+ self.embed_tokens = nn.Embedding(
909
+ config.vocab_size, config.hidden_size, self.padding_idx)
910
+ self.layers = nn.ModuleList([KimiDecoderLayer(
911
+ config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
912
+ self.norm = KimiRMSNorm(
913
+ config.hidden_size, eps=config.rms_norm_eps)
914
+
915
+ if getattr(config, "_attn_implementation", None) is not None:
916
+ if config._attn_implementation != "flash_attention_2":
917
+ logger.warning_once(
918
+ f"Ignoring the provided attention implementation {config._attn_implementation}")
919
+ logger.warning_once("Using flash_attention_2 backend instead.")
920
+ config._attn_implementation = "flash_attention_2"
921
+ else:
922
+ config._attn_implementation = "flash_attention_2"
923
+
924
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
925
+ self.gradient_checkpointing = False
926
+ # Initialize weights and apply final processing
927
+ self.post_init()
928
+
929
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
930
+ """
931
+ NOTE: Left-padding is used for linear attention mask.
932
+ No need for zeroing states when
933
+ 1. Cached forward
934
+ 2. Attending to all inputs
935
+ """
936
+ linear_attn_mask = attention_mask
937
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
938
+ linear_attn_mask = None
939
+ return linear_attn_mask
940
+
941
+ @check_model_inputs
942
+ @auto_docstring
943
+ def forward(
944
+ self,
945
+ input_ids: torch.LongTensor = None,
946
+ attention_mask: torch.Tensor | None = None,
947
+ position_ids: torch.LongTensor | None = None,
948
+ past_key_values: Cache | None = None,
949
+ inputs_embeds: torch.FloatTensor | None = None,
950
+ cache_position: torch.LongTensor | None = None,
951
+ use_cache: bool | None = None,
952
+ **kwargs: Unpack[TransformersKwargs],
953
+ ) -> tuple | BaseModelOutputWithPast:
954
+
955
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
956
+
957
+ if (input_ids is None) and (inputs_embeds is None):
958
+ raise ValueError(
959
+ "You must specify exactly one of input_ids or inputs_embeds")
960
+
961
+ # Get inputs_embeds
962
+ if inputs_embeds is None:
963
+ inputs_embeds = self.embed_tokens(input_ids)
964
+
965
+ if use_cache and past_key_values is None:
966
+ past_key_values = KimiDynamicCache(config=self.config)
967
+
968
+ if cache_position is None:
969
+ past_seen_tokens = past_key_values.get_seq_length(
970
+ ) if past_key_values is not None else 0
971
+ cache_position: torch.Tensor = torch.arange(
972
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device,
973
+ )
974
+
975
+ if position_ids is None:
976
+ position_ids = cache_position.unsqueeze(0)
977
+
978
+ causal_mask = create_causal_mask(
979
+ config=self.config,
980
+ input_embeds=inputs_embeds,
981
+ attention_mask=attention_mask,
982
+ cache_position=cache_position,
983
+ past_key_values=past_key_values,
984
+ position_ids=position_ids,
985
+ )
986
+ linear_attn_mask = self._update_linear_attn_mask(
987
+ attention_mask, cache_position)
988
+
989
+ hidden_states = inputs_embeds
990
+ if past_key_values is not None:
991
+ assert isinstance(past_key_values, KimiDynamicCache)
992
+
993
+ for decoder_layer in self.layers:
994
+ layer_mask = linear_attn_mask if decoder_layer.is_linear_attn else causal_mask
995
+
996
+ hidden_states = decoder_layer(
997
+ hidden_states,
998
+ attention_mask=layer_mask,
999
+ past_key_values=past_key_values,
1000
+ cache_position=cache_position,
1001
+ **kwargs,
1002
+ )
1003
+
1004
+ hidden_states = self.norm(hidden_states)
1005
+
1006
+ return BaseModelOutputWithPast(
1007
+ last_hidden_state=hidden_states,
1008
+ past_key_values=past_key_values,
1009
+ )
1010
+
1011
+
1012
+ class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):
1013
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1014
+
1015
+ def __init__(self, config):
1016
+ super().__init__(config)
1017
+ self.model = KimiLinearModel(config)
1018
+ self.vocab_size = config.vocab_size
1019
+ self.lm_head = nn.Linear(
1020
+ config.hidden_size, config.vocab_size, bias=False)
1021
+
1022
+ # Initialize weights and apply final processing
1023
+ self.post_init()
1024
+
1025
+ @can_return_tuple
1026
+ @auto_docstring
1027
+ def forward(
1028
+ self,
1029
+ input_ids: torch.LongTensor = None,
1030
+ attention_mask: torch.Tensor | None = None,
1031
+ position_ids: torch.LongTensor | None = None,
1032
+ past_key_values: list[torch.FloatTensor] | None = None,
1033
+ inputs_embeds: torch.FloatTensor | None = None,
1034
+ labels: torch.LongTensor | None = None,
1035
+ use_cache: bool | None = None,
1036
+ output_attentions: bool | None = None,
1037
+ output_hidden_states: bool | None = None,
1038
+ generation_mode: bool | None = None,
1039
+ return_dict: bool | None = None,
1040
+ cache_position: torch.LongTensor | None = None,
1041
+ **kwargs: Unpack[TransformersKwargs],
1042
+ ) -> tuple | CausalLMOutputWithPast:
1043
+ r"""
1044
+ Args:
1045
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1046
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1047
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1048
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1049
+
1050
+ Returns:
1051
+
1052
+ Example:
1053
+
1054
+ ```python
1055
+ >>> from transformers import AutoTokenizer, KimiLinearForCausalLM
1056
+
1057
+ >>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1058
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1059
+
1060
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1061
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1062
+
1063
+ >>> # Generate
1064
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1065
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1066
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1067
+ ```"""
1068
+
1069
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1070
+ output_hidden_states = (
1071
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1072
+ )
1073
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1074
+
1075
+ outputs = self.model(
1076
+ input_ids=input_ids,
1077
+ attention_mask=attention_mask,
1078
+ position_ids=position_ids,
1079
+ past_key_values=past_key_values,
1080
+ inputs_embeds=inputs_embeds,
1081
+ use_cache=use_cache,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ return_dict=return_dict,
1085
+ cache_position=cache_position,
1086
+ )
1087
+
1088
+ logits = outputs[0]
1089
+ if generation_mode:
1090
+ logits = logits[:, -1:]
1091
+ logits = self.lm_head(logits)
1092
+
1093
+ loss = None
1094
+ if labels is not None:
1095
+ loss = self.loss_function(
1096
+ logits, labels, self.vocab_size, **kwargs)
1097
+
1098
+ return CausalLMOutputWithPast(
1099
+ loss=loss,
1100
+ logits=logits,
1101
+ past_key_values=outputs.past_key_values,
1102
+ hidden_states=outputs.hidden_states,
1103
+ attentions=outputs.attentions,
1104
+ )