aiqtech zhiyuan8 commited on
Commit
e3b1a57
·
verified ·
0 Parent(s):

Duplicate from moonshotai/Kimi-Linear-48B-A3B-Instruct

Browse files

Co-authored-by: LiZhiyuan <zhiyuan8@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/arch.png filter=lfs diff=lfs merge=lfs -text
37
+ figures/perf_speed.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: text-generation
4
+ library_name: transformers
5
+ ---
6
+
7
+ <div align="center">
8
+ <a href="https://huggingface.co/papers/2510.26692"><img width="80%" src="figures/banner.png"></a>
9
+ </div>
10
+
11
+ <div align="center">
12
+ <a href="https://huggingface.co/papers/2510.26692" ><img src="figures/logo.png" height="16" width="16" style="display: inline-block; vertical-align: middle; margin: 2px;"><b style="display: inline-block;"> Tech Report</b></a> |
13
+ <a href="https://github.com/MoonshotAI/Kimi-Linear"><img src="figures/github.png" height="16" width="16" style="display: inline-block; vertical-align: middle; margin: 2px;"><b style="display: inline-block;"> Code</b></a> |
14
+ <a href="https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct"><img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" height="16" width="16" style="display: inline-block; vertical-align: middle; margin: 2px;"><b style="display: inline-block;"> HuggingFace</b></a>
15
+ </div>
16
+
17
+ <div align="center">
18
+ <img width="90%" src="figures/perf_speed.png">
19
+ <p><em><b>(a)</b> On MMLU-Pro (4k context length), Kimi Linear achieves 51.0 performance with similar speed as full attention. On RULER (128k context length), it shows Pareto-optimal performance (84.3) and 3.98x speedup. <b>(b)</b> Kimi Linear achieves 6.3x faster TPOT compared to MLA, offering significant speedups at long sequence lengths (1M tokens).</em></p>
20
+ </div>
21
+
22
+ ## Overview
23
+
24
+ Kimi Linear is a hybrid linear attention architecture that outperforms traditional full attention methods across various contexts, including short, long, and reinforcement learning (RL) scaling regimes.
25
+ At its core is Kimi Delta Attention (KDA)—a refined version of [Gated DeltaNet](https://arxiv.org/abs/2412.06464) that introduces a more efficient gating mechanism to optimize the use of finite-state RNN memory.
26
+
27
+ Kimi Linear achieves superior performance and hardware efficiency, especially for long-context tasks. It reduces the need for large KV caches by up to 75% and boosts decoding throughput by up to $6\times$ for contexts as long as 1M tokens.
28
+
29
+ We open-source the KDA kernel in [FLA](https://github.com/fla-org/flash-linear-attention/tree/main/fla/ops/kda), and release two versions model checkpoints trained with 5.7T tokens.
30
+
31
+
32
+ | **Model** | **#Total Params** | **#Activated Params** | **Context Length** | **Download Link** |
33
+ | :------------------: | :---------------: | :-------------------: | :----------------: | :------------------------------------------------------------------------------: |
34
+ | Kimi-Linear-Base | 48B | 3B | 1M | [🤗 Hugging Face](https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Base) |
35
+ | Kimi-Linear-Instruct | 48B | 3B | 1M | [🤗 Hugging Face](https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct) |
36
+
37
+ ## Key Features
38
+
39
+ - **Kimi Delta Attention (KDA):** A linear attention mechanism that refines the gated delta rule with finegrained gating.
40
+ - **Hybrid Architecture:** A 3:1 KDA-to-global MLA ratio reduces memory usage while maintaining or surpassing the quality of full attention.
41
+ - **Superior Performance:** Outperforms full attention in a variety of tasks, including long-context and RL-style benchmarks on 1.4T token training runs with fair comparisons.
42
+ - **High Throughput:** Achieves up to 6&times; faster decoding and significantly reduces time per output token (TPOT).
43
+
44
+ <div align="center">
45
+ <img width="60%" src="figures/arch.png">
46
+ </div>
47
+
48
+ ## Usage
49
+
50
+ ### Inference with Hugging Face Transformers
51
+
52
+ To use the Kimi Linear model, we recommend the following environment:
53
+
54
+ * `python` >= 3.10
55
+ * `torch` >= 2.6
56
+ * `fla-core` >= 0.4.0
57
+
58
+ ```shell
59
+ pip install -U fla-core
60
+ ```
61
+
62
+ Example Code:
63
+ ```py
64
+ from transformers import AutoModelForCausalLM, AutoTokenizer
65
+
66
+ model_name = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
67
+ model = AutoModelForCausalLM.from_pretrained(
68
+ model_name,
69
+ torch_dtype="auto",
70
+ device_map="auto",
71
+ trust_remote_code=True
72
+ )
73
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
74
+
75
+ messages = [
76
+ {"role": "system", "content": "You are a helpful assistant provided by Moonshot-AI."},
77
+ {"role": "user", "content": "Is 123 a prime?"}
78
+ ]
79
+ input_ids = tokenizer.apply_chat_template(
80
+ messages,
81
+ add_generation_prompt=True,
82
+ return_tensors="pt"
83
+ ).to(model.device)
84
+ generated_ids = model.generate(inputs=input_ids, max_new_tokens=500)
85
+ response = tokenizer.batch_decode(generated_ids)[0]
86
+ print(response)
87
+ ```
88
+
89
+ ### Deployment
90
+
91
+ For deployment, you can use the latest vllm to create an OpenAI-compatible API endpoint.
92
+
93
+ ```sh
94
+ vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct \
95
+ --port 8000 \
96
+ --tensor-parallel-size 4 \
97
+ --max-model-len 1048576 \
98
+ --trust-remote-code
99
+ ```
100
+
101
+ ### Citation
102
+
103
+ If you found our work useful, please cite
104
+ ```bibtex
105
+ @misc{team2025kimi,
106
+ title = {Kimi Linear: An Expressive, Efficient Attention Architecture},
107
+ author = {Zhang, Yu and Lin, Zongyu and Yao, Xingcheng and Hu, Jiaxi and Meng, Fanqing and Liu, Chengyin and Men, Xin and Yang, Songlin and Li, Zhiyuan and Li, Wentao and Lu, Enzhe and Liu, Weizhou and Chen, Yanru and Xu, Weixin and Yu, Longhui and Wang, Yejie and Fan, Yu and Zhong, Longguang and Yuan, Enming and Zhang, Dehao and Zhang, Yizhi and T. Liu, Y. and Wang, Haiming and Fang, Shengjun and He, Weiran and Liu, Shaowei and Li, Yiwei and Su, Jianlin and Qiu, Jiezhong and Pang, Bo and Yan, Junjie and Jiang, Zhejun and Huang, Weixiao and Yin, Bohong and You, Jiacheng and Wei, Chu and Wang, Zhengtao and Hong, Chao and Chen, Yutian and Chen, Guanduo and Wang, Yucheng and Zheng, Huabin and Wang, Feng and Liu, Yibo and Dong, Mengnan and Zhang, Zheng and Pan, Siyuan and Wu, Wenhao and Wu, Yuhao and Guan, Longyu and Tao, Jiawen and Fu, Guohong and Xu, Xinran and Wang, Yuzhi and Lai, Guokun and Wu, Yuxin and Zhou, Xinyu and Yang, Zhilin and Du, Yulun},
108
+ year = {2025},
109
+ eprint = {2510.26692},
110
+ archivePrefix = {arXiv},
111
+ primaryClass = {cs.CL}
112
+ }
113
+ ```
chat_template.jinja ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_content(msg) -%}
2
+ {%- set c = msg.get('content') -%}
3
+ {%- if c is string -%}
4
+ {{ c }}
5
+ {%- elif c is not none -%}
6
+ {% for content in c -%}
7
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
8
+ <|media_start|>image<|media_content|><|media_pad|><|media_end|>
9
+ {% else -%}
10
+ {{ content['text'] }}
11
+ {%- endif -%}
12
+ {%- endfor -%}
13
+ {%- endif -%}
14
+ {%- endmacro %}
15
+
16
+
17
+ {%- if tools -%}
18
+ <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>
19
+ {%- endif -%}
20
+ {% for message in messages %}
21
+ {%- set role_name = message.get('name') or message['role'] -%}
22
+ {%- if message['role'] == 'user' -%}
23
+ <|im_user|>{{role_name}}<|im_middle|>
24
+ {%- elif message['role'] == 'assistant' -%}
25
+ <|im_assistant|>{{role_name}}<|im_middle|>
26
+ {%- else -%}
27
+ <|im_system|>{{role_name}}<|im_middle|>
28
+ {%- endif -%}
29
+
30
+ {%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
31
+ {{render_content(message)}}<|tool_calls_section_begin|>
32
+ {%- for tool_call in message['tool_calls'] -%}
33
+ {%- set formatted_id = tool_call['id'] -%}
34
+ <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
35
+ {%- endfor -%}
36
+ <|tool_calls_section_end|>
37
+ {%- elif message['role'] == 'tool' -%}
38
+ {%- set tool_call_id = message.tool_call_id -%}
39
+ ## Return of {{ tool_call_id }}
40
+ {{render_content(message)}}
41
+ {%- elif message['content'] is not none -%}
42
+ {{render_content(message)}}
43
+ {%- endif -%}
44
+ <|im_end|>
45
+ {%- endfor -%}
46
+ {%- if add_generation_prompt -%}
47
+ <|im_assistant|>assistant<|im_middle|>
48
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KimiLinearForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_kimi.KimiLinearConfig",
7
+ "AutoModel": "modeling_kimi.KimiLinearModel",
8
+ "AutoModelForCausalLM": "modeling_kimi.KimiLinearForCausalLM"
9
+ },
10
+ "bos_token_id": 163584,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 163586,
13
+ "first_k_dense_replace": 1,
14
+ "head_dim": 72,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2304,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 9216,
19
+ "kv_lora_rank": 512,
20
+ "linear_attn_config": {
21
+ "full_attn_layers": [
22
+ 4,
23
+ 8,
24
+ 12,
25
+ 16,
26
+ 20,
27
+ 24,
28
+ 27
29
+ ],
30
+ "head_dim": 128,
31
+ "kda_layers": [
32
+ 1,
33
+ 2,
34
+ 3,
35
+ 5,
36
+ 6,
37
+ 7,
38
+ 9,
39
+ 10,
40
+ 11,
41
+ 13,
42
+ 14,
43
+ 15,
44
+ 17,
45
+ 18,
46
+ 19,
47
+ 21,
48
+ 22,
49
+ 23,
50
+ 25,
51
+ 26
52
+ ],
53
+ "num_heads": 32,
54
+ "short_conv_kernel_size": 4
55
+ },
56
+ "mla_use_nope": true,
57
+ "model_max_length": 1048576,
58
+ "model_type": "kimi_linear",
59
+ "moe_intermediate_size": 1024,
60
+ "moe_layer_freq": 1,
61
+ "moe_renormalize": true,
62
+ "moe_router_activation_func": "sigmoid",
63
+ "num_attention_heads": 32,
64
+ "num_expert_group": 1,
65
+ "num_experts": 256,
66
+ "num_experts_per_token": 8,
67
+ "num_hidden_layers": 27,
68
+ "num_key_value_heads": 32,
69
+ "num_nextn_predict_layers": 0,
70
+ "num_shared_experts": 1,
71
+ "pad_token_id": 163839,
72
+ "q_lora_rank": null,
73
+ "qk_nope_head_dim": 128,
74
+ "qk_rope_head_dim": 64,
75
+ "rms_norm_eps": 1e-05,
76
+ "rope_scaling": null,
77
+ "rope_theta": 10000.0,
78
+ "routed_scaling_factor": 2.446,
79
+ "tie_word_embeddings": false,
80
+ "topk_group": 1,
81
+ "transformers_version": "4.57.1",
82
+ "use_cache": true,
83
+ "use_grouped_topk": true,
84
+ "v_head_dim": 128,
85
+ "vocab_size": 163840
86
+ }
configuration_kimi.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Optional
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class KimiLinearConfig(PretrainedConfig):
8
+ model_type = "kimi_linear"
9
+ keys_to_ignore_at_inference = ["past_key_values"]
10
+
11
+ def __init__(
12
+ self,
13
+ model_type="kimi_linear",
14
+ vocab_size=163840,
15
+ hidden_size=4096,
16
+ head_dim=None,
17
+ intermediate_size=11008,
18
+ num_hidden_layers=32,
19
+ num_attention_heads=32,
20
+ num_key_value_heads=None,
21
+ hidden_act="silu",
22
+ initializer_range=0.02,
23
+ rms_norm_eps=1e-6,
24
+ use_cache=True,
25
+ pad_token_id=0,
26
+ bos_token_id=1,
27
+ eos_token_id=2,
28
+ rope_theta=10000.0,
29
+ rope_scaling=None,
30
+ tie_word_embeddings=False,
31
+ moe_intermediate_size: Optional[int] = None,
32
+ moe_renormalize: bool = True,
33
+ moe_router_activation_func: str = "sigmoid",
34
+ num_experts: Optional[int] = None,
35
+ num_experts_per_token: Optional[int] = None,
36
+ num_shared_experts: int = 0,
37
+ routed_scaling_factor: float = 1.0,
38
+ first_k_dense_replace: int = 0,
39
+ moe_layer_freq: int = 1,
40
+ use_grouped_topk: bool = True,
41
+ num_expert_group: int = 1,
42
+ topk_group: int = 1,
43
+ q_lora_rank: Optional[int] = None,
44
+ kv_lora_rank: Optional[int] = None,
45
+ qk_nope_head_dim: Optional[int] = None,
46
+ qk_rope_head_dim: Optional[int] = None,
47
+ v_head_dim: Optional[int] = None,
48
+ mla_use_nope: Optional[bool] = False,
49
+ num_nextn_predict_layers: int = 0,
50
+ linear_attn_config: Optional[dict] = None,
51
+ **kwargs,
52
+ ):
53
+ self.model_type = model_type
54
+ self.vocab_size = vocab_size
55
+ self.hidden_size = hidden_size
56
+ self.head_dim = (
57
+ head_dim if head_dim is not None else hidden_size // num_attention_heads
58
+ )
59
+ self.intermediate_size = intermediate_size
60
+ self.num_hidden_layers = num_hidden_layers
61
+ self.num_attention_heads = num_attention_heads
62
+
63
+ # for backward compatibility
64
+ if num_key_value_heads is None:
65
+ num_key_value_heads = num_attention_heads
66
+
67
+ self.num_key_value_heads = num_key_value_heads
68
+ self.hidden_act = hidden_act
69
+ self.initializer_range = initializer_range
70
+ self.rms_norm_eps = rms_norm_eps
71
+ self.use_cache = use_cache
72
+ self.rope_theta = rope_theta
73
+ self.rope_scaling = rope_scaling
74
+
75
+ self.q_lora_rank = q_lora_rank
76
+ self.kv_lora_rank = kv_lora_rank
77
+ self.qk_nope_head_dim = qk_nope_head_dim
78
+ self.qk_rope_head_dim = qk_rope_head_dim
79
+ self.v_head_dim = v_head_dim
80
+ self.mla_use_nope = mla_use_nope
81
+ # moe config
82
+ self.num_experts = num_experts
83
+ self.num_experts_per_token = num_experts_per_token
84
+ self.moe_renormalize = moe_renormalize
85
+ self.num_shared_experts = num_shared_experts
86
+ self.routed_scaling_factor = routed_scaling_factor
87
+ self.moe_router_activation_func = moe_router_activation_func
88
+ assert self.moe_router_activation_func in ("softmax", "sigmoid")
89
+ self.moe_intermediate_size = moe_intermediate_size
90
+ self.first_k_dense_replace = first_k_dense_replace
91
+ self.moe_layer_freq = moe_layer_freq
92
+ self.use_grouped_topk = use_grouped_topk
93
+ self.num_expert_group = num_expert_group
94
+ self.topk_group = topk_group
95
+ self.num_nextn_predict_layers = num_nextn_predict_layers
96
+
97
+ if linear_attn_config is not None:
98
+ assert linear_attn_config["kda_layers"] is not None
99
+ assert linear_attn_config["full_attn_layers"] is not None
100
+ self.linear_attn_config = linear_attn_config
101
+
102
+ super().__init__(
103
+ pad_token_id=pad_token_id,
104
+ bos_token_id=bos_token_id,
105
+ eos_token_id=eos_token_id,
106
+ tie_word_embeddings=tie_word_embeddings,
107
+ **kwargs,
108
+ )
109
+
110
+ @property
111
+ def is_mla(self):
112
+ return (
113
+ self.q_lora_rank is not None
114
+ or self.kv_lora_rank is not None
115
+ or self.qk_nope_head_dim is not None
116
+ or self.qk_rope_head_dim is not None
117
+ or self.v_head_dim is not None
118
+ or self.mla_use_nope is True
119
+ )
120
+
121
+ @property
122
+ def is_moe(self):
123
+ return self.num_experts is not None
124
+
125
+ @property
126
+ def is_linear_attn(self) -> bool:
127
+ return not (
128
+ self.linear_attn_config is None
129
+ or (
130
+ isinstance(self.linear_attn_config, dict)
131
+ and self.linear_attn_config["kda_layers"] is not None
132
+ and len(self.linear_attn_config["kda_layers"]) == 0
133
+ )
134
+ )
135
+
136
+ def is_kda_layer(self, layer_idx: int):
137
+ return (
138
+ self.linear_attn_config is not None
139
+ and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
140
+ )
figures/arch.png ADDED

Git LFS Details

  • SHA256: 132ae021fa4661ed39e7be784d46f05f22b82aabb9afd2bab8dbdc0a5a61cba0
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
figures/banner.png ADDED
figures/github.png ADDED
figures/logo.png ADDED
figures/perf_speed.png ADDED

Git LFS Details

  • SHA256: f8951e618db41ae57fa0cec4845d7b275dffbd7f9db12c6496bfea536c625aea
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 163584,
4
+ "eos_token_id": 163586,
5
+ "pad_token_id": 163839,
6
+ "transformers_version": "4.57.1"
7
+ }
model-00001-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c5c908aa3b86b6486080b577cb7aa8dbe9ca7cb18789653768017e602b61a7f
3
+ size 4999482712
model-00002-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fcb34e9ebe2434f32761c06ef17a465157308e6e583eb7eb70cc25e57cd2cb0
3
+ size 4999923264
model-00003-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f35d2a95dd1e3170fd642d0db4d0d07933985ef59041494652092cc27893e231
3
+ size 4997138040
model-00004-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eda18b6226777bb9a07584dfa64986ac4f28a26cee3203f16ffb14deef9ef48b
3
+ size 4997148016
model-00005-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:652e8a43d493105176807d256af0a5c56e45c6d783e6c8221832918f3425c0a0
3
+ size 4999923296
model-00006-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77f5dc551436c934f0991eee0c319e0f33689ea3c10cb8cb8f48acc32238526f
3
+ size 4997138040
model-00007-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3980c7efccdc6a27633eb8909afc381cb780cccaa7be0347d1645496ea3eb5a2
3
+ size 4997148128
model-00008-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dab91b8eaed9874c75de99a4a08669f520fa3e2c8977175333db552504a1c5d3
3
+ size 4999924384
model-00009-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13f6ae84d557682ec4a0fc8b6090d4f89cdd26e5e216445cc9d77a65c7f4c90b
3
+ size 4997139104
model-00010-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d58ab0a201e26ff429b9d18678a76f3a3284ad977719e055f94d892133ee247b
3
+ size 4997149016
model-00011-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:142aa90317af104b2d9f5a6ae4dc661f4a7f7c152f83d6c2477de8037be92201
3
+ size 4999924408
model-00012-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb1d4fe2d94a04898eb8300b5144923e5540a75091ee5f4c8b67936a69d91780
3
+ size 4997139104
model-00013-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12fb6f2dea889d460f33f7fbb55f76d7beb468698cf56707f4d77a9ab69461d3
3
+ size 4997148992
model-00014-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf178169e0dfdc721492f1a98a5be2ef5f66fd8569039f5a77819641a5a1b32d
3
+ size 4999924440
model-00015-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e65460934e8794faeadd6d8cbeffd23fcdbf07d9c61ca92ef97afc95d0ccdaa
3
+ size 4997139104
model-00016-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05cc77846f94d50dc09180f5844f01aee38e489b1fd833d8c7aec6a62214ef03
3
+ size 4997148960
model-00017-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaf238a2f1a971ef311c445309de323992da59287d55759cf2a4a3a85ca6a1cc
3
+ size 4999924472
model-00018-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:315cdb6964a975522cdb755cf5eb76b46478346b015113e241f81127ad9e6fd4
3
+ size 4997139104
model-00019-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbdc2c77e41baa76a2c2b3ced0a59fe7587e95ca3d1acc75247b88a80dee3041
3
+ size 4999934384
model-00020-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f1e4a9194d045e01c90ed2697939bcedd533b6aa1f1b97b0ae0a5932e5a4bc7
3
+ size 3280687152
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_kimi.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ from einops import rearrange
9
+ from packaging import version
10
+ from torch import nn
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import Cache
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.masking_utils import create_causal_mask
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast)
18
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
19
+ PreTrainedModel)
20
+ from transformers.processing_utils import Unpack
21
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
22
+ from transformers.utils import (TransformersKwargs, auto_docstring,
23
+ can_return_tuple, logging)
24
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
25
+
26
+ try:
27
+ from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
28
+ from fla.modules import FusedRMSNormGated, ShortConvolution
29
+ from fla.ops.kda import chunk_kda, fused_recurrent_kda
30
+ from fla.ops.kda.gate import fused_kda_gate
31
+ except ImportError:
32
+ raise ImportError("Plese run `pip install -U fla-core`")
33
+
34
+ from .configuration_kimi import KimiLinearConfig
35
+
36
+ assert version.parse(transformers.__version__) >= version.parse("4.56.0"), \
37
+ "Please upgrade transformers to >= 4.56.0"
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class KimiDynamicCache:
43
+ """
44
+ Dynamic cache for Kimi model.
45
+ Inspired by Qwen3-Next
46
+ """
47
+ is_compileable = False
48
+
49
+ def __init__(self, config: KimiLinearConfig):
50
+ super().__init__()
51
+ self.config = config
52
+
53
+ if config.linear_attn_config is not None:
54
+ self.layer_types = []
55
+ for i in range(config.num_hidden_layers):
56
+ if config.is_kda_layer(i):
57
+ self.layer_types.append("linear_attention")
58
+ else:
59
+ self.layer_types.append("full_attention")
60
+ else:
61
+ self.layer_types = ["full_attention"] * config.num_hidden_layers
62
+
63
+ self.transformer_layers = [
64
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
65
+ ]
66
+
67
+ linear_layers = [i for i in range(
68
+ config.num_hidden_layers) if self.layer_types[i] == "linear_attention"]
69
+ self.last_linear_layer = linear_layers[-1] if linear_layers else -1
70
+
71
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
72
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
73
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
74
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
75
+
76
+ def __len__(self):
77
+ return len(self.layer_types)
78
+
79
+ def update(
80
+ self,
81
+ key_states: torch.Tensor,
82
+ value_states: torch.Tensor,
83
+ layer_idx: int,
84
+ cache_kwargs: Optional[dict[str, Any]] = None,
85
+ ) -> tuple[torch.Tensor, torch.Tensor]:
86
+ if self.key_cache[layer_idx] is None:
87
+ self.key_cache[layer_idx] = key_states
88
+ self.value_cache[layer_idx] = value_states
89
+ else:
90
+ self.key_cache[layer_idx] = torch.cat(
91
+ [self.key_cache[layer_idx], key_states], dim=2)
92
+ self.value_cache[layer_idx] = torch.cat(
93
+ [self.value_cache[layer_idx], value_states], dim=2)
94
+
95
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
96
+
97
+ def reorder_cache(self, beam_idx: torch.LongTensor):
98
+ """Reorders the cache for beam search, given the selected beam indices."""
99
+ for layer_idx in range(len(self.key_cache)):
100
+ if self.key_cache[layer_idx] is not None:
101
+ device = self.key_cache[layer_idx].device
102
+ beam_idx = beam_idx.to(device)
103
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
104
+ 0, beam_idx)
105
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
106
+ 0, beam_idx)
107
+
108
+ if self.conv_states[layer_idx] is not None:
109
+ device = self.conv_states[layer_idx][0].device
110
+ beam_idx = beam_idx.to(device)
111
+ q_conv, k_conv, v_conv = self.conv_states[layer_idx]
112
+ self.conv_states[layer_idx] = (
113
+ q_conv.index_select(0, beam_idx),
114
+ k_conv.index_select(0, beam_idx),
115
+ v_conv.index_select(0, beam_idx)
116
+ )
117
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
118
+ 0, beam_idx)
119
+
120
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
121
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
122
+ # take any layer that contains cache and not empty tensor
123
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
124
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
125
+ return 0
126
+ return self.key_cache[layer_idx].shape[-2]
127
+
128
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
129
+ """
130
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
131
+ the given layer at `layer_idx`.
132
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
133
+ """
134
+ kv_offset = 0
135
+ query_length = cache_position.shape[0]
136
+ past_seen_tokens = self.get_seq_length(layer_idx)
137
+ kv_length = query_length + past_seen_tokens
138
+ return kv_length, kv_offset
139
+
140
+ @property
141
+ def has_previous_state(self):
142
+ """We have a previous state if the last linear (conv) layer was already updated."""
143
+ if self.last_linear_layer == -1:
144
+ return False
145
+ return self.conv_states[self.last_linear_layer] is not None
146
+
147
+
148
+ class KimiRMSNorm(nn.Module):
149
+ def __init__(self, hidden_size, eps=1e-6):
150
+ """
151
+ KimiRMSNorm is equivalent to T5LayerNorm
152
+ """
153
+ super().__init__()
154
+ self.weight = nn.Parameter(torch.ones(hidden_size))
155
+ self.variance_epsilon = eps
156
+
157
+ def forward(self, hidden_states):
158
+ input_dtype = hidden_states.dtype
159
+ hidden_states = hidden_states.to(torch.float32)
160
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
161
+ hidden_states = hidden_states * \
162
+ torch.rsqrt(variance + self.variance_epsilon)
163
+ return self.weight * hidden_states.to(input_dtype)
164
+
165
+
166
+ ALL_LAYERNORM_LAYERS.append(KimiRMSNorm)
167
+
168
+
169
+ class KimiBlockSparseMLP(nn.Module):
170
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
171
+ super().__init__()
172
+ self.config = config
173
+ self.ffn_dim = config.intermediate_size if intermediate_size is None else intermediate_size
174
+ self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size
175
+
176
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate
177
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down
178
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up
179
+
180
+ self.act_fn = ACT2FN[config.hidden_act]
181
+
182
+ def forward(self, hidden_states):
183
+ current_hidden_states = self.act_fn(
184
+ self.w1(hidden_states)) * self.w3(hidden_states)
185
+ current_hidden_states = self.w2(current_hidden_states)
186
+ return current_hidden_states
187
+
188
+
189
+ class KimiMLP(nn.Module):
190
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
191
+ super().__init__()
192
+ self.config = config
193
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
194
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
195
+ self.gate_proj = nn.Linear(
196
+ self.hidden_size, self.intermediate_size, bias=False)
197
+ self.up_proj = nn.Linear(
198
+ self.hidden_size, self.intermediate_size, bias=False)
199
+ self.down_proj = nn.Linear(
200
+ self.intermediate_size, self.hidden_size, bias=False)
201
+ self.act_fn = ACT2FN[config.hidden_act]
202
+
203
+ def forward(self, x):
204
+ down_proj = self.down_proj(self.act_fn(
205
+ self.gate_proj(x)) * self.up_proj(x))
206
+ return down_proj
207
+
208
+
209
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
210
+ """
211
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
212
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
213
+ """
214
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
215
+ if n_rep == 1:
216
+ return hidden_states
217
+ hidden_states = hidden_states[:, :, None, :, :].expand(
218
+ batch, num_key_value_heads, n_rep, slen, head_dim)
219
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
220
+
221
+
222
+ def eager_attention_forward(
223
+ module: nn.Module,
224
+ query: torch.Tensor,
225
+ key: torch.Tensor,
226
+ value: torch.Tensor,
227
+ attention_mask: Optional[torch.Tensor],
228
+ scaling: float,
229
+ dropout: float = 0.0,
230
+ **kwargs: Unpack[TransformersKwargs],
231
+ ):
232
+ key_states = repeat_kv(key, module.num_key_value_groups)
233
+ value_states = repeat_kv(value, module.num_key_value_groups)
234
+
235
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
236
+ if attention_mask is not None:
237
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
238
+ attn_weights = attn_weights + causal_mask
239
+
240
+ attn_weights = nn.functional.softmax(
241
+ attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
242
+ attn_weights = nn.functional.dropout(
243
+ attn_weights, p=dropout, training=module.training)
244
+ attn_output = torch.matmul(attn_weights, value_states)
245
+ attn_output = attn_output.transpose(1, 2).contiguous()
246
+
247
+ return attn_output, attn_weights
248
+
249
+
250
+ class KimiMLAAttention(nn.Module):
251
+ """
252
+ Multi-Latent Attention adapted from deepseek-v3
253
+ """
254
+
255
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
256
+ nn.Module.__init__(self)
257
+ self.config = config
258
+ self.layer_idx = layer_idx
259
+ self.hidden_size = config.hidden_size
260
+ self.num_heads = config.num_attention_heads
261
+ self.num_key_value_heads = config.num_key_value_heads
262
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
263
+
264
+ self.rope_theta = config.rope_theta
265
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
266
+
267
+ try:
268
+ self.q_lora_rank = config.q_lora_rank
269
+ self.qk_rope_head_dim = config.qk_rope_head_dim
270
+ self.kv_lora_rank = config.kv_lora_rank
271
+ self.v_head_dim = config.v_head_dim
272
+ self.qk_nope_head_dim = config.qk_nope_head_dim
273
+ self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
274
+ self.use_nope = config.mla_use_nope
275
+ self.scaling = self.q_head_dim ** (-0.5)
276
+ except Exception as e:
277
+ raise ValueError(
278
+ f"Kimi MLA config is not found or not properly formatted: {e}")
279
+
280
+ assert self.q_lora_rank is None
281
+ self.q_proj = nn.Linear(
282
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False,
283
+ )
284
+ self.kv_a_proj_with_mqa = nn.Linear(
285
+ self.hidden_size,
286
+ self.kv_lora_rank + self.qk_rope_head_dim,
287
+ bias=False,
288
+ )
289
+ self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank)
290
+ self.kv_b_proj = nn.Linear(
291
+ self.kv_lora_rank,
292
+ self.num_heads
293
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
294
+ bias=False,
295
+ )
296
+ self.o_proj = nn.Linear(
297
+ self.num_heads * self.v_head_dim,
298
+ self.hidden_size,
299
+ bias=False,
300
+ )
301
+ self.is_causal = True
302
+ assert self.use_nope
303
+
304
+ def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ attention_mask: Optional[torch.Tensor] = None,
308
+ past_key_values: Optional[Cache] = None,
309
+ **kwargs,
310
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
+ batch_size, seq_length = hidden_states.shape[:-1]
312
+ query_shape = (batch_size, seq_length, -1, self.q_head_dim)
313
+ key_shape = (batch_size, seq_length, -1,
314
+ self.qk_nope_head_dim + self.v_head_dim)
315
+
316
+ q_states = self.q_proj(hidden_states)
317
+ q_states = q_states.view(query_shape).transpose(1, 2)
318
+ q_pass, q_rot = torch.split(
319
+ q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
320
+
321
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
322
+ k_pass, k_rot = torch.split(
323
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
324
+
325
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(
326
+ k_pass)).view(key_shape).transpose(1, 2)
327
+ k_pass, value_states = torch.split(
328
+ k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
329
+
330
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
331
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
332
+
333
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
334
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
335
+
336
+ if past_key_values is not None:
337
+ key_states, value_states = past_key_values.update(
338
+ key_states, value_states, self.layer_idx)
339
+
340
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
341
+ value_states = F.pad(
342
+ value_states, [0, self.q_head_dim - self.v_head_dim])
343
+
344
+ attention_interface: Callable = eager_attention_forward
345
+ if self.config._attn_implementation != "eager":
346
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
347
+
348
+ attn_output, _ = attention_interface(
349
+ self,
350
+ query_states,
351
+ key_states,
352
+ value_states,
353
+ attention_mask,
354
+ dropout=0.0 if not self.training else self.attention_dropout,
355
+ scaling=self.scaling,
356
+ **kwargs,
357
+ )
358
+
359
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
360
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
361
+
362
+ attn_output = attn_output.reshape(
363
+ batch_size, seq_length, -1).contiguous()
364
+ attn_output = self.o_proj(attn_output)
365
+ return attn_output
366
+
367
+
368
+ class KimiDeltaAttention(nn.Module):
369
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
370
+ super().__init__()
371
+ self.config = config
372
+ self.mode = "chunk"
373
+
374
+ self.hidden_size = config.hidden_size
375
+ self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
376
+ self.head_dim = config.linear_attn_config["head_dim"]
377
+ self.num_heads = config.linear_attn_config["num_heads"]
378
+ self.head_k_dim = self.head_dim
379
+ self.num_k_heads = self.num_heads
380
+
381
+ self.layer_idx = layer_idx
382
+
383
+ assert self.mode in [
384
+ 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{self.mode}`."
385
+
386
+ projection_k_size = self.head_k_dim * self.num_k_heads
387
+ projection_size = self.head_dim * self.num_heads
388
+
389
+ self.q_proj = nn.Linear(
390
+ self.hidden_size, projection_k_size, bias=False)
391
+ self.k_proj = nn.Linear(
392
+ self.hidden_size, projection_k_size, bias=False)
393
+ self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)
394
+
395
+ self.q_conv1d = ShortConvolution(
396
+ hidden_size=projection_k_size,
397
+ kernel_size=self.conv_size,
398
+ activation='silu',
399
+ )
400
+ self.k_conv1d = ShortConvolution(
401
+ hidden_size=projection_k_size,
402
+ kernel_size=self.conv_size,
403
+ activation='silu'
404
+ )
405
+ self.v_conv1d = ShortConvolution(
406
+ hidden_size=projection_size,
407
+ kernel_size=self.conv_size,
408
+ activation='silu'
409
+ )
410
+
411
+ self.A_log = torch.nn.Parameter(torch.log(torch.empty(
412
+ self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1))
413
+
414
+ self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
415
+ self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
416
+
417
+ self.dt_bias = nn.Parameter(
418
+ torch.empty(projection_size, dtype=torch.float32))
419
+
420
+ self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
421
+
422
+ self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
423
+ self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
424
+
425
+ self.o_norm = FusedRMSNormGated(
426
+ self.head_dim, eps=config.rms_norm_eps, activation='sigmoid')
427
+ self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: torch.Tensor,
432
+ attention_mask: Optional[torch.Tensor] = None,
433
+ cache_params: Optional[KimiDynamicCache] = None,
434
+ **kwargs: Unpack[dict]
435
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
436
+ if attention_mask is not None:
437
+ if attention_mask.dim() != 2:
438
+ attention_mask = kwargs.get("padding_mask", None)
439
+
440
+ if attention_mask is not None and attention_mask.dim() != 2:
441
+ raise ValueError(
442
+ "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
443
+ "(0 = padding). 3D masks are not supported here."
444
+ )
445
+ use_cache = cache_params is not None
446
+ batch_size, q_len, _ = hidden_states.shape
447
+ mode = 'fused_recurrent' if q_len <= 64 else self.mode
448
+ if self.training:
449
+ assert mode == 'chunk', "Only chunk mode is supported in training."
450
+
451
+ cu_seqlens = kwargs.get('cu_seqlens', None)
452
+ indices = None
453
+ if attention_mask is not None:
454
+ indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
455
+ hidden_states = index_first_axis(
456
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
457
+
458
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
459
+ recurrent_state = None
460
+ if cache_params is not None:
461
+ if cache_params.conv_states[self.layer_idx] is not None:
462
+ conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[
463
+ self.layer_idx]
464
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
465
+ q, conv_state_q = self.q_conv1d(
466
+ x=self.q_proj(hidden_states),
467
+ cache=conv_state_q,
468
+ output_final_state=use_cache,
469
+ cu_seqlens=cu_seqlens
470
+ )
471
+ k, conv_state_k = self.k_conv1d(
472
+ x=self.k_proj(hidden_states),
473
+ cache=conv_state_k,
474
+ output_final_state=use_cache,
475
+ cu_seqlens=cu_seqlens
476
+ )
477
+ v, conv_state_v = self.v_conv1d(
478
+ x=self.v_proj(hidden_states),
479
+ cache=conv_state_v,
480
+ output_final_state=use_cache,
481
+ cu_seqlens=cu_seqlens
482
+ )
483
+ g = self.f_b_proj(self.f_a_proj(hidden_states))
484
+ g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
485
+ beta = self.b_proj(hidden_states).float().sigmoid()
486
+
487
+ q, k = map(lambda x: rearrange(
488
+ x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
489
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
490
+
491
+ if mode == 'chunk':
492
+ o, recurrent_state = chunk_kda(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ beta=beta,
498
+ initial_state=recurrent_state,
499
+ output_final_state=True,
500
+ use_qk_l2norm_in_kernel=True,
501
+ cu_seqlens=cu_seqlens,
502
+ )
503
+ else:
504
+ o, recurrent_state = fused_recurrent_kda(
505
+ q=q,
506
+ k=k,
507
+ v=v,
508
+ g=g,
509
+ beta=beta,
510
+ initial_state=recurrent_state,
511
+ output_final_state=True,
512
+ use_qk_l2norm_in_kernel=True,
513
+ cu_seqlens=cu_seqlens,
514
+ )
515
+ if cache_params is not None:
516
+ cache_params.recurrent_states[self.layer_idx] = recurrent_state
517
+ cache_params.conv_states[self.layer_idx] = (
518
+ conv_state_q, conv_state_k, conv_state_v)
519
+
520
+ g = self.g_b_proj(self.g_a_proj(hidden_states))
521
+ g = rearrange(g, '... (h d) -> ... h d', d=self.head_dim)
522
+ o = self.o_norm(o, g)
523
+
524
+ o = rearrange(o, 'b t h d -> b t (h d)')
525
+ o = self.o_proj(o)
526
+ if attention_mask is not None:
527
+ o = pad_input(o.squeeze(0), indices, batch_size, q_len)
528
+
529
+ return o
530
+
531
+
532
+ class KimiMoEGate(nn.Module):
533
+ """
534
+ MoEGate adapted from Deepseek-V3.
535
+ Parameter correspondences:
536
+ num_experts -> n_routed_experts
537
+ num_experts_per_token -> num_experts_per_tok
538
+ num_expert_group -> n_group
539
+ moe_router_activation_func -> scoring_func
540
+ """
541
+
542
+ def __init__(self, config: KimiLinearConfig):
543
+ super().__init__()
544
+ self.config = config
545
+ self.top_k = config.num_experts_per_token
546
+ self.num_experts = config.num_experts
547
+ self.routed_scaling_factor = config.routed_scaling_factor
548
+ self.moe_router_activation_func = config.moe_router_activation_func
549
+ self.num_expert_group = getattr(config, "num_expert_group", 1)
550
+ self.topk_group = getattr(config, "topk_group", 1)
551
+
552
+ # topk selection algorithm
553
+ self.moe_renormalize = config.moe_renormalize
554
+ self.gating_dim = config.hidden_size
555
+ self.weight = nn.Parameter(
556
+ torch.empty((self.num_experts, self.gating_dim))
557
+ )
558
+
559
+ self.e_score_correction_bias = nn.Parameter(
560
+ torch.empty((self.num_experts))
561
+ )
562
+ self.reset_parameters()
563
+
564
+ def reset_parameters(self) -> None:
565
+ import torch.nn.init as init
566
+
567
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
568
+
569
+ def forward(self, hidden_states):
570
+ bsz, seq_len, h = hidden_states.shape
571
+ # compute gating score
572
+ hidden_states = hidden_states.view(-1, h)
573
+ logits = F.linear(
574
+ hidden_states.type(torch.float32), self.weight.type(
575
+ torch.float32), None
576
+ )
577
+ if self.moe_router_activation_func == "sigmoid":
578
+ scores = logits.sigmoid()
579
+ elif self.moe_router_activation_func == "softmax":
580
+ scores = logits.softmax(dim=1)
581
+ else:
582
+ raise NotImplementedError(
583
+ f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}"
584
+ )
585
+
586
+ # select top-k experts
587
+ assert not self.training
588
+ scores_for_choice = scores.view(bsz * seq_len, -1)
589
+ scores_for_choice += self.e_score_correction_bias.unsqueeze(0)
590
+ group_scores = (
591
+ scores_for_choice.view(
592
+ bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
593
+ ) # [n, num_expert_group]
594
+ group_idx = torch.topk(
595
+ group_scores, k=self.topk_group, dim=-1, sorted=False
596
+ )[
597
+ 1
598
+ ] # [n, top_k_group]
599
+ group_mask = torch.zeros_like(group_scores) # [n, num_expert_group]
600
+ group_mask.scatter_(1, group_idx, 1) # [n, num_expert_group]
601
+ score_mask = (
602
+ group_mask.unsqueeze(-1)
603
+ .expand(
604
+ bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group
605
+ )
606
+ .reshape(bsz * seq_len, -1)
607
+ ) # [n, e]
608
+ tmp_scores = scores_for_choice.masked_fill(
609
+ ~score_mask.bool(), 0.0) # [n, e]
610
+ _, topk_idx = torch.topk(
611
+ tmp_scores, k=self.top_k, dim=-1, sorted=False
612
+ )
613
+ topk_weight = scores.gather(1, topk_idx)
614
+
615
+ # norm gate to sum 1
616
+ if self.top_k > 1 and self.moe_renormalize:
617
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
618
+ topk_weight = topk_weight / denominator
619
+ # must multiply the scaling factor
620
+ topk_weight = topk_weight * self.routed_scaling_factor
621
+
622
+ return topk_idx, topk_weight
623
+
624
+
625
+ class KimiSparseMoeBlock(nn.Module):
626
+ """
627
+ Adapted from Deepseek-V3's MOE implementation
628
+ The namings are consistent with Kimi's version.
629
+ """
630
+
631
+ def __init__(self, config: KimiLinearConfig):
632
+ super().__init__()
633
+ self.config = config
634
+ self.hidden_dim = config.hidden_size
635
+ self.num_experts = config.num_experts
636
+ self.top_k = config.num_experts_per_token
637
+ self.moe_renormalize = config.moe_renormalize
638
+
639
+ self.ep_size = 1
640
+ self.experts_per_rank = config.num_experts
641
+ self.ep_rank = 0
642
+ self.experts = nn.ModuleList(
643
+ [
644
+ KimiBlockSparseMLP(
645
+ config, intermediate_size=config.moe_intermediate_size
646
+ )
647
+ for _ in range(config.num_experts)
648
+ ]
649
+ )
650
+ self.gate = KimiMoEGate(config)
651
+ if config.num_shared_experts is not None:
652
+ intermediate_size = config.moe_intermediate_size * config.num_shared_experts
653
+ self.shared_experts = KimiMLP(
654
+ config=config, intermediate_size=intermediate_size
655
+ )
656
+
657
+ def forward(self, hidden_states):
658
+ identity = hidden_states
659
+ orig_shape = hidden_states.shape
660
+ topk_idx, topk_weight = self.gate(hidden_states)
661
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
662
+ flat_topk_idx = topk_idx.view(-1)
663
+ if not self.training:
664
+ y = self.moe_infer(hidden_states, topk_idx,
665
+ topk_weight).view(*orig_shape)
666
+ else:
667
+ raise NotImplementedError(
668
+ "Training mode is not supported in KimiSparseMoeBlock")
669
+ if self.config.num_shared_experts is not None:
670
+ y = y + self.shared_experts(identity)
671
+ return y
672
+
673
+ @torch.no_grad()
674
+ def moe_infer(self, x, topk_ids, topk_weight):
675
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
676
+ cnts.scatter_(1, topk_ids, 1)
677
+ tokens_per_expert = cnts.sum(dim=0)
678
+ idxs = topk_ids.view(-1).argsort()
679
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
680
+
681
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
682
+
683
+ outputs = []
684
+ start_idx = 0
685
+ for i, num_tokens in enumerate(tokens_per_expert):
686
+ end_idx = start_idx + num_tokens
687
+ if num_tokens == 0:
688
+ continue
689
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
690
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
691
+ expert_out = expert(tokens_for_this_expert)
692
+ outputs.append(expert_out)
693
+ start_idx = end_idx
694
+
695
+ outs = torch.cat(outputs, dim=0) if len(
696
+ outputs) else sorted_tokens.new_empty(0)
697
+
698
+ new_x = torch.empty_like(outs)
699
+ new_x[idxs] = outs
700
+ final_out = (
701
+ new_x.view(*topk_ids.shape, -1)
702
+ .type(topk_weight.dtype)
703
+ .mul_(topk_weight.unsqueeze(dim=-1))
704
+ .sum(dim=1)
705
+ .type(new_x.dtype)
706
+ )
707
+ return final_out
708
+
709
+
710
+ class KimiDecoderLayer(nn.Module):
711
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
712
+ super().__init__()
713
+ self.hidden_size = config.hidden_size
714
+ self.config = config
715
+ if config.is_kda_layer(layer_idx):
716
+ self.is_linear_attn = True
717
+ self.self_attn = KimiDeltaAttention(
718
+ config=config, layer_idx=layer_idx)
719
+ elif config.is_mla:
720
+ self.is_linear_attn = False
721
+ self.self_attn = KimiMLAAttention(
722
+ config=config, layer_idx=layer_idx)
723
+ else:
724
+ raise NotImplementedError
725
+ if (
726
+ config.num_experts is not None
727
+ and layer_idx >= config.first_k_dense_replace
728
+ and layer_idx % getattr(config, "moe_layer_freq", 1) == 0
729
+ ):
730
+ self.block_sparse_moe = KimiSparseMoeBlock(config)
731
+ else:
732
+ self.mlp = KimiMLP(config)
733
+ self.input_layernorm = KimiRMSNorm(
734
+ config.hidden_size, eps=config.rms_norm_eps)
735
+ self.post_attention_layernorm = KimiRMSNorm(
736
+ config.hidden_size, eps=config.rms_norm_eps)
737
+
738
+ def forward(
739
+ self,
740
+ hidden_states: torch.Tensor,
741
+ attention_mask: Optional[torch.Tensor] = None,
742
+ position_ids: Optional[torch.LongTensor] = None,
743
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
744
+ output_attentions: Optional[bool] = False,
745
+ use_cache: Optional[bool] = False,
746
+ **kwargs: Unpack[FlashAttentionKwargs],
747
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
748
+ """
749
+ Args:
750
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
751
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
752
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
753
+ output_attentions (`bool`, *optional*):
754
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
755
+ returned tensors for more detail.
756
+ use_cache (`bool`, *optional*):
757
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
758
+ (see `past_key_values`).
759
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
760
+ """
761
+
762
+ residual = hidden_states
763
+
764
+ hidden_states = self.input_layernorm(hidden_states)
765
+
766
+ # Self Attention
767
+ if self.is_linear_attn is False:
768
+ hidden_states = self.self_attn(
769
+ hidden_states=hidden_states,
770
+ attention_mask=attention_mask,
771
+ position_ids=position_ids,
772
+ past_key_values=past_key_values,
773
+ output_attentions=output_attentions,
774
+ use_cache=use_cache,
775
+ **kwargs,
776
+ )
777
+ else:
778
+ hidden_states = self.self_attn(
779
+ hidden_states=hidden_states,
780
+ attention_mask=attention_mask,
781
+ cache_params=past_key_values,
782
+ output_attentions=output_attentions,
783
+ use_cache=use_cache,
784
+ **kwargs,
785
+ )
786
+ hidden_states = residual + hidden_states
787
+
788
+ # Fully Connected
789
+ residual = hidden_states
790
+ hidden_states = self.post_attention_layernorm(hidden_states)
791
+ if hasattr(self, "block_sparse_moe"):
792
+ hidden_states = self.block_sparse_moe(hidden_states)
793
+ else:
794
+ hidden_states = self.mlp(hidden_states)
795
+ hidden_states = residual + hidden_states
796
+
797
+ return hidden_states
798
+
799
+
800
+ class KimiPreTrainedModel(PreTrainedModel):
801
+ config_class = KimiLinearConfig
802
+ base_model_prefix = "model"
803
+ supports_gradient_checkpointing = True
804
+ _no_split_modules = ["KimiDecoderLayer"]
805
+ _skip_keys_device_placement = "past_key_values"
806
+ _supports_flash_attn_2 = True
807
+ _can_record_outputs = {
808
+ "router_logits": OutputRecorder(KimiBlockSparseMLP, index=1),
809
+ "hidden_states": KimiDecoderLayer,
810
+ "attentions": KimiMLAAttention,
811
+ }
812
+ _is_stateful = True
813
+
814
+ def _init_weights(self, module):
815
+ std = self.config.initializer_range
816
+ if isinstance(module, nn.Linear):
817
+ module.weight.data.normal_(mean=0.0, std=std)
818
+ if module.bias is not None:
819
+ module.bias.data.zero_()
820
+ elif isinstance(module, nn.Embedding):
821
+ module.weight.data.normal_(mean=0.0, std=std)
822
+ if module.padding_idx is not None:
823
+ module.weight.data[module.padding_idx].zero_()
824
+
825
+
826
+ class KimiLinearModel(KimiPreTrainedModel):
827
+ def __init__(self, config: KimiLinearConfig):
828
+ super().__init__(config)
829
+ self.padding_idx = config.pad_token_id
830
+ self.vocab_size = config.vocab_size
831
+
832
+ self.embed_tokens = nn.Embedding(
833
+ config.vocab_size, config.hidden_size, self.padding_idx)
834
+ self.layers = nn.ModuleList([KimiDecoderLayer(
835
+ config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
836
+ self.norm = KimiRMSNorm(
837
+ config.hidden_size, eps=config.rms_norm_eps)
838
+
839
+ if getattr(config, "_attn_implementation", None) is not None:
840
+ if config._attn_implementation != "flash_attention_2":
841
+ logger.warning_once(
842
+ f"Ignoring the provided attention implementation {config._attn_implementation}")
843
+ logger.warning_once("Using flash_attention_2 backend instead.")
844
+ config._attn_implementation = "flash_attention_2"
845
+ else:
846
+ config._attn_implementation = "flash_attention_2"
847
+
848
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
849
+ self.gradient_checkpointing = False
850
+ # Initialize weights and apply final processing
851
+ self.post_init()
852
+
853
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
854
+ """
855
+ NOTE: Left-padding is used for linear attention mask.
856
+ No need for zeroing states when
857
+ 1. Cached forward
858
+ 2. Attending to all inputs
859
+ """
860
+ linear_attn_mask = attention_mask
861
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
862
+ linear_attn_mask = None
863
+ return linear_attn_mask
864
+
865
+ @check_model_inputs
866
+ @auto_docstring
867
+ def forward(
868
+ self,
869
+ input_ids: torch.LongTensor = None,
870
+ attention_mask: Optional[torch.Tensor] = None,
871
+ position_ids: Optional[torch.LongTensor] = None,
872
+ past_key_values: Optional[Cache] = None,
873
+ inputs_embeds: Optional[torch.FloatTensor] = None,
874
+ cache_position: Optional[torch.LongTensor] = None,
875
+ use_cache: Optional[bool] = None,
876
+ **kwargs: Unpack[TransformersKwargs],
877
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
878
+
879
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
880
+
881
+ if (input_ids is None) and (inputs_embeds is None):
882
+ raise ValueError(
883
+ "You must specify exactly one of input_ids or inputs_embeds")
884
+
885
+ # Get inputs_embeds
886
+ if inputs_embeds is None:
887
+ inputs_embeds = self.embed_tokens(input_ids)
888
+
889
+ if use_cache and past_key_values is None:
890
+ past_key_values = KimiDynamicCache(config=self.config)
891
+
892
+ if cache_position is None:
893
+ past_seen_tokens = past_key_values.get_seq_length(
894
+ ) if past_key_values is not None else 0
895
+ cache_position: torch.Tensor = torch.arange(
896
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
897
+ )
898
+
899
+ if position_ids is None:
900
+ position_ids = cache_position.unsqueeze(0)
901
+
902
+ causal_mask = create_causal_mask(
903
+ config=self.config,
904
+ input_embeds=inputs_embeds,
905
+ attention_mask=attention_mask,
906
+ cache_position=cache_position,
907
+ past_key_values=past_key_values,
908
+ position_ids=position_ids,
909
+ )
910
+ linear_attn_mask = self._update_linear_attn_mask(
911
+ attention_mask, cache_position)
912
+
913
+ hidden_states = inputs_embeds
914
+ if past_key_values is not None:
915
+ assert isinstance(past_key_values, KimiDynamicCache)
916
+
917
+ for decoder_layer in self.layers:
918
+ layer_mask = linear_attn_mask if decoder_layer.is_linear_attn else causal_mask
919
+
920
+ hidden_states = decoder_layer(
921
+ hidden_states,
922
+ attention_mask=layer_mask,
923
+ past_key_values=past_key_values,
924
+ cache_position=cache_position,
925
+ **kwargs,
926
+ )
927
+
928
+ hidden_states = self.norm(hidden_states)
929
+
930
+ return BaseModelOutputWithPast(
931
+ last_hidden_state=hidden_states,
932
+ past_key_values=past_key_values,
933
+ )
934
+
935
+
936
+ class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):
937
+ _tied_weights_keys = ["lm_head.weight"]
938
+
939
+ def __init__(self, config):
940
+ super().__init__(config)
941
+ self.model = KimiLinearModel(config)
942
+ self.vocab_size = config.vocab_size
943
+ self.lm_head = nn.Linear(
944
+ config.hidden_size, config.vocab_size, bias=False)
945
+
946
+ # Initialize weights and apply final processing
947
+ self.post_init()
948
+
949
+ @can_return_tuple
950
+ @auto_docstring
951
+ def forward(
952
+ self,
953
+ input_ids: torch.LongTensor = None,
954
+ attention_mask: Optional[torch.Tensor] = None,
955
+ position_ids: Optional[torch.LongTensor] = None,
956
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
957
+ inputs_embeds: Optional[torch.FloatTensor] = None,
958
+ labels: Optional[torch.LongTensor] = None,
959
+ use_cache: Optional[bool] = None,
960
+ output_attentions: Optional[bool] = None,
961
+ output_hidden_states: Optional[bool] = None,
962
+ generation_mode: Optional[bool] = None,
963
+ return_dict: Optional[bool] = None,
964
+ cache_position: Optional[torch.LongTensor] = None,
965
+ **kwargs: Unpack[TransformersKwargs],
966
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
967
+ r"""
968
+ Args:
969
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
970
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
971
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
972
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
973
+
974
+ Returns:
975
+
976
+ Example:
977
+
978
+ ```python
979
+ >>> from transformers import AutoTokenizer, KimiLinearForCausalLM
980
+
981
+ >>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
982
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
983
+
984
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
985
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
986
+
987
+ >>> # Generate
988
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
989
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
990
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
991
+ ```"""
992
+
993
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
994
+ output_hidden_states = (
995
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
996
+ )
997
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
998
+
999
+ outputs = self.model(
1000
+ input_ids=input_ids,
1001
+ attention_mask=attention_mask,
1002
+ position_ids=position_ids,
1003
+ past_key_values=past_key_values,
1004
+ inputs_embeds=inputs_embeds,
1005
+ use_cache=use_cache,
1006
+ output_attentions=output_attentions,
1007
+ output_hidden_states=output_hidden_states,
1008
+ return_dict=return_dict,
1009
+ cache_position=cache_position,
1010
+ )
1011
+
1012
+ logits = outputs[0]
1013
+ if generation_mode:
1014
+ logits = logits[:, -1:]
1015
+ logits = self.lm_head(logits)
1016
+
1017
+ loss = None
1018
+ if labels is not None:
1019
+ loss = self.loss_function(
1020
+ logits, labels, self.vocab_size, **kwargs)
1021
+
1022
+ return CausalLMOutputWithPast(
1023
+ loss=loss,
1024
+ logits=logits,
1025
+ past_key_values=outputs.past_key_values,
1026
+ hidden_states=outputs.hidden_states,
1027
+ attentions=outputs.attentions,
1028
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "[extra_id_0]",
4
+ "[extra_id_1]",
5
+ "[extra_id_2]",
6
+ "[extra_id_3]",
7
+ "[start_header_id]",
8
+ "[end_header_id]",
9
+ "[extra_id_4]",
10
+ "[EOT]",
11
+ "[extra_id_5]",
12
+ "[extra_id_6]",
13
+ "[extra_id_7]",
14
+ "[extra_id_8]",
15
+ "[extra_id_9]",
16
+ "[extra_id_10]",
17
+ "[extra_id_11]",
18
+ "[extra_id_12]",
19
+ "[extra_id_13]",
20
+ "[extra_id_14]",
21
+ "[extra_id_15]",
22
+ "[extra_id_16]",
23
+ "[extra_id_17]",
24
+ "[extra_id_18]",
25
+ "[extra_id_19]",
26
+ "[extra_id_20]",
27
+ "[extra_id_21]",
28
+ "[extra_id_22]",
29
+ "[extra_id_23]",
30
+ "[extra_id_24]",
31
+ "[extra_id_25]",
32
+ "[extra_id_26]",
33
+ "[extra_id_27]",
34
+ "[extra_id_28]",
35
+ "[extra_id_29]",
36
+ "[extra_id_30]",
37
+ "[extra_id_31]",
38
+ "[extra_id_32]",
39
+ "[extra_id_33]",
40
+ "[extra_id_34]",
41
+ "[extra_id_35]",
42
+ "[extra_id_36]",
43
+ "[extra_id_37]",
44
+ "[extra_id_38]",
45
+ "[extra_id_39]",
46
+ "[extra_id_40]",
47
+ "[extra_id_41]",
48
+ "[extra_id_42]",
49
+ "[extra_id_43]",
50
+ "[extra_id_44]",
51
+ "[extra_id_45]",
52
+ "[extra_id_46]",
53
+ "[extra_id_47]",
54
+ "[extra_id_48]",
55
+ "[extra_id_49]",
56
+ "[extra_id_50]",
57
+ "[extra_id_51]",
58
+ "[extra_id_52]",
59
+ "[extra_id_53]",
60
+ "[extra_id_54]",
61
+ "[extra_id_55]",
62
+ "[extra_id_56]",
63
+ "[extra_id_57]",
64
+ "[extra_id_58]",
65
+ "[extra_id_59]",
66
+ "[extra_id_60]",
67
+ "[extra_id_61]",
68
+ "[extra_id_62]",
69
+ "[extra_id_63]",
70
+ "[extra_id_64]",
71
+ "[extra_id_65]",
72
+ "[extra_id_66]",
73
+ "[extra_id_67]",
74
+ "[extra_id_68]",
75
+ "[extra_id_69]",
76
+ "[extra_id_70]",
77
+ "[extra_id_71]",
78
+ "[extra_id_72]",
79
+ "[extra_id_73]",
80
+ "[extra_id_74]",
81
+ "[extra_id_75]",
82
+ "[extra_id_76]",
83
+ "[extra_id_77]",
84
+ "[extra_id_78]",
85
+ "[extra_id_79]",
86
+ "[extra_id_80]",
87
+ "[extra_id_81]",
88
+ "[extra_id_82]",
89
+ "[extra_id_83]",
90
+ "[extra_id_84]",
91
+ "[extra_id_85]",
92
+ "[extra_id_86]",
93
+ "[extra_id_87]",
94
+ "[extra_id_88]",
95
+ "[extra_id_89]",
96
+ "[extra_id_90]",
97
+ "[extra_id_91]",
98
+ "[extra_id_92]",
99
+ "[extra_id_93]",
100
+ "[extra_id_94]",
101
+ "[extra_id_95]",
102
+ "[extra_id_96]",
103
+ "[extra_id_97]",
104
+ "[extra_id_98]",
105
+ "[extra_id_99]",
106
+ "[extra_id_100]",
107
+ "[extra_id_101]",
108
+ "[extra_id_102]",
109
+ "[extra_id_103]",
110
+ "[extra_id_104]",
111
+ "[extra_id_105]",
112
+ "[extra_id_106]",
113
+ "[extra_id_107]",
114
+ "[extra_id_108]",
115
+ "[extra_id_109]",
116
+ "[extra_id_110]",
117
+ "[extra_id_111]",
118
+ "[extra_id_112]",
119
+ "[extra_id_113]",
120
+ "[extra_id_114]",
121
+ "[extra_id_115]",
122
+ "[extra_id_116]",
123
+ "[extra_id_117]",
124
+ "[extra_id_118]",
125
+ "[extra_id_119]",
126
+ "[extra_id_120]",
127
+ "[extra_id_121]",
128
+ "[extra_id_122]",
129
+ "[extra_id_123]",
130
+ "[extra_id_124]",
131
+ "[extra_id_125]",
132
+ "[extra_id_126]",
133
+ "[extra_id_127]",
134
+ "[extra_id_128]",
135
+ "[extra_id_129]",
136
+ "[extra_id_130]",
137
+ "[extra_id_131]",
138
+ "[extra_id_132]",
139
+ "[extra_id_133]",
140
+ "[extra_id_134]",
141
+ "[extra_id_135]",
142
+ "[extra_id_136]",
143
+ "[extra_id_137]",
144
+ "[extra_id_138]",
145
+ "[extra_id_139]",
146
+ "[extra_id_140]",
147
+ "[extra_id_141]",
148
+ "[extra_id_142]",
149
+ "[extra_id_143]",
150
+ "[extra_id_144]",
151
+ "[extra_id_145]",
152
+ "[extra_id_146]",
153
+ "[extra_id_147]",
154
+ "[extra_id_148]",
155
+ "[extra_id_149]",
156
+ "[extra_id_150]",
157
+ "[extra_id_151]",
158
+ "[extra_id_152]",
159
+ "[extra_id_153]",
160
+ "[extra_id_154]",
161
+ "[extra_id_155]",
162
+ "[extra_id_156]",
163
+ "[extra_id_157]",
164
+ "[extra_id_158]",
165
+ "[extra_id_159]",
166
+ "[extra_id_160]",
167
+ "[extra_id_161]",
168
+ "[extra_id_162]",
169
+ "[extra_id_163]",
170
+ "[extra_id_164]",
171
+ "[extra_id_165]",
172
+ "[extra_id_166]",
173
+ "[extra_id_167]",
174
+ "[extra_id_168]",
175
+ "[extra_id_169]",
176
+ "[extra_id_170]",
177
+ "[extra_id_171]",
178
+ "[extra_id_172]",
179
+ "[extra_id_173]",
180
+ "[extra_id_174]",
181
+ "[extra_id_175]",
182
+ "[extra_id_176]",
183
+ "[extra_id_177]",
184
+ "[extra_id_178]",
185
+ "[extra_id_179]",
186
+ "[extra_id_180]",
187
+ "[extra_id_181]",
188
+ "[extra_id_182]",
189
+ "[extra_id_183]",
190
+ "[extra_id_184]",
191
+ "[extra_id_185]",
192
+ "[extra_id_186]",
193
+ "[extra_id_187]",
194
+ "[extra_id_188]",
195
+ "[extra_id_189]",
196
+ "[extra_id_190]",
197
+ "[extra_id_191]",
198
+ "[extra_id_192]",
199
+ "[extra_id_193]",
200
+ "[extra_id_194]",
201
+ "[extra_id_195]",
202
+ "[extra_id_196]",
203
+ "[extra_id_197]",
204
+ "[extra_id_198]",
205
+ "[extra_id_199]",
206
+ "[extra_id_200]",
207
+ "[extra_id_201]",
208
+ "[extra_id_202]",
209
+ "[extra_id_203]",
210
+ "[extra_id_204]",
211
+ "[extra_id_205]",
212
+ "[extra_id_206]",
213
+ "[extra_id_207]",
214
+ "[extra_id_208]",
215
+ "[extra_id_209]",
216
+ "[extra_id_210]",
217
+ "[extra_id_211]",
218
+ "[extra_id_212]",
219
+ "[extra_id_213]",
220
+ "[extra_id_214]",
221
+ "[extra_id_215]",
222
+ "[extra_id_216]",
223
+ "[extra_id_217]",
224
+ "[extra_id_218]",
225
+ "[extra_id_219]",
226
+ "[extra_id_220]",
227
+ "[extra_id_221]",
228
+ "[extra_id_222]",
229
+ "[extra_id_223]",
230
+ "[extra_id_224]",
231
+ "[extra_id_225]",
232
+ "[extra_id_226]",
233
+ "[extra_id_227]",
234
+ "[extra_id_228]",
235
+ "[extra_id_229]",
236
+ "[extra_id_230]",
237
+ "[extra_id_231]",
238
+ "[extra_id_232]",
239
+ "[extra_id_233]",
240
+ "[extra_id_234]",
241
+ "[extra_id_235]",
242
+ "[extra_id_236]",
243
+ "[extra_id_237]",
244
+ "[extra_id_238]",
245
+ "[extra_id_239]",
246
+ "[extra_id_240]",
247
+ "[extra_id_241]",
248
+ "[extra_id_242]",
249
+ "[extra_id_243]",
250
+ "[extra_id_244]",
251
+ "[extra_id_245]",
252
+ "[extra_id_246]",
253
+ "[extra_id_247]",
254
+ "[extra_id_248]"
255
+ ],
256
+ "bos_token": "[BOS]",
257
+ "eos_token": "[EOS]",
258
+ "pad_token": "[extra_id_250]",
259
+ "unk_token": "[extra_id_249]"
260
+ }
tiktoken.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6c497a7469b33ced9c38afb1ad6e47f03f5e5dc05f15930799210ec050c5103
3
+ size 2795286
tokenization_kimi.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tiktoken
3
+
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import (
7
+ cast,
8
+ Tuple,
9
+ Dict,
10
+ Iterator,
11
+ List,
12
+ Union,
13
+ Optional,
14
+ )
15
+ from shutil import copyfile
16
+ from tiktoken.load import load_tiktoken_bpe
17
+ from tokenizers import AddedToken, pre_tokenizers, Regex
18
+ from transformers.tokenization_utils import PreTrainedTokenizer
19
+ from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
20
+ from typing import Any
21
+
22
+
23
+ logger = getLogger(__name__)
24
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
25
+
26
+
27
+ class TikTokenTokenizer(PreTrainedTokenizer):
28
+ """
29
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
30
+
31
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
32
+ this superclass for more information regarding those methods.
33
+
34
+ Args:
35
+ vocab_file (`str`):
36
+ The path to the Tiktoken model file.
37
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
38
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
39
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
40
+ The end of sequence token.
41
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
42
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
43
+ token instead. The second to last item in special_tokens.
44
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
45
+ The token used for padding, for example when batching sequences of different lengths.
46
+ additional_special_tokens (list of `str`, *optional*):
47
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
48
+ skipped when decoding if `skip_special_tokens` is set to `True`.
49
+ """
50
+
51
+ vocab_files_names = VOCAB_FILES_NAMES
52
+
53
+ model_input_names = ["input_ids", "attention_mask"]
54
+
55
+ special_tokens: Dict[str, int]
56
+
57
+ num_reserved_special_tokens = 256
58
+
59
+ pat_str = "|".join(
60
+ [
61
+ r"""[\p{Han}]+""",
62
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
63
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
64
+ r"""\p{N}{1,3}""",
65
+ r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
66
+ r"""\s*[\r\n]+""",
67
+ r"""\s+(?!\S)""",
68
+ r"""\s+""",
69
+ ]
70
+ )
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_file,
75
+ bos_token: Union[str, AddedToken]="[BOS]",
76
+ eos_token: Union[str, AddedToken]="[EOS]",
77
+ unk_token: Union[str, AddedToken, None]=None,
78
+ pad_token: Union[str, AddedToken, None]=None,
79
+ additional_special_tokens: List[str]=None,
80
+ added_tokens_decoder: Optional[dict] = None,
81
+ **kwargs,
82
+ ):
83
+ assert os.path.isfile(vocab_file), vocab_file
84
+
85
+ if additional_special_tokens is None:
86
+ additional_special_tokens = [
87
+ "<|im_end|>",
88
+ "<|im_user|>",
89
+ "<|im_assistant|>",
90
+ "<|start_header_id|>",
91
+ "<|end_header_id|>",
92
+ "[EOT]",
93
+ "<|im_system|>",
94
+ "<|im_middle|>",
95
+ ]
96
+
97
+ special_tokens_mapping = {
98
+ i: added_tokens_decoder[i].content for i in added_tokens_decoder
99
+ }
100
+
101
+ self.vocab_file = vocab_file
102
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
103
+ num_base_tokens = len(mergeable_ranks)
104
+ self.special_tokens = {
105
+ special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
106
+ for i in range(
107
+ num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
108
+ )
109
+ }
110
+
111
+
112
+
113
+ self.model = tiktoken.Encoding(
114
+ name=Path(vocab_file).name,
115
+ pat_str=self.pat_str,
116
+ mergeable_ranks=mergeable_ranks,
117
+ special_tokens=self.special_tokens,
118
+ )
119
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
120
+
121
+ self.n_words: int = self.model.n_vocab
122
+ # BOS / EOS token IDs
123
+ self.bos_id: int = self.special_tokens[str(bos_token)]
124
+ self.eos_id: int = self.special_tokens[str(eos_token)]
125
+ logger.info(
126
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
127
+ )
128
+
129
+ self.pad_id: int = self.special_tokens[str(pad_token)]
130
+ self.unk_id: int = self.special_tokens[str(unk_token)]
131
+
132
+ self.byte_encoder = bytes_to_unicode()
133
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
134
+
135
+ self.decoder = {}
136
+ for i in range(self.n_words):
137
+ # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
138
+ decoding = ''.join([
139
+ self.byte_encoder[ord(char)] for char in
140
+ self.model.decode_single_token_bytes(i).decode('latin-1')
141
+ ])
142
+ self.decoder[i] = decoding
143
+
144
+ self.encoder = {}
145
+ for i in range(self.n_words):
146
+ if i in self.decoder:
147
+ self.encoder[self.decoder[i]] = i
148
+
149
+ super().__init__(
150
+ bos_token=bos_token,
151
+ eos_token=eos_token,
152
+ unk_token=unk_token,
153
+ pad_token=pad_token,
154
+ additional_special_tokens=additional_special_tokens,
155
+ **kwargs,
156
+ )
157
+ self.all_special_ids_set = set(self.all_special_ids)
158
+
159
+ def encode(
160
+ self,
161
+ text: str,
162
+ allow_special_tokens: bool = True,
163
+ **kwargs
164
+ ) -> List[int]:
165
+ """
166
+ Encodes a string into a list of token IDs.
167
+
168
+ Args:
169
+ text (str): The input string to be encoded.
170
+
171
+ Returns:
172
+ list[int]: A list of token IDs.
173
+ """
174
+ # If there are other args, we should call super().encode because there are a lot of code
175
+ # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
176
+ # NOTE: our encode method is not compatible with the super().encode method,
177
+ # e.g. split_special_tokens' default is True in our encode method.
178
+ if len(kwargs) > 0:
179
+ logger.warning( f"Calling super().encode with {kwargs}" )
180
+ return super().encode(text, **kwargs)
181
+
182
+ assert type(text) is str
183
+
184
+ # The tiktoken tokenizer can handle <=400k chars without
185
+ # pyo3_runtime.PanicException.
186
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
187
+
188
+ # https://github.com/openai/tiktoken/issues/195
189
+ # Here we iterate over subsequences and split if we exceed the limit
190
+ # of max consecutive non-whitespace or whitespace characters.
191
+ MAX_NO_WHITESPACES_CHARS = 25_000
192
+
193
+ texts = self.pre_tokenizer_process(text)
194
+
195
+ all_substrs = []
196
+ for text in texts:
197
+ substrs = (
198
+ substr
199
+ for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
200
+ for substr in self._split_whitespaces_or_nonwhitespaces(
201
+ text[i: i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
202
+ )
203
+ )
204
+ all_substrs.extend(substrs)
205
+
206
+ t: List[int] = []
207
+ for substr in all_substrs:
208
+ if allow_special_tokens:
209
+ t.extend(
210
+ # we should consider special token as a common token
211
+ self.model.encode(
212
+ substr,
213
+ allowed_special="all",
214
+ )
215
+ )
216
+ else:
217
+ t.extend(
218
+ # we should consider special token as a common token
219
+ self.model.encode(
220
+ substr,
221
+ disallowed_special=(),
222
+ )
223
+ )
224
+
225
+ return t
226
+
227
+ def decode(
228
+ self,
229
+ token_ids: Union[int, List[int]],
230
+ **kwargs
231
+ ) -> str:
232
+ """
233
+ Decodes a list of token IDs into a string.
234
+
235
+ Args:
236
+ token_ids (List[int]): The list of token IDs to be decoded.
237
+
238
+ Returns:
239
+ str: The decoded string.
240
+ """
241
+ # If there are other args, we should call super().decode because there are a lot of code
242
+ # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
243
+ if len(kwargs) > 0:
244
+ return super().decode(token_ids, **kwargs)
245
+
246
+ if type(token_ids) is int:
247
+ token_ids = [token_ids]
248
+
249
+ return self.model.decode(cast(List[int], token_ids))
250
+
251
+ @staticmethod
252
+ def _split_whitespaces_or_nonwhitespaces(
253
+ s: str, max_consecutive_slice_len: int
254
+ ) -> Iterator[str]:
255
+ """
256
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
257
+ consecutive whitespaces or consecutive non-whitespaces.
258
+ """
259
+ current_slice_len = 0
260
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
261
+ slice_start = 0
262
+
263
+ for i in range(len(s)):
264
+ is_now_space = s[i].isspace()
265
+
266
+ if current_slice_is_space ^ is_now_space:
267
+ current_slice_len = 1
268
+ current_slice_is_space = is_now_space
269
+ else:
270
+ current_slice_len += 1
271
+ if current_slice_len > max_consecutive_slice_len:
272
+ yield s[slice_start:i]
273
+ slice_start = i
274
+ current_slice_len = 1
275
+ yield s[slice_start:]
276
+
277
+ def pre_tokenizer_process(self, text: str) -> List[str]:
278
+ """
279
+ pre-tokenizes the input text into a list of tokens.
280
+ This method is used to split the input text into smaller chunks for internal processing.
281
+ """
282
+ return [text]
283
+
284
+
285
+ """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
286
+ @property
287
+ def vocab_size(self) -> int:
288
+ return self.n_words
289
+
290
+ def get_vocab(self) -> Dict[str, int]:
291
+ return self.encoder
292
+
293
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
294
+ return [
295
+ self.decoder[t]
296
+ for t in self.encode(text)
297
+ ]
298
+
299
+ def _convert_token_to_id(self, token: str) -> int:
300
+ return self.encoder.get(token, self.unk_id)
301
+
302
+ def _convert_id_to_token(self, index: int) -> str:
303
+ return self.decoder.get(index)
304
+
305
+ @staticmethod
306
+ def clean_up_tokenization(out_string: str) -> str:
307
+ return out_string
308
+
309
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
310
+ text = ''.join(tokens)
311
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', 'replace')
312
+ return text
313
+
314
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
315
+ if not os.path.isdir(save_directory):
316
+ raise ValueError(f"vocabulary path ({save_directory}) should be a directory")
317
+ out_vocab_file = os.path.join(
318
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
319
+ )
320
+
321
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
322
+ copyfile(self.vocab_file, out_vocab_file)
323
+
324
+ return (out_vocab_file,)
325
+
326
+
327
+
328
+ def apply_chat_template(
329
+ self, conversation, tools: Optional[list[dict]] = None,
330
+ tokenize: bool = False,
331
+ add_generation_prompt: bool = True,
332
+ **kwargs
333
+ ):
334
+ tools = deep_sort_dict(tools)
335
+ return super().apply_chat_template(conversation,
336
+ tools=tools,
337
+ tokenize=tokenize,
338
+ add_generation_prompt=add_generation_prompt,
339
+ **kwargs)
340
+
341
+
342
+ def deep_sort_dict(obj: Any) -> Any:
343
+ if isinstance(obj, dict):
344
+ return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
345
+ if isinstance(obj, list):
346
+ return [deep_sort_dict(item) for item in obj]
347
+ return obj
tokenizer_config.json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "163584": {
4
+ "content": "[BOS]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "163585": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "163586": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "163587": {
28
+ "content": "<|im_user|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "163588": {
36
+ "content": "<|im_assistant|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "163590": {
44
+ "content": "<|start_header_id|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "163591": {
52
+ "content": "<|end_header_id|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "163593": {
60
+ "content": "[EOT]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "163594": {
68
+ "content": "<|im_system|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "163595": {
76
+ "content": "<|tool_calls_section_begin|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "163596": {
84
+ "content": "<|tool_calls_section_end|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "163597": {
92
+ "content": "<|tool_call_begin|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "163598": {
100
+ "content": "<|tool_call_argument_begin|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "163599": {
108
+ "content": "<|tool_call_end|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "163601": {
116
+ "content": "<|im_middle|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "163838": {
124
+ "content": "[UNK]",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "163839": {
132
+ "content": "[PAD]",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ }
139
+ },
140
+ "additional_special_tokens": [
141
+ "<|im_end|>",
142
+ "<|im_user|>",
143
+ "<|im_assistant|>",
144
+ "<|start_header_id|>",
145
+ "<|end_header_id|>",
146
+ "[EOT]",
147
+ "<|im_system|>",
148
+ "<|im_middle|>"
149
+ ],
150
+ "bos_token": "[BOS]",
151
+ "clean_up_tokenization_spaces": false,
152
+ "eos_token": "[EOS]",
153
+ "extra_special_tokens": {},
154
+ "model_max_length": 1000000000000000019884624838656,
155
+ "pad_token": "[PAD]",
156
+ "tokenizer_class": "TikTokenTokenizer",
157
+ "unk_token": "[UNK]",
158
+ "auto_map": {
159
+ "AutoTokenizer": [
160
+ "tokenization_kimi.TikTokenTokenizer",
161
+ null
162
+ ]
163
+ }
164
+ }