Duplicate from moonshotai/Kimi-Linear-48B-A3B-Instruct
Browse filesCo-authored-by: LiZhiyuan <zhiyuan8@users.noreply.huggingface.co>
- .gitattributes +37 -0
- README.md +113 -0
- chat_template.jinja +48 -0
- config.json +86 -0
- configuration_kimi.py +140 -0
- figures/arch.png +3 -0
- figures/banner.png +0 -0
- figures/github.png +0 -0
- figures/logo.png +0 -0
- figures/perf_speed.png +3 -0
- generation_config.json +7 -0
- model-00001-of-00020.safetensors +3 -0
- model-00002-of-00020.safetensors +3 -0
- model-00003-of-00020.safetensors +3 -0
- model-00004-of-00020.safetensors +3 -0
- model-00005-of-00020.safetensors +3 -0
- model-00006-of-00020.safetensors +3 -0
- model-00007-of-00020.safetensors +3 -0
- model-00008-of-00020.safetensors +3 -0
- model-00009-of-00020.safetensors +3 -0
- model-00010-of-00020.safetensors +3 -0
- model-00011-of-00020.safetensors +3 -0
- model-00012-of-00020.safetensors +3 -0
- model-00013-of-00020.safetensors +3 -0
- model-00014-of-00020.safetensors +3 -0
- model-00015-of-00020.safetensors +3 -0
- model-00016-of-00020.safetensors +3 -0
- model-00017-of-00020.safetensors +3 -0
- model-00018-of-00020.safetensors +3 -0
- model-00019-of-00020.safetensors +3 -0
- model-00020-of-00020.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_kimi.py +1028 -0
- special_tokens_map.json +260 -0
- tiktoken.model +3 -0
- tokenization_kimi.py +347 -0
- tokenizer_config.json +164 -0
.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figures/arch.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
figures/perf_speed.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
pipeline_tag: text-generation
|
| 4 |
+
library_name: transformers
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
<div align="center">
|
| 8 |
+
<a href="https://huggingface.co/papers/2510.26692"><img width="80%" src="figures/banner.png"></a>
|
| 9 |
+
</div>
|
| 10 |
+
|
| 11 |
+
<div align="center">
|
| 12 |
+
<a href="https://huggingface.co/papers/2510.26692" ><img src="figures/logo.png" height="16" width="16" style="display: inline-block; vertical-align: middle; margin: 2px;"><b style="display: inline-block;"> Tech Report</b></a> |
|
| 13 |
+
<a href="https://github.com/MoonshotAI/Kimi-Linear"><img src="figures/github.png" height="16" width="16" style="display: inline-block; vertical-align: middle; margin: 2px;"><b style="display: inline-block;"> Code</b></a> |
|
| 14 |
+
<a href="https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct"><img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" height="16" width="16" style="display: inline-block; vertical-align: middle; margin: 2px;"><b style="display: inline-block;"> HuggingFace</b></a>
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
<div align="center">
|
| 18 |
+
<img width="90%" src="figures/perf_speed.png">
|
| 19 |
+
<p><em><b>(a)</b> On MMLU-Pro (4k context length), Kimi Linear achieves 51.0 performance with similar speed as full attention. On RULER (128k context length), it shows Pareto-optimal performance (84.3) and 3.98x speedup. <b>(b)</b> Kimi Linear achieves 6.3x faster TPOT compared to MLA, offering significant speedups at long sequence lengths (1M tokens).</em></p>
|
| 20 |
+
</div>
|
| 21 |
+
|
| 22 |
+
## Overview
|
| 23 |
+
|
| 24 |
+
Kimi Linear is a hybrid linear attention architecture that outperforms traditional full attention methods across various contexts, including short, long, and reinforcement learning (RL) scaling regimes.
|
| 25 |
+
At its core is Kimi Delta Attention (KDA)—a refined version of [Gated DeltaNet](https://arxiv.org/abs/2412.06464) that introduces a more efficient gating mechanism to optimize the use of finite-state RNN memory.
|
| 26 |
+
|
| 27 |
+
Kimi Linear achieves superior performance and hardware efficiency, especially for long-context tasks. It reduces the need for large KV caches by up to 75% and boosts decoding throughput by up to $6\times$ for contexts as long as 1M tokens.
|
| 28 |
+
|
| 29 |
+
We open-source the KDA kernel in [FLA](https://github.com/fla-org/flash-linear-attention/tree/main/fla/ops/kda), and release two versions model checkpoints trained with 5.7T tokens.
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
| **Model** | **#Total Params** | **#Activated Params** | **Context Length** | **Download Link** |
|
| 33 |
+
| :------------------: | :---------------: | :-------------------: | :----------------: | :------------------------------------------------------------------------------: |
|
| 34 |
+
| Kimi-Linear-Base | 48B | 3B | 1M | [🤗 Hugging Face](https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Base) |
|
| 35 |
+
| Kimi-Linear-Instruct | 48B | 3B | 1M | [🤗 Hugging Face](https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct) |
|
| 36 |
+
|
| 37 |
+
## Key Features
|
| 38 |
+
|
| 39 |
+
- **Kimi Delta Attention (KDA):** A linear attention mechanism that refines the gated delta rule with finegrained gating.
|
| 40 |
+
- **Hybrid Architecture:** A 3:1 KDA-to-global MLA ratio reduces memory usage while maintaining or surpassing the quality of full attention.
|
| 41 |
+
- **Superior Performance:** Outperforms full attention in a variety of tasks, including long-context and RL-style benchmarks on 1.4T token training runs with fair comparisons.
|
| 42 |
+
- **High Throughput:** Achieves up to 6× faster decoding and significantly reduces time per output token (TPOT).
|
| 43 |
+
|
| 44 |
+
<div align="center">
|
| 45 |
+
<img width="60%" src="figures/arch.png">
|
| 46 |
+
</div>
|
| 47 |
+
|
| 48 |
+
## Usage
|
| 49 |
+
|
| 50 |
+
### Inference with Hugging Face Transformers
|
| 51 |
+
|
| 52 |
+
To use the Kimi Linear model, we recommend the following environment:
|
| 53 |
+
|
| 54 |
+
* `python` >= 3.10
|
| 55 |
+
* `torch` >= 2.6
|
| 56 |
+
* `fla-core` >= 0.4.0
|
| 57 |
+
|
| 58 |
+
```shell
|
| 59 |
+
pip install -U fla-core
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Example Code:
|
| 63 |
+
```py
|
| 64 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 65 |
+
|
| 66 |
+
model_name = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
|
| 67 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 68 |
+
model_name,
|
| 69 |
+
torch_dtype="auto",
|
| 70 |
+
device_map="auto",
|
| 71 |
+
trust_remote_code=True
|
| 72 |
+
)
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 74 |
+
|
| 75 |
+
messages = [
|
| 76 |
+
{"role": "system", "content": "You are a helpful assistant provided by Moonshot-AI."},
|
| 77 |
+
{"role": "user", "content": "Is 123 a prime?"}
|
| 78 |
+
]
|
| 79 |
+
input_ids = tokenizer.apply_chat_template(
|
| 80 |
+
messages,
|
| 81 |
+
add_generation_prompt=True,
|
| 82 |
+
return_tensors="pt"
|
| 83 |
+
).to(model.device)
|
| 84 |
+
generated_ids = model.generate(inputs=input_ids, max_new_tokens=500)
|
| 85 |
+
response = tokenizer.batch_decode(generated_ids)[0]
|
| 86 |
+
print(response)
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### Deployment
|
| 90 |
+
|
| 91 |
+
For deployment, you can use the latest vllm to create an OpenAI-compatible API endpoint.
|
| 92 |
+
|
| 93 |
+
```sh
|
| 94 |
+
vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct \
|
| 95 |
+
--port 8000 \
|
| 96 |
+
--tensor-parallel-size 4 \
|
| 97 |
+
--max-model-len 1048576 \
|
| 98 |
+
--trust-remote-code
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### Citation
|
| 102 |
+
|
| 103 |
+
If you found our work useful, please cite
|
| 104 |
+
```bibtex
|
| 105 |
+
@misc{team2025kimi,
|
| 106 |
+
title = {Kimi Linear: An Expressive, Efficient Attention Architecture},
|
| 107 |
+
author = {Zhang, Yu and Lin, Zongyu and Yao, Xingcheng and Hu, Jiaxi and Meng, Fanqing and Liu, Chengyin and Men, Xin and Yang, Songlin and Li, Zhiyuan and Li, Wentao and Lu, Enzhe and Liu, Weizhou and Chen, Yanru and Xu, Weixin and Yu, Longhui and Wang, Yejie and Fan, Yu and Zhong, Longguang and Yuan, Enming and Zhang, Dehao and Zhang, Yizhi and T. Liu, Y. and Wang, Haiming and Fang, Shengjun and He, Weiran and Liu, Shaowei and Li, Yiwei and Su, Jianlin and Qiu, Jiezhong and Pang, Bo and Yan, Junjie and Jiang, Zhejun and Huang, Weixiao and Yin, Bohong and You, Jiacheng and Wei, Chu and Wang, Zhengtao and Hong, Chao and Chen, Yutian and Chen, Guanduo and Wang, Yucheng and Zheng, Huabin and Wang, Feng and Liu, Yibo and Dong, Mengnan and Zhang, Zheng and Pan, Siyuan and Wu, Wenhao and Wu, Yuhao and Guan, Longyu and Tao, Jiawen and Fu, Guohong and Xu, Xinran and Wang, Yuzhi and Lai, Guokun and Wu, Yuxin and Zhou, Xinyu and Yang, Zhilin and Du, Yulun},
|
| 108 |
+
year = {2025},
|
| 109 |
+
eprint = {2510.26692},
|
| 110 |
+
archivePrefix = {arXiv},
|
| 111 |
+
primaryClass = {cs.CL}
|
| 112 |
+
}
|
| 113 |
+
```
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% macro render_content(msg) -%}
|
| 2 |
+
{%- set c = msg.get('content') -%}
|
| 3 |
+
{%- if c is string -%}
|
| 4 |
+
{{ c }}
|
| 5 |
+
{%- elif c is not none -%}
|
| 6 |
+
{% for content in c -%}
|
| 7 |
+
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
| 8 |
+
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
|
| 9 |
+
{% else -%}
|
| 10 |
+
{{ content['text'] }}
|
| 11 |
+
{%- endif -%}
|
| 12 |
+
{%- endfor -%}
|
| 13 |
+
{%- endif -%}
|
| 14 |
+
{%- endmacro %}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
{%- if tools -%}
|
| 18 |
+
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>
|
| 19 |
+
{%- endif -%}
|
| 20 |
+
{% for message in messages %}
|
| 21 |
+
{%- set role_name = message.get('name') or message['role'] -%}
|
| 22 |
+
{%- if message['role'] == 'user' -%}
|
| 23 |
+
<|im_user|>{{role_name}}<|im_middle|>
|
| 24 |
+
{%- elif message['role'] == 'assistant' -%}
|
| 25 |
+
<|im_assistant|>{{role_name}}<|im_middle|>
|
| 26 |
+
{%- else -%}
|
| 27 |
+
<|im_system|>{{role_name}}<|im_middle|>
|
| 28 |
+
{%- endif -%}
|
| 29 |
+
|
| 30 |
+
{%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
|
| 31 |
+
{{render_content(message)}}<|tool_calls_section_begin|>
|
| 32 |
+
{%- for tool_call in message['tool_calls'] -%}
|
| 33 |
+
{%- set formatted_id = tool_call['id'] -%}
|
| 34 |
+
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
|
| 35 |
+
{%- endfor -%}
|
| 36 |
+
<|tool_calls_section_end|>
|
| 37 |
+
{%- elif message['role'] == 'tool' -%}
|
| 38 |
+
{%- set tool_call_id = message.tool_call_id -%}
|
| 39 |
+
## Return of {{ tool_call_id }}
|
| 40 |
+
{{render_content(message)}}
|
| 41 |
+
{%- elif message['content'] is not none -%}
|
| 42 |
+
{{render_content(message)}}
|
| 43 |
+
{%- endif -%}
|
| 44 |
+
<|im_end|>
|
| 45 |
+
{%- endfor -%}
|
| 46 |
+
{%- if add_generation_prompt -%}
|
| 47 |
+
<|im_assistant|>assistant<|im_middle|>
|
| 48 |
+
{%- endif -%}
|
config.json
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"KimiLinearForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_kimi.KimiLinearConfig",
|
| 7 |
+
"AutoModel": "modeling_kimi.KimiLinearModel",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_kimi.KimiLinearForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"bos_token_id": 163584,
|
| 11 |
+
"dtype": "bfloat16",
|
| 12 |
+
"eos_token_id": 163586,
|
| 13 |
+
"first_k_dense_replace": 1,
|
| 14 |
+
"head_dim": 72,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 2304,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 9216,
|
| 19 |
+
"kv_lora_rank": 512,
|
| 20 |
+
"linear_attn_config": {
|
| 21 |
+
"full_attn_layers": [
|
| 22 |
+
4,
|
| 23 |
+
8,
|
| 24 |
+
12,
|
| 25 |
+
16,
|
| 26 |
+
20,
|
| 27 |
+
24,
|
| 28 |
+
27
|
| 29 |
+
],
|
| 30 |
+
"head_dim": 128,
|
| 31 |
+
"kda_layers": [
|
| 32 |
+
1,
|
| 33 |
+
2,
|
| 34 |
+
3,
|
| 35 |
+
5,
|
| 36 |
+
6,
|
| 37 |
+
7,
|
| 38 |
+
9,
|
| 39 |
+
10,
|
| 40 |
+
11,
|
| 41 |
+
13,
|
| 42 |
+
14,
|
| 43 |
+
15,
|
| 44 |
+
17,
|
| 45 |
+
18,
|
| 46 |
+
19,
|
| 47 |
+
21,
|
| 48 |
+
22,
|
| 49 |
+
23,
|
| 50 |
+
25,
|
| 51 |
+
26
|
| 52 |
+
],
|
| 53 |
+
"num_heads": 32,
|
| 54 |
+
"short_conv_kernel_size": 4
|
| 55 |
+
},
|
| 56 |
+
"mla_use_nope": true,
|
| 57 |
+
"model_max_length": 1048576,
|
| 58 |
+
"model_type": "kimi_linear",
|
| 59 |
+
"moe_intermediate_size": 1024,
|
| 60 |
+
"moe_layer_freq": 1,
|
| 61 |
+
"moe_renormalize": true,
|
| 62 |
+
"moe_router_activation_func": "sigmoid",
|
| 63 |
+
"num_attention_heads": 32,
|
| 64 |
+
"num_expert_group": 1,
|
| 65 |
+
"num_experts": 256,
|
| 66 |
+
"num_experts_per_token": 8,
|
| 67 |
+
"num_hidden_layers": 27,
|
| 68 |
+
"num_key_value_heads": 32,
|
| 69 |
+
"num_nextn_predict_layers": 0,
|
| 70 |
+
"num_shared_experts": 1,
|
| 71 |
+
"pad_token_id": 163839,
|
| 72 |
+
"q_lora_rank": null,
|
| 73 |
+
"qk_nope_head_dim": 128,
|
| 74 |
+
"qk_rope_head_dim": 64,
|
| 75 |
+
"rms_norm_eps": 1e-05,
|
| 76 |
+
"rope_scaling": null,
|
| 77 |
+
"rope_theta": 10000.0,
|
| 78 |
+
"routed_scaling_factor": 2.446,
|
| 79 |
+
"tie_word_embeddings": false,
|
| 80 |
+
"topk_group": 1,
|
| 81 |
+
"transformers_version": "4.57.1",
|
| 82 |
+
"use_cache": true,
|
| 83 |
+
"use_grouped_topk": true,
|
| 84 |
+
"v_head_dim": 128,
|
| 85 |
+
"vocab_size": 163840
|
| 86 |
+
}
|
configuration_kimi.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class KimiLinearConfig(PretrainedConfig):
|
| 8 |
+
model_type = "kimi_linear"
|
| 9 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
model_type="kimi_linear",
|
| 14 |
+
vocab_size=163840,
|
| 15 |
+
hidden_size=4096,
|
| 16 |
+
head_dim=None,
|
| 17 |
+
intermediate_size=11008,
|
| 18 |
+
num_hidden_layers=32,
|
| 19 |
+
num_attention_heads=32,
|
| 20 |
+
num_key_value_heads=None,
|
| 21 |
+
hidden_act="silu",
|
| 22 |
+
initializer_range=0.02,
|
| 23 |
+
rms_norm_eps=1e-6,
|
| 24 |
+
use_cache=True,
|
| 25 |
+
pad_token_id=0,
|
| 26 |
+
bos_token_id=1,
|
| 27 |
+
eos_token_id=2,
|
| 28 |
+
rope_theta=10000.0,
|
| 29 |
+
rope_scaling=None,
|
| 30 |
+
tie_word_embeddings=False,
|
| 31 |
+
moe_intermediate_size: Optional[int] = None,
|
| 32 |
+
moe_renormalize: bool = True,
|
| 33 |
+
moe_router_activation_func: str = "sigmoid",
|
| 34 |
+
num_experts: Optional[int] = None,
|
| 35 |
+
num_experts_per_token: Optional[int] = None,
|
| 36 |
+
num_shared_experts: int = 0,
|
| 37 |
+
routed_scaling_factor: float = 1.0,
|
| 38 |
+
first_k_dense_replace: int = 0,
|
| 39 |
+
moe_layer_freq: int = 1,
|
| 40 |
+
use_grouped_topk: bool = True,
|
| 41 |
+
num_expert_group: int = 1,
|
| 42 |
+
topk_group: int = 1,
|
| 43 |
+
q_lora_rank: Optional[int] = None,
|
| 44 |
+
kv_lora_rank: Optional[int] = None,
|
| 45 |
+
qk_nope_head_dim: Optional[int] = None,
|
| 46 |
+
qk_rope_head_dim: Optional[int] = None,
|
| 47 |
+
v_head_dim: Optional[int] = None,
|
| 48 |
+
mla_use_nope: Optional[bool] = False,
|
| 49 |
+
num_nextn_predict_layers: int = 0,
|
| 50 |
+
linear_attn_config: Optional[dict] = None,
|
| 51 |
+
**kwargs,
|
| 52 |
+
):
|
| 53 |
+
self.model_type = model_type
|
| 54 |
+
self.vocab_size = vocab_size
|
| 55 |
+
self.hidden_size = hidden_size
|
| 56 |
+
self.head_dim = (
|
| 57 |
+
head_dim if head_dim is not None else hidden_size // num_attention_heads
|
| 58 |
+
)
|
| 59 |
+
self.intermediate_size = intermediate_size
|
| 60 |
+
self.num_hidden_layers = num_hidden_layers
|
| 61 |
+
self.num_attention_heads = num_attention_heads
|
| 62 |
+
|
| 63 |
+
# for backward compatibility
|
| 64 |
+
if num_key_value_heads is None:
|
| 65 |
+
num_key_value_heads = num_attention_heads
|
| 66 |
+
|
| 67 |
+
self.num_key_value_heads = num_key_value_heads
|
| 68 |
+
self.hidden_act = hidden_act
|
| 69 |
+
self.initializer_range = initializer_range
|
| 70 |
+
self.rms_norm_eps = rms_norm_eps
|
| 71 |
+
self.use_cache = use_cache
|
| 72 |
+
self.rope_theta = rope_theta
|
| 73 |
+
self.rope_scaling = rope_scaling
|
| 74 |
+
|
| 75 |
+
self.q_lora_rank = q_lora_rank
|
| 76 |
+
self.kv_lora_rank = kv_lora_rank
|
| 77 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
| 78 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
| 79 |
+
self.v_head_dim = v_head_dim
|
| 80 |
+
self.mla_use_nope = mla_use_nope
|
| 81 |
+
# moe config
|
| 82 |
+
self.num_experts = num_experts
|
| 83 |
+
self.num_experts_per_token = num_experts_per_token
|
| 84 |
+
self.moe_renormalize = moe_renormalize
|
| 85 |
+
self.num_shared_experts = num_shared_experts
|
| 86 |
+
self.routed_scaling_factor = routed_scaling_factor
|
| 87 |
+
self.moe_router_activation_func = moe_router_activation_func
|
| 88 |
+
assert self.moe_router_activation_func in ("softmax", "sigmoid")
|
| 89 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 90 |
+
self.first_k_dense_replace = first_k_dense_replace
|
| 91 |
+
self.moe_layer_freq = moe_layer_freq
|
| 92 |
+
self.use_grouped_topk = use_grouped_topk
|
| 93 |
+
self.num_expert_group = num_expert_group
|
| 94 |
+
self.topk_group = topk_group
|
| 95 |
+
self.num_nextn_predict_layers = num_nextn_predict_layers
|
| 96 |
+
|
| 97 |
+
if linear_attn_config is not None:
|
| 98 |
+
assert linear_attn_config["kda_layers"] is not None
|
| 99 |
+
assert linear_attn_config["full_attn_layers"] is not None
|
| 100 |
+
self.linear_attn_config = linear_attn_config
|
| 101 |
+
|
| 102 |
+
super().__init__(
|
| 103 |
+
pad_token_id=pad_token_id,
|
| 104 |
+
bos_token_id=bos_token_id,
|
| 105 |
+
eos_token_id=eos_token_id,
|
| 106 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 107 |
+
**kwargs,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def is_mla(self):
|
| 112 |
+
return (
|
| 113 |
+
self.q_lora_rank is not None
|
| 114 |
+
or self.kv_lora_rank is not None
|
| 115 |
+
or self.qk_nope_head_dim is not None
|
| 116 |
+
or self.qk_rope_head_dim is not None
|
| 117 |
+
or self.v_head_dim is not None
|
| 118 |
+
or self.mla_use_nope is True
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def is_moe(self):
|
| 123 |
+
return self.num_experts is not None
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def is_linear_attn(self) -> bool:
|
| 127 |
+
return not (
|
| 128 |
+
self.linear_attn_config is None
|
| 129 |
+
or (
|
| 130 |
+
isinstance(self.linear_attn_config, dict)
|
| 131 |
+
and self.linear_attn_config["kda_layers"] is not None
|
| 132 |
+
and len(self.linear_attn_config["kda_layers"]) == 0
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def is_kda_layer(self, layer_idx: int):
|
| 137 |
+
return (
|
| 138 |
+
self.linear_attn_config is not None
|
| 139 |
+
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
|
| 140 |
+
)
|
figures/arch.png
ADDED
|
Git LFS Details
|
figures/banner.png
ADDED
|
figures/github.png
ADDED
|
figures/logo.png
ADDED
|
figures/perf_speed.png
ADDED
|
Git LFS Details
|
generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 163584,
|
| 4 |
+
"eos_token_id": 163586,
|
| 5 |
+
"pad_token_id": 163839,
|
| 6 |
+
"transformers_version": "4.57.1"
|
| 7 |
+
}
|
model-00001-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c5c908aa3b86b6486080b577cb7aa8dbe9ca7cb18789653768017e602b61a7f
|
| 3 |
+
size 4999482712
|
model-00002-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6fcb34e9ebe2434f32761c06ef17a465157308e6e583eb7eb70cc25e57cd2cb0
|
| 3 |
+
size 4999923264
|
model-00003-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f35d2a95dd1e3170fd642d0db4d0d07933985ef59041494652092cc27893e231
|
| 3 |
+
size 4997138040
|
model-00004-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eda18b6226777bb9a07584dfa64986ac4f28a26cee3203f16ffb14deef9ef48b
|
| 3 |
+
size 4997148016
|
model-00005-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:652e8a43d493105176807d256af0a5c56e45c6d783e6c8221832918f3425c0a0
|
| 3 |
+
size 4999923296
|
model-00006-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77f5dc551436c934f0991eee0c319e0f33689ea3c10cb8cb8f48acc32238526f
|
| 3 |
+
size 4997138040
|
model-00007-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3980c7efccdc6a27633eb8909afc381cb780cccaa7be0347d1645496ea3eb5a2
|
| 3 |
+
size 4997148128
|
model-00008-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dab91b8eaed9874c75de99a4a08669f520fa3e2c8977175333db552504a1c5d3
|
| 3 |
+
size 4999924384
|
model-00009-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13f6ae84d557682ec4a0fc8b6090d4f89cdd26e5e216445cc9d77a65c7f4c90b
|
| 3 |
+
size 4997139104
|
model-00010-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d58ab0a201e26ff429b9d18678a76f3a3284ad977719e055f94d892133ee247b
|
| 3 |
+
size 4997149016
|
model-00011-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:142aa90317af104b2d9f5a6ae4dc661f4a7f7c152f83d6c2477de8037be92201
|
| 3 |
+
size 4999924408
|
model-00012-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb1d4fe2d94a04898eb8300b5144923e5540a75091ee5f4c8b67936a69d91780
|
| 3 |
+
size 4997139104
|
model-00013-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12fb6f2dea889d460f33f7fbb55f76d7beb468698cf56707f4d77a9ab69461d3
|
| 3 |
+
size 4997148992
|
model-00014-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf178169e0dfdc721492f1a98a5be2ef5f66fd8569039f5a77819641a5a1b32d
|
| 3 |
+
size 4999924440
|
model-00015-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e65460934e8794faeadd6d8cbeffd23fcdbf07d9c61ca92ef97afc95d0ccdaa
|
| 3 |
+
size 4997139104
|
model-00016-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:05cc77846f94d50dc09180f5844f01aee38e489b1fd833d8c7aec6a62214ef03
|
| 3 |
+
size 4997148960
|
model-00017-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eaf238a2f1a971ef311c445309de323992da59287d55759cf2a4a3a85ca6a1cc
|
| 3 |
+
size 4999924472
|
model-00018-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:315cdb6964a975522cdb755cf5eb76b46478346b015113e241f81127ad9e6fd4
|
| 3 |
+
size 4997139104
|
model-00019-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbdc2c77e41baa76a2c2b3ced0a59fe7587e95ca3d1acc75247b88a80dee3041
|
| 3 |
+
size 4999934384
|
model-00020-of-00020.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f1e4a9194d045e01c90ed2697939bcedd533b6aa1f1b97b0ae0a5932e5a4bc7
|
| 3 |
+
size 3280687152
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_kimi.py
ADDED
|
@@ -0,0 +1,1028 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections.abc import Callable
|
| 3 |
+
from typing import Any, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import transformers
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from packaging import version
|
| 10 |
+
from torch import nn
|
| 11 |
+
from transformers.activations import ACT2FN
|
| 12 |
+
from transformers.cache_utils import Cache
|
| 13 |
+
from transformers.generation import GenerationMixin
|
| 14 |
+
from transformers.masking_utils import create_causal_mask
|
| 15 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 16 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
| 17 |
+
CausalLMOutputWithPast)
|
| 18 |
+
from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
|
| 19 |
+
PreTrainedModel)
|
| 20 |
+
from transformers.processing_utils import Unpack
|
| 21 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 22 |
+
from transformers.utils import (TransformersKwargs, auto_docstring,
|
| 23 |
+
can_return_tuple, logging)
|
| 24 |
+
from transformers.utils.generic import OutputRecorder, check_model_inputs
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
|
| 28 |
+
from fla.modules import FusedRMSNormGated, ShortConvolution
|
| 29 |
+
from fla.ops.kda import chunk_kda, fused_recurrent_kda
|
| 30 |
+
from fla.ops.kda.gate import fused_kda_gate
|
| 31 |
+
except ImportError:
|
| 32 |
+
raise ImportError("Plese run `pip install -U fla-core`")
|
| 33 |
+
|
| 34 |
+
from .configuration_kimi import KimiLinearConfig
|
| 35 |
+
|
| 36 |
+
assert version.parse(transformers.__version__) >= version.parse("4.56.0"), \
|
| 37 |
+
"Please upgrade transformers to >= 4.56.0"
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class KimiDynamicCache:
|
| 43 |
+
"""
|
| 44 |
+
Dynamic cache for Kimi model.
|
| 45 |
+
Inspired by Qwen3-Next
|
| 46 |
+
"""
|
| 47 |
+
is_compileable = False
|
| 48 |
+
|
| 49 |
+
def __init__(self, config: KimiLinearConfig):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.config = config
|
| 52 |
+
|
| 53 |
+
if config.linear_attn_config is not None:
|
| 54 |
+
self.layer_types = []
|
| 55 |
+
for i in range(config.num_hidden_layers):
|
| 56 |
+
if config.is_kda_layer(i):
|
| 57 |
+
self.layer_types.append("linear_attention")
|
| 58 |
+
else:
|
| 59 |
+
self.layer_types.append("full_attention")
|
| 60 |
+
else:
|
| 61 |
+
self.layer_types = ["full_attention"] * config.num_hidden_layers
|
| 62 |
+
|
| 63 |
+
self.transformer_layers = [
|
| 64 |
+
i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
linear_layers = [i for i in range(
|
| 68 |
+
config.num_hidden_layers) if self.layer_types[i] == "linear_attention"]
|
| 69 |
+
self.last_linear_layer = linear_layers[-1] if linear_layers else -1
|
| 70 |
+
|
| 71 |
+
self.conv_states = [None for _ in range(config.num_hidden_layers)]
|
| 72 |
+
self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
|
| 73 |
+
self.key_cache = [None for _ in range(config.num_hidden_layers)]
|
| 74 |
+
self.value_cache = [None for _ in range(config.num_hidden_layers)]
|
| 75 |
+
|
| 76 |
+
def __len__(self):
|
| 77 |
+
return len(self.layer_types)
|
| 78 |
+
|
| 79 |
+
def update(
|
| 80 |
+
self,
|
| 81 |
+
key_states: torch.Tensor,
|
| 82 |
+
value_states: torch.Tensor,
|
| 83 |
+
layer_idx: int,
|
| 84 |
+
cache_kwargs: Optional[dict[str, Any]] = None,
|
| 85 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 86 |
+
if self.key_cache[layer_idx] is None:
|
| 87 |
+
self.key_cache[layer_idx] = key_states
|
| 88 |
+
self.value_cache[layer_idx] = value_states
|
| 89 |
+
else:
|
| 90 |
+
self.key_cache[layer_idx] = torch.cat(
|
| 91 |
+
[self.key_cache[layer_idx], key_states], dim=2)
|
| 92 |
+
self.value_cache[layer_idx] = torch.cat(
|
| 93 |
+
[self.value_cache[layer_idx], value_states], dim=2)
|
| 94 |
+
|
| 95 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 96 |
+
|
| 97 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 98 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 99 |
+
for layer_idx in range(len(self.key_cache)):
|
| 100 |
+
if self.key_cache[layer_idx] is not None:
|
| 101 |
+
device = self.key_cache[layer_idx].device
|
| 102 |
+
beam_idx = beam_idx.to(device)
|
| 103 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
| 104 |
+
0, beam_idx)
|
| 105 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
| 106 |
+
0, beam_idx)
|
| 107 |
+
|
| 108 |
+
if self.conv_states[layer_idx] is not None:
|
| 109 |
+
device = self.conv_states[layer_idx][0].device
|
| 110 |
+
beam_idx = beam_idx.to(device)
|
| 111 |
+
q_conv, k_conv, v_conv = self.conv_states[layer_idx]
|
| 112 |
+
self.conv_states[layer_idx] = (
|
| 113 |
+
q_conv.index_select(0, beam_idx),
|
| 114 |
+
k_conv.index_select(0, beam_idx),
|
| 115 |
+
v_conv.index_select(0, beam_idx)
|
| 116 |
+
)
|
| 117 |
+
self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
|
| 118 |
+
0, beam_idx)
|
| 119 |
+
|
| 120 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 121 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 122 |
+
# take any layer that contains cache and not empty tensor
|
| 123 |
+
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
|
| 124 |
+
if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
|
| 125 |
+
return 0
|
| 126 |
+
return self.key_cache[layer_idx].shape[-2]
|
| 127 |
+
|
| 128 |
+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
| 129 |
+
"""
|
| 130 |
+
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
| 131 |
+
the given layer at `layer_idx`.
|
| 132 |
+
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
|
| 133 |
+
"""
|
| 134 |
+
kv_offset = 0
|
| 135 |
+
query_length = cache_position.shape[0]
|
| 136 |
+
past_seen_tokens = self.get_seq_length(layer_idx)
|
| 137 |
+
kv_length = query_length + past_seen_tokens
|
| 138 |
+
return kv_length, kv_offset
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def has_previous_state(self):
|
| 142 |
+
"""We have a previous state if the last linear (conv) layer was already updated."""
|
| 143 |
+
if self.last_linear_layer == -1:
|
| 144 |
+
return False
|
| 145 |
+
return self.conv_states[self.last_linear_layer] is not None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class KimiRMSNorm(nn.Module):
|
| 149 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 150 |
+
"""
|
| 151 |
+
KimiRMSNorm is equivalent to T5LayerNorm
|
| 152 |
+
"""
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 155 |
+
self.variance_epsilon = eps
|
| 156 |
+
|
| 157 |
+
def forward(self, hidden_states):
|
| 158 |
+
input_dtype = hidden_states.dtype
|
| 159 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 160 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 161 |
+
hidden_states = hidden_states * \
|
| 162 |
+
torch.rsqrt(variance + self.variance_epsilon)
|
| 163 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
ALL_LAYERNORM_LAYERS.append(KimiRMSNorm)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class KimiBlockSparseMLP(nn.Module):
|
| 170 |
+
def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.config = config
|
| 173 |
+
self.ffn_dim = config.intermediate_size if intermediate_size is None else intermediate_size
|
| 174 |
+
self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size
|
| 175 |
+
|
| 176 |
+
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate
|
| 177 |
+
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down
|
| 178 |
+
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up
|
| 179 |
+
|
| 180 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 181 |
+
|
| 182 |
+
def forward(self, hidden_states):
|
| 183 |
+
current_hidden_states = self.act_fn(
|
| 184 |
+
self.w1(hidden_states)) * self.w3(hidden_states)
|
| 185 |
+
current_hidden_states = self.w2(current_hidden_states)
|
| 186 |
+
return current_hidden_states
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class KimiMLP(nn.Module):
|
| 190 |
+
def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.config = config
|
| 193 |
+
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
| 194 |
+
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
| 195 |
+
self.gate_proj = nn.Linear(
|
| 196 |
+
self.hidden_size, self.intermediate_size, bias=False)
|
| 197 |
+
self.up_proj = nn.Linear(
|
| 198 |
+
self.hidden_size, self.intermediate_size, bias=False)
|
| 199 |
+
self.down_proj = nn.Linear(
|
| 200 |
+
self.intermediate_size, self.hidden_size, bias=False)
|
| 201 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
down_proj = self.down_proj(self.act_fn(
|
| 205 |
+
self.gate_proj(x)) * self.up_proj(x))
|
| 206 |
+
return down_proj
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 210 |
+
"""
|
| 211 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 212 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 213 |
+
"""
|
| 214 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 215 |
+
if n_rep == 1:
|
| 216 |
+
return hidden_states
|
| 217 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 218 |
+
batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 219 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def eager_attention_forward(
|
| 223 |
+
module: nn.Module,
|
| 224 |
+
query: torch.Tensor,
|
| 225 |
+
key: torch.Tensor,
|
| 226 |
+
value: torch.Tensor,
|
| 227 |
+
attention_mask: Optional[torch.Tensor],
|
| 228 |
+
scaling: float,
|
| 229 |
+
dropout: float = 0.0,
|
| 230 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 231 |
+
):
|
| 232 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 233 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 234 |
+
|
| 235 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 236 |
+
if attention_mask is not None:
|
| 237 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 238 |
+
attn_weights = attn_weights + causal_mask
|
| 239 |
+
|
| 240 |
+
attn_weights = nn.functional.softmax(
|
| 241 |
+
attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 242 |
+
attn_weights = nn.functional.dropout(
|
| 243 |
+
attn_weights, p=dropout, training=module.training)
|
| 244 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 245 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 246 |
+
|
| 247 |
+
return attn_output, attn_weights
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class KimiMLAAttention(nn.Module):
|
| 251 |
+
"""
|
| 252 |
+
Multi-Latent Attention adapted from deepseek-v3
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(self, config: KimiLinearConfig, layer_idx: int):
|
| 256 |
+
nn.Module.__init__(self)
|
| 257 |
+
self.config = config
|
| 258 |
+
self.layer_idx = layer_idx
|
| 259 |
+
self.hidden_size = config.hidden_size
|
| 260 |
+
self.num_heads = config.num_attention_heads
|
| 261 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 262 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 263 |
+
|
| 264 |
+
self.rope_theta = config.rope_theta
|
| 265 |
+
self.attention_dropout = getattr(config, "attention_dropout", 0.0)
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
self.q_lora_rank = config.q_lora_rank
|
| 269 |
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
| 270 |
+
self.kv_lora_rank = config.kv_lora_rank
|
| 271 |
+
self.v_head_dim = config.v_head_dim
|
| 272 |
+
self.qk_nope_head_dim = config.qk_nope_head_dim
|
| 273 |
+
self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
| 274 |
+
self.use_nope = config.mla_use_nope
|
| 275 |
+
self.scaling = self.q_head_dim ** (-0.5)
|
| 276 |
+
except Exception as e:
|
| 277 |
+
raise ValueError(
|
| 278 |
+
f"Kimi MLA config is not found or not properly formatted: {e}")
|
| 279 |
+
|
| 280 |
+
assert self.q_lora_rank is None
|
| 281 |
+
self.q_proj = nn.Linear(
|
| 282 |
+
self.hidden_size, self.num_heads * self.q_head_dim, bias=False,
|
| 283 |
+
)
|
| 284 |
+
self.kv_a_proj_with_mqa = nn.Linear(
|
| 285 |
+
self.hidden_size,
|
| 286 |
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
| 287 |
+
bias=False,
|
| 288 |
+
)
|
| 289 |
+
self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank)
|
| 290 |
+
self.kv_b_proj = nn.Linear(
|
| 291 |
+
self.kv_lora_rank,
|
| 292 |
+
self.num_heads
|
| 293 |
+
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
| 294 |
+
bias=False,
|
| 295 |
+
)
|
| 296 |
+
self.o_proj = nn.Linear(
|
| 297 |
+
self.num_heads * self.v_head_dim,
|
| 298 |
+
self.hidden_size,
|
| 299 |
+
bias=False,
|
| 300 |
+
)
|
| 301 |
+
self.is_causal = True
|
| 302 |
+
assert self.use_nope
|
| 303 |
+
|
| 304 |
+
def forward(
|
| 305 |
+
self,
|
| 306 |
+
hidden_states: torch.Tensor,
|
| 307 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 308 |
+
past_key_values: Optional[Cache] = None,
|
| 309 |
+
**kwargs,
|
| 310 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 311 |
+
batch_size, seq_length = hidden_states.shape[:-1]
|
| 312 |
+
query_shape = (batch_size, seq_length, -1, self.q_head_dim)
|
| 313 |
+
key_shape = (batch_size, seq_length, -1,
|
| 314 |
+
self.qk_nope_head_dim + self.v_head_dim)
|
| 315 |
+
|
| 316 |
+
q_states = self.q_proj(hidden_states)
|
| 317 |
+
q_states = q_states.view(query_shape).transpose(1, 2)
|
| 318 |
+
q_pass, q_rot = torch.split(
|
| 319 |
+
q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
| 320 |
+
|
| 321 |
+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
| 322 |
+
k_pass, k_rot = torch.split(
|
| 323 |
+
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 324 |
+
|
| 325 |
+
k_pass = self.kv_b_proj(self.kv_a_layernorm(
|
| 326 |
+
k_pass)).view(key_shape).transpose(1, 2)
|
| 327 |
+
k_pass, value_states = torch.split(
|
| 328 |
+
k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 329 |
+
|
| 330 |
+
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
| 331 |
+
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
|
| 332 |
+
|
| 333 |
+
query_states = torch.cat((q_pass, q_rot), dim=-1)
|
| 334 |
+
key_states = torch.cat((k_pass, k_rot), dim=-1)
|
| 335 |
+
|
| 336 |
+
if past_key_values is not None:
|
| 337 |
+
key_states, value_states = past_key_values.update(
|
| 338 |
+
key_states, value_states, self.layer_idx)
|
| 339 |
+
|
| 340 |
+
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
| 341 |
+
value_states = F.pad(
|
| 342 |
+
value_states, [0, self.q_head_dim - self.v_head_dim])
|
| 343 |
+
|
| 344 |
+
attention_interface: Callable = eager_attention_forward
|
| 345 |
+
if self.config._attn_implementation != "eager":
|
| 346 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 347 |
+
|
| 348 |
+
attn_output, _ = attention_interface(
|
| 349 |
+
self,
|
| 350 |
+
query_states,
|
| 351 |
+
key_states,
|
| 352 |
+
value_states,
|
| 353 |
+
attention_mask,
|
| 354 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 355 |
+
scaling=self.scaling,
|
| 356 |
+
**kwargs,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
| 360 |
+
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
| 361 |
+
|
| 362 |
+
attn_output = attn_output.reshape(
|
| 363 |
+
batch_size, seq_length, -1).contiguous()
|
| 364 |
+
attn_output = self.o_proj(attn_output)
|
| 365 |
+
return attn_output
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class KimiDeltaAttention(nn.Module):
|
| 369 |
+
def __init__(self, config: KimiLinearConfig, layer_idx: int):
|
| 370 |
+
super().__init__()
|
| 371 |
+
self.config = config
|
| 372 |
+
self.mode = "chunk"
|
| 373 |
+
|
| 374 |
+
self.hidden_size = config.hidden_size
|
| 375 |
+
self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
|
| 376 |
+
self.head_dim = config.linear_attn_config["head_dim"]
|
| 377 |
+
self.num_heads = config.linear_attn_config["num_heads"]
|
| 378 |
+
self.head_k_dim = self.head_dim
|
| 379 |
+
self.num_k_heads = self.num_heads
|
| 380 |
+
|
| 381 |
+
self.layer_idx = layer_idx
|
| 382 |
+
|
| 383 |
+
assert self.mode in [
|
| 384 |
+
'chunk', 'fused_recurrent'], f"Not suppoerted mode `{self.mode}`."
|
| 385 |
+
|
| 386 |
+
projection_k_size = self.head_k_dim * self.num_k_heads
|
| 387 |
+
projection_size = self.head_dim * self.num_heads
|
| 388 |
+
|
| 389 |
+
self.q_proj = nn.Linear(
|
| 390 |
+
self.hidden_size, projection_k_size, bias=False)
|
| 391 |
+
self.k_proj = nn.Linear(
|
| 392 |
+
self.hidden_size, projection_k_size, bias=False)
|
| 393 |
+
self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)
|
| 394 |
+
|
| 395 |
+
self.q_conv1d = ShortConvolution(
|
| 396 |
+
hidden_size=projection_k_size,
|
| 397 |
+
kernel_size=self.conv_size,
|
| 398 |
+
activation='silu',
|
| 399 |
+
)
|
| 400 |
+
self.k_conv1d = ShortConvolution(
|
| 401 |
+
hidden_size=projection_k_size,
|
| 402 |
+
kernel_size=self.conv_size,
|
| 403 |
+
activation='silu'
|
| 404 |
+
)
|
| 405 |
+
self.v_conv1d = ShortConvolution(
|
| 406 |
+
hidden_size=projection_size,
|
| 407 |
+
kernel_size=self.conv_size,
|
| 408 |
+
activation='silu'
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
self.A_log = torch.nn.Parameter(torch.log(torch.empty(
|
| 412 |
+
self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1))
|
| 413 |
+
|
| 414 |
+
self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
|
| 415 |
+
self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
|
| 416 |
+
|
| 417 |
+
self.dt_bias = nn.Parameter(
|
| 418 |
+
torch.empty(projection_size, dtype=torch.float32))
|
| 419 |
+
|
| 420 |
+
self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
|
| 421 |
+
|
| 422 |
+
self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
|
| 423 |
+
self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
|
| 424 |
+
|
| 425 |
+
self.o_norm = FusedRMSNormGated(
|
| 426 |
+
self.head_dim, eps=config.rms_norm_eps, activation='sigmoid')
|
| 427 |
+
self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)
|
| 428 |
+
|
| 429 |
+
def forward(
|
| 430 |
+
self,
|
| 431 |
+
hidden_states: torch.Tensor,
|
| 432 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 433 |
+
cache_params: Optional[KimiDynamicCache] = None,
|
| 434 |
+
**kwargs: Unpack[dict]
|
| 435 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
| 436 |
+
if attention_mask is not None:
|
| 437 |
+
if attention_mask.dim() != 2:
|
| 438 |
+
attention_mask = kwargs.get("padding_mask", None)
|
| 439 |
+
|
| 440 |
+
if attention_mask is not None and attention_mask.dim() != 2:
|
| 441 |
+
raise ValueError(
|
| 442 |
+
"attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
|
| 443 |
+
"(0 = padding). 3D masks are not supported here."
|
| 444 |
+
)
|
| 445 |
+
use_cache = cache_params is not None
|
| 446 |
+
batch_size, q_len, _ = hidden_states.shape
|
| 447 |
+
mode = 'fused_recurrent' if q_len <= 64 else self.mode
|
| 448 |
+
if self.training:
|
| 449 |
+
assert mode == 'chunk', "Only chunk mode is supported in training."
|
| 450 |
+
|
| 451 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
| 452 |
+
indices = None
|
| 453 |
+
if attention_mask is not None:
|
| 454 |
+
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
|
| 455 |
+
hidden_states = index_first_axis(
|
| 456 |
+
rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
|
| 457 |
+
|
| 458 |
+
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
| 459 |
+
recurrent_state = None
|
| 460 |
+
if cache_params is not None:
|
| 461 |
+
if cache_params.conv_states[self.layer_idx] is not None:
|
| 462 |
+
conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[
|
| 463 |
+
self.layer_idx]
|
| 464 |
+
recurrent_state = cache_params.recurrent_states[self.layer_idx]
|
| 465 |
+
q, conv_state_q = self.q_conv1d(
|
| 466 |
+
x=self.q_proj(hidden_states),
|
| 467 |
+
cache=conv_state_q,
|
| 468 |
+
output_final_state=use_cache,
|
| 469 |
+
cu_seqlens=cu_seqlens
|
| 470 |
+
)
|
| 471 |
+
k, conv_state_k = self.k_conv1d(
|
| 472 |
+
x=self.k_proj(hidden_states),
|
| 473 |
+
cache=conv_state_k,
|
| 474 |
+
output_final_state=use_cache,
|
| 475 |
+
cu_seqlens=cu_seqlens
|
| 476 |
+
)
|
| 477 |
+
v, conv_state_v = self.v_conv1d(
|
| 478 |
+
x=self.v_proj(hidden_states),
|
| 479 |
+
cache=conv_state_v,
|
| 480 |
+
output_final_state=use_cache,
|
| 481 |
+
cu_seqlens=cu_seqlens
|
| 482 |
+
)
|
| 483 |
+
g = self.f_b_proj(self.f_a_proj(hidden_states))
|
| 484 |
+
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
|
| 485 |
+
beta = self.b_proj(hidden_states).float().sigmoid()
|
| 486 |
+
|
| 487 |
+
q, k = map(lambda x: rearrange(
|
| 488 |
+
x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
|
| 489 |
+
v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
|
| 490 |
+
|
| 491 |
+
if mode == 'chunk':
|
| 492 |
+
o, recurrent_state = chunk_kda(
|
| 493 |
+
q=q,
|
| 494 |
+
k=k,
|
| 495 |
+
v=v,
|
| 496 |
+
g=g,
|
| 497 |
+
beta=beta,
|
| 498 |
+
initial_state=recurrent_state,
|
| 499 |
+
output_final_state=True,
|
| 500 |
+
use_qk_l2norm_in_kernel=True,
|
| 501 |
+
cu_seqlens=cu_seqlens,
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
o, recurrent_state = fused_recurrent_kda(
|
| 505 |
+
q=q,
|
| 506 |
+
k=k,
|
| 507 |
+
v=v,
|
| 508 |
+
g=g,
|
| 509 |
+
beta=beta,
|
| 510 |
+
initial_state=recurrent_state,
|
| 511 |
+
output_final_state=True,
|
| 512 |
+
use_qk_l2norm_in_kernel=True,
|
| 513 |
+
cu_seqlens=cu_seqlens,
|
| 514 |
+
)
|
| 515 |
+
if cache_params is not None:
|
| 516 |
+
cache_params.recurrent_states[self.layer_idx] = recurrent_state
|
| 517 |
+
cache_params.conv_states[self.layer_idx] = (
|
| 518 |
+
conv_state_q, conv_state_k, conv_state_v)
|
| 519 |
+
|
| 520 |
+
g = self.g_b_proj(self.g_a_proj(hidden_states))
|
| 521 |
+
g = rearrange(g, '... (h d) -> ... h d', d=self.head_dim)
|
| 522 |
+
o = self.o_norm(o, g)
|
| 523 |
+
|
| 524 |
+
o = rearrange(o, 'b t h d -> b t (h d)')
|
| 525 |
+
o = self.o_proj(o)
|
| 526 |
+
if attention_mask is not None:
|
| 527 |
+
o = pad_input(o.squeeze(0), indices, batch_size, q_len)
|
| 528 |
+
|
| 529 |
+
return o
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class KimiMoEGate(nn.Module):
|
| 533 |
+
"""
|
| 534 |
+
MoEGate adapted from Deepseek-V3.
|
| 535 |
+
Parameter correspondences:
|
| 536 |
+
num_experts -> n_routed_experts
|
| 537 |
+
num_experts_per_token -> num_experts_per_tok
|
| 538 |
+
num_expert_group -> n_group
|
| 539 |
+
moe_router_activation_func -> scoring_func
|
| 540 |
+
"""
|
| 541 |
+
|
| 542 |
+
def __init__(self, config: KimiLinearConfig):
|
| 543 |
+
super().__init__()
|
| 544 |
+
self.config = config
|
| 545 |
+
self.top_k = config.num_experts_per_token
|
| 546 |
+
self.num_experts = config.num_experts
|
| 547 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 548 |
+
self.moe_router_activation_func = config.moe_router_activation_func
|
| 549 |
+
self.num_expert_group = getattr(config, "num_expert_group", 1)
|
| 550 |
+
self.topk_group = getattr(config, "topk_group", 1)
|
| 551 |
+
|
| 552 |
+
# topk selection algorithm
|
| 553 |
+
self.moe_renormalize = config.moe_renormalize
|
| 554 |
+
self.gating_dim = config.hidden_size
|
| 555 |
+
self.weight = nn.Parameter(
|
| 556 |
+
torch.empty((self.num_experts, self.gating_dim))
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
self.e_score_correction_bias = nn.Parameter(
|
| 560 |
+
torch.empty((self.num_experts))
|
| 561 |
+
)
|
| 562 |
+
self.reset_parameters()
|
| 563 |
+
|
| 564 |
+
def reset_parameters(self) -> None:
|
| 565 |
+
import torch.nn.init as init
|
| 566 |
+
|
| 567 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 568 |
+
|
| 569 |
+
def forward(self, hidden_states):
|
| 570 |
+
bsz, seq_len, h = hidden_states.shape
|
| 571 |
+
# compute gating score
|
| 572 |
+
hidden_states = hidden_states.view(-1, h)
|
| 573 |
+
logits = F.linear(
|
| 574 |
+
hidden_states.type(torch.float32), self.weight.type(
|
| 575 |
+
torch.float32), None
|
| 576 |
+
)
|
| 577 |
+
if self.moe_router_activation_func == "sigmoid":
|
| 578 |
+
scores = logits.sigmoid()
|
| 579 |
+
elif self.moe_router_activation_func == "softmax":
|
| 580 |
+
scores = logits.softmax(dim=1)
|
| 581 |
+
else:
|
| 582 |
+
raise NotImplementedError(
|
| 583 |
+
f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}"
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# select top-k experts
|
| 587 |
+
assert not self.training
|
| 588 |
+
scores_for_choice = scores.view(bsz * seq_len, -1)
|
| 589 |
+
scores_for_choice += self.e_score_correction_bias.unsqueeze(0)
|
| 590 |
+
group_scores = (
|
| 591 |
+
scores_for_choice.view(
|
| 592 |
+
bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
| 593 |
+
) # [n, num_expert_group]
|
| 594 |
+
group_idx = torch.topk(
|
| 595 |
+
group_scores, k=self.topk_group, dim=-1, sorted=False
|
| 596 |
+
)[
|
| 597 |
+
1
|
| 598 |
+
] # [n, top_k_group]
|
| 599 |
+
group_mask = torch.zeros_like(group_scores) # [n, num_expert_group]
|
| 600 |
+
group_mask.scatter_(1, group_idx, 1) # [n, num_expert_group]
|
| 601 |
+
score_mask = (
|
| 602 |
+
group_mask.unsqueeze(-1)
|
| 603 |
+
.expand(
|
| 604 |
+
bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group
|
| 605 |
+
)
|
| 606 |
+
.reshape(bsz * seq_len, -1)
|
| 607 |
+
) # [n, e]
|
| 608 |
+
tmp_scores = scores_for_choice.masked_fill(
|
| 609 |
+
~score_mask.bool(), 0.0) # [n, e]
|
| 610 |
+
_, topk_idx = torch.topk(
|
| 611 |
+
tmp_scores, k=self.top_k, dim=-1, sorted=False
|
| 612 |
+
)
|
| 613 |
+
topk_weight = scores.gather(1, topk_idx)
|
| 614 |
+
|
| 615 |
+
# norm gate to sum 1
|
| 616 |
+
if self.top_k > 1 and self.moe_renormalize:
|
| 617 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 618 |
+
topk_weight = topk_weight / denominator
|
| 619 |
+
# must multiply the scaling factor
|
| 620 |
+
topk_weight = topk_weight * self.routed_scaling_factor
|
| 621 |
+
|
| 622 |
+
return topk_idx, topk_weight
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class KimiSparseMoeBlock(nn.Module):
|
| 626 |
+
"""
|
| 627 |
+
Adapted from Deepseek-V3's MOE implementation
|
| 628 |
+
The namings are consistent with Kimi's version.
|
| 629 |
+
"""
|
| 630 |
+
|
| 631 |
+
def __init__(self, config: KimiLinearConfig):
|
| 632 |
+
super().__init__()
|
| 633 |
+
self.config = config
|
| 634 |
+
self.hidden_dim = config.hidden_size
|
| 635 |
+
self.num_experts = config.num_experts
|
| 636 |
+
self.top_k = config.num_experts_per_token
|
| 637 |
+
self.moe_renormalize = config.moe_renormalize
|
| 638 |
+
|
| 639 |
+
self.ep_size = 1
|
| 640 |
+
self.experts_per_rank = config.num_experts
|
| 641 |
+
self.ep_rank = 0
|
| 642 |
+
self.experts = nn.ModuleList(
|
| 643 |
+
[
|
| 644 |
+
KimiBlockSparseMLP(
|
| 645 |
+
config, intermediate_size=config.moe_intermediate_size
|
| 646 |
+
)
|
| 647 |
+
for _ in range(config.num_experts)
|
| 648 |
+
]
|
| 649 |
+
)
|
| 650 |
+
self.gate = KimiMoEGate(config)
|
| 651 |
+
if config.num_shared_experts is not None:
|
| 652 |
+
intermediate_size = config.moe_intermediate_size * config.num_shared_experts
|
| 653 |
+
self.shared_experts = KimiMLP(
|
| 654 |
+
config=config, intermediate_size=intermediate_size
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
def forward(self, hidden_states):
|
| 658 |
+
identity = hidden_states
|
| 659 |
+
orig_shape = hidden_states.shape
|
| 660 |
+
topk_idx, topk_weight = self.gate(hidden_states)
|
| 661 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 662 |
+
flat_topk_idx = topk_idx.view(-1)
|
| 663 |
+
if not self.training:
|
| 664 |
+
y = self.moe_infer(hidden_states, topk_idx,
|
| 665 |
+
topk_weight).view(*orig_shape)
|
| 666 |
+
else:
|
| 667 |
+
raise NotImplementedError(
|
| 668 |
+
"Training mode is not supported in KimiSparseMoeBlock")
|
| 669 |
+
if self.config.num_shared_experts is not None:
|
| 670 |
+
y = y + self.shared_experts(identity)
|
| 671 |
+
return y
|
| 672 |
+
|
| 673 |
+
@torch.no_grad()
|
| 674 |
+
def moe_infer(self, x, topk_ids, topk_weight):
|
| 675 |
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
| 676 |
+
cnts.scatter_(1, topk_ids, 1)
|
| 677 |
+
tokens_per_expert = cnts.sum(dim=0)
|
| 678 |
+
idxs = topk_ids.view(-1).argsort()
|
| 679 |
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
| 680 |
+
|
| 681 |
+
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
| 682 |
+
|
| 683 |
+
outputs = []
|
| 684 |
+
start_idx = 0
|
| 685 |
+
for i, num_tokens in enumerate(tokens_per_expert):
|
| 686 |
+
end_idx = start_idx + num_tokens
|
| 687 |
+
if num_tokens == 0:
|
| 688 |
+
continue
|
| 689 |
+
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
| 690 |
+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
| 691 |
+
expert_out = expert(tokens_for_this_expert)
|
| 692 |
+
outputs.append(expert_out)
|
| 693 |
+
start_idx = end_idx
|
| 694 |
+
|
| 695 |
+
outs = torch.cat(outputs, dim=0) if len(
|
| 696 |
+
outputs) else sorted_tokens.new_empty(0)
|
| 697 |
+
|
| 698 |
+
new_x = torch.empty_like(outs)
|
| 699 |
+
new_x[idxs] = outs
|
| 700 |
+
final_out = (
|
| 701 |
+
new_x.view(*topk_ids.shape, -1)
|
| 702 |
+
.type(topk_weight.dtype)
|
| 703 |
+
.mul_(topk_weight.unsqueeze(dim=-1))
|
| 704 |
+
.sum(dim=1)
|
| 705 |
+
.type(new_x.dtype)
|
| 706 |
+
)
|
| 707 |
+
return final_out
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
class KimiDecoderLayer(nn.Module):
|
| 711 |
+
def __init__(self, config: KimiLinearConfig, layer_idx: int):
|
| 712 |
+
super().__init__()
|
| 713 |
+
self.hidden_size = config.hidden_size
|
| 714 |
+
self.config = config
|
| 715 |
+
if config.is_kda_layer(layer_idx):
|
| 716 |
+
self.is_linear_attn = True
|
| 717 |
+
self.self_attn = KimiDeltaAttention(
|
| 718 |
+
config=config, layer_idx=layer_idx)
|
| 719 |
+
elif config.is_mla:
|
| 720 |
+
self.is_linear_attn = False
|
| 721 |
+
self.self_attn = KimiMLAAttention(
|
| 722 |
+
config=config, layer_idx=layer_idx)
|
| 723 |
+
else:
|
| 724 |
+
raise NotImplementedError
|
| 725 |
+
if (
|
| 726 |
+
config.num_experts is not None
|
| 727 |
+
and layer_idx >= config.first_k_dense_replace
|
| 728 |
+
and layer_idx % getattr(config, "moe_layer_freq", 1) == 0
|
| 729 |
+
):
|
| 730 |
+
self.block_sparse_moe = KimiSparseMoeBlock(config)
|
| 731 |
+
else:
|
| 732 |
+
self.mlp = KimiMLP(config)
|
| 733 |
+
self.input_layernorm = KimiRMSNorm(
|
| 734 |
+
config.hidden_size, eps=config.rms_norm_eps)
|
| 735 |
+
self.post_attention_layernorm = KimiRMSNorm(
|
| 736 |
+
config.hidden_size, eps=config.rms_norm_eps)
|
| 737 |
+
|
| 738 |
+
def forward(
|
| 739 |
+
self,
|
| 740 |
+
hidden_states: torch.Tensor,
|
| 741 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 742 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 743 |
+
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
| 744 |
+
output_attentions: Optional[bool] = False,
|
| 745 |
+
use_cache: Optional[bool] = False,
|
| 746 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 747 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 748 |
+
"""
|
| 749 |
+
Args:
|
| 750 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 751 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 752 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 753 |
+
output_attentions (`bool`, *optional*):
|
| 754 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 755 |
+
returned tensors for more detail.
|
| 756 |
+
use_cache (`bool`, *optional*):
|
| 757 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 758 |
+
(see `past_key_values`).
|
| 759 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 760 |
+
"""
|
| 761 |
+
|
| 762 |
+
residual = hidden_states
|
| 763 |
+
|
| 764 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 765 |
+
|
| 766 |
+
# Self Attention
|
| 767 |
+
if self.is_linear_attn is False:
|
| 768 |
+
hidden_states = self.self_attn(
|
| 769 |
+
hidden_states=hidden_states,
|
| 770 |
+
attention_mask=attention_mask,
|
| 771 |
+
position_ids=position_ids,
|
| 772 |
+
past_key_values=past_key_values,
|
| 773 |
+
output_attentions=output_attentions,
|
| 774 |
+
use_cache=use_cache,
|
| 775 |
+
**kwargs,
|
| 776 |
+
)
|
| 777 |
+
else:
|
| 778 |
+
hidden_states = self.self_attn(
|
| 779 |
+
hidden_states=hidden_states,
|
| 780 |
+
attention_mask=attention_mask,
|
| 781 |
+
cache_params=past_key_values,
|
| 782 |
+
output_attentions=output_attentions,
|
| 783 |
+
use_cache=use_cache,
|
| 784 |
+
**kwargs,
|
| 785 |
+
)
|
| 786 |
+
hidden_states = residual + hidden_states
|
| 787 |
+
|
| 788 |
+
# Fully Connected
|
| 789 |
+
residual = hidden_states
|
| 790 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 791 |
+
if hasattr(self, "block_sparse_moe"):
|
| 792 |
+
hidden_states = self.block_sparse_moe(hidden_states)
|
| 793 |
+
else:
|
| 794 |
+
hidden_states = self.mlp(hidden_states)
|
| 795 |
+
hidden_states = residual + hidden_states
|
| 796 |
+
|
| 797 |
+
return hidden_states
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
class KimiPreTrainedModel(PreTrainedModel):
|
| 801 |
+
config_class = KimiLinearConfig
|
| 802 |
+
base_model_prefix = "model"
|
| 803 |
+
supports_gradient_checkpointing = True
|
| 804 |
+
_no_split_modules = ["KimiDecoderLayer"]
|
| 805 |
+
_skip_keys_device_placement = "past_key_values"
|
| 806 |
+
_supports_flash_attn_2 = True
|
| 807 |
+
_can_record_outputs = {
|
| 808 |
+
"router_logits": OutputRecorder(KimiBlockSparseMLP, index=1),
|
| 809 |
+
"hidden_states": KimiDecoderLayer,
|
| 810 |
+
"attentions": KimiMLAAttention,
|
| 811 |
+
}
|
| 812 |
+
_is_stateful = True
|
| 813 |
+
|
| 814 |
+
def _init_weights(self, module):
|
| 815 |
+
std = self.config.initializer_range
|
| 816 |
+
if isinstance(module, nn.Linear):
|
| 817 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 818 |
+
if module.bias is not None:
|
| 819 |
+
module.bias.data.zero_()
|
| 820 |
+
elif isinstance(module, nn.Embedding):
|
| 821 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 822 |
+
if module.padding_idx is not None:
|
| 823 |
+
module.weight.data[module.padding_idx].zero_()
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
class KimiLinearModel(KimiPreTrainedModel):
|
| 827 |
+
def __init__(self, config: KimiLinearConfig):
|
| 828 |
+
super().__init__(config)
|
| 829 |
+
self.padding_idx = config.pad_token_id
|
| 830 |
+
self.vocab_size = config.vocab_size
|
| 831 |
+
|
| 832 |
+
self.embed_tokens = nn.Embedding(
|
| 833 |
+
config.vocab_size, config.hidden_size, self.padding_idx)
|
| 834 |
+
self.layers = nn.ModuleList([KimiDecoderLayer(
|
| 835 |
+
config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
| 836 |
+
self.norm = KimiRMSNorm(
|
| 837 |
+
config.hidden_size, eps=config.rms_norm_eps)
|
| 838 |
+
|
| 839 |
+
if getattr(config, "_attn_implementation", None) is not None:
|
| 840 |
+
if config._attn_implementation != "flash_attention_2":
|
| 841 |
+
logger.warning_once(
|
| 842 |
+
f"Ignoring the provided attention implementation {config._attn_implementation}")
|
| 843 |
+
logger.warning_once("Using flash_attention_2 backend instead.")
|
| 844 |
+
config._attn_implementation = "flash_attention_2"
|
| 845 |
+
else:
|
| 846 |
+
config._attn_implementation = "flash_attention_2"
|
| 847 |
+
|
| 848 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 849 |
+
self.gradient_checkpointing = False
|
| 850 |
+
# Initialize weights and apply final processing
|
| 851 |
+
self.post_init()
|
| 852 |
+
|
| 853 |
+
def _update_linear_attn_mask(self, attention_mask, cache_position):
|
| 854 |
+
"""
|
| 855 |
+
NOTE: Left-padding is used for linear attention mask.
|
| 856 |
+
No need for zeroing states when
|
| 857 |
+
1. Cached forward
|
| 858 |
+
2. Attending to all inputs
|
| 859 |
+
"""
|
| 860 |
+
linear_attn_mask = attention_mask
|
| 861 |
+
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
|
| 862 |
+
linear_attn_mask = None
|
| 863 |
+
return linear_attn_mask
|
| 864 |
+
|
| 865 |
+
@check_model_inputs
|
| 866 |
+
@auto_docstring
|
| 867 |
+
def forward(
|
| 868 |
+
self,
|
| 869 |
+
input_ids: torch.LongTensor = None,
|
| 870 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 871 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 872 |
+
past_key_values: Optional[Cache] = None,
|
| 873 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 874 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 875 |
+
use_cache: Optional[bool] = None,
|
| 876 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 877 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 878 |
+
|
| 879 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 880 |
+
|
| 881 |
+
if (input_ids is None) and (inputs_embeds is None):
|
| 882 |
+
raise ValueError(
|
| 883 |
+
"You must specify exactly one of input_ids or inputs_embeds")
|
| 884 |
+
|
| 885 |
+
# Get inputs_embeds
|
| 886 |
+
if inputs_embeds is None:
|
| 887 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 888 |
+
|
| 889 |
+
if use_cache and past_key_values is None:
|
| 890 |
+
past_key_values = KimiDynamicCache(config=self.config)
|
| 891 |
+
|
| 892 |
+
if cache_position is None:
|
| 893 |
+
past_seen_tokens = past_key_values.get_seq_length(
|
| 894 |
+
) if past_key_values is not None else 0
|
| 895 |
+
cache_position: torch.Tensor = torch.arange(
|
| 896 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
if position_ids is None:
|
| 900 |
+
position_ids = cache_position.unsqueeze(0)
|
| 901 |
+
|
| 902 |
+
causal_mask = create_causal_mask(
|
| 903 |
+
config=self.config,
|
| 904 |
+
input_embeds=inputs_embeds,
|
| 905 |
+
attention_mask=attention_mask,
|
| 906 |
+
cache_position=cache_position,
|
| 907 |
+
past_key_values=past_key_values,
|
| 908 |
+
position_ids=position_ids,
|
| 909 |
+
)
|
| 910 |
+
linear_attn_mask = self._update_linear_attn_mask(
|
| 911 |
+
attention_mask, cache_position)
|
| 912 |
+
|
| 913 |
+
hidden_states = inputs_embeds
|
| 914 |
+
if past_key_values is not None:
|
| 915 |
+
assert isinstance(past_key_values, KimiDynamicCache)
|
| 916 |
+
|
| 917 |
+
for decoder_layer in self.layers:
|
| 918 |
+
layer_mask = linear_attn_mask if decoder_layer.is_linear_attn else causal_mask
|
| 919 |
+
|
| 920 |
+
hidden_states = decoder_layer(
|
| 921 |
+
hidden_states,
|
| 922 |
+
attention_mask=layer_mask,
|
| 923 |
+
past_key_values=past_key_values,
|
| 924 |
+
cache_position=cache_position,
|
| 925 |
+
**kwargs,
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
hidden_states = self.norm(hidden_states)
|
| 929 |
+
|
| 930 |
+
return BaseModelOutputWithPast(
|
| 931 |
+
last_hidden_state=hidden_states,
|
| 932 |
+
past_key_values=past_key_values,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):
|
| 937 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 938 |
+
|
| 939 |
+
def __init__(self, config):
|
| 940 |
+
super().__init__(config)
|
| 941 |
+
self.model = KimiLinearModel(config)
|
| 942 |
+
self.vocab_size = config.vocab_size
|
| 943 |
+
self.lm_head = nn.Linear(
|
| 944 |
+
config.hidden_size, config.vocab_size, bias=False)
|
| 945 |
+
|
| 946 |
+
# Initialize weights and apply final processing
|
| 947 |
+
self.post_init()
|
| 948 |
+
|
| 949 |
+
@can_return_tuple
|
| 950 |
+
@auto_docstring
|
| 951 |
+
def forward(
|
| 952 |
+
self,
|
| 953 |
+
input_ids: torch.LongTensor = None,
|
| 954 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 955 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 956 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 957 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 958 |
+
labels: Optional[torch.LongTensor] = None,
|
| 959 |
+
use_cache: Optional[bool] = None,
|
| 960 |
+
output_attentions: Optional[bool] = None,
|
| 961 |
+
output_hidden_states: Optional[bool] = None,
|
| 962 |
+
generation_mode: Optional[bool] = None,
|
| 963 |
+
return_dict: Optional[bool] = None,
|
| 964 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 965 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 966 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 967 |
+
r"""
|
| 968 |
+
Args:
|
| 969 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 970 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 971 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 972 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 973 |
+
|
| 974 |
+
Returns:
|
| 975 |
+
|
| 976 |
+
Example:
|
| 977 |
+
|
| 978 |
+
```python
|
| 979 |
+
>>> from transformers import AutoTokenizer, KimiLinearForCausalLM
|
| 980 |
+
|
| 981 |
+
>>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 982 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 983 |
+
|
| 984 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 985 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 986 |
+
|
| 987 |
+
>>> # Generate
|
| 988 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 989 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 990 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 991 |
+
```"""
|
| 992 |
+
|
| 993 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 994 |
+
output_hidden_states = (
|
| 995 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 996 |
+
)
|
| 997 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 998 |
+
|
| 999 |
+
outputs = self.model(
|
| 1000 |
+
input_ids=input_ids,
|
| 1001 |
+
attention_mask=attention_mask,
|
| 1002 |
+
position_ids=position_ids,
|
| 1003 |
+
past_key_values=past_key_values,
|
| 1004 |
+
inputs_embeds=inputs_embeds,
|
| 1005 |
+
use_cache=use_cache,
|
| 1006 |
+
output_attentions=output_attentions,
|
| 1007 |
+
output_hidden_states=output_hidden_states,
|
| 1008 |
+
return_dict=return_dict,
|
| 1009 |
+
cache_position=cache_position,
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
logits = outputs[0]
|
| 1013 |
+
if generation_mode:
|
| 1014 |
+
logits = logits[:, -1:]
|
| 1015 |
+
logits = self.lm_head(logits)
|
| 1016 |
+
|
| 1017 |
+
loss = None
|
| 1018 |
+
if labels is not None:
|
| 1019 |
+
loss = self.loss_function(
|
| 1020 |
+
logits, labels, self.vocab_size, **kwargs)
|
| 1021 |
+
|
| 1022 |
+
return CausalLMOutputWithPast(
|
| 1023 |
+
loss=loss,
|
| 1024 |
+
logits=logits,
|
| 1025 |
+
past_key_values=outputs.past_key_values,
|
| 1026 |
+
hidden_states=outputs.hidden_states,
|
| 1027 |
+
attentions=outputs.attentions,
|
| 1028 |
+
)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"[extra_id_0]",
|
| 4 |
+
"[extra_id_1]",
|
| 5 |
+
"[extra_id_2]",
|
| 6 |
+
"[extra_id_3]",
|
| 7 |
+
"[start_header_id]",
|
| 8 |
+
"[end_header_id]",
|
| 9 |
+
"[extra_id_4]",
|
| 10 |
+
"[EOT]",
|
| 11 |
+
"[extra_id_5]",
|
| 12 |
+
"[extra_id_6]",
|
| 13 |
+
"[extra_id_7]",
|
| 14 |
+
"[extra_id_8]",
|
| 15 |
+
"[extra_id_9]",
|
| 16 |
+
"[extra_id_10]",
|
| 17 |
+
"[extra_id_11]",
|
| 18 |
+
"[extra_id_12]",
|
| 19 |
+
"[extra_id_13]",
|
| 20 |
+
"[extra_id_14]",
|
| 21 |
+
"[extra_id_15]",
|
| 22 |
+
"[extra_id_16]",
|
| 23 |
+
"[extra_id_17]",
|
| 24 |
+
"[extra_id_18]",
|
| 25 |
+
"[extra_id_19]",
|
| 26 |
+
"[extra_id_20]",
|
| 27 |
+
"[extra_id_21]",
|
| 28 |
+
"[extra_id_22]",
|
| 29 |
+
"[extra_id_23]",
|
| 30 |
+
"[extra_id_24]",
|
| 31 |
+
"[extra_id_25]",
|
| 32 |
+
"[extra_id_26]",
|
| 33 |
+
"[extra_id_27]",
|
| 34 |
+
"[extra_id_28]",
|
| 35 |
+
"[extra_id_29]",
|
| 36 |
+
"[extra_id_30]",
|
| 37 |
+
"[extra_id_31]",
|
| 38 |
+
"[extra_id_32]",
|
| 39 |
+
"[extra_id_33]",
|
| 40 |
+
"[extra_id_34]",
|
| 41 |
+
"[extra_id_35]",
|
| 42 |
+
"[extra_id_36]",
|
| 43 |
+
"[extra_id_37]",
|
| 44 |
+
"[extra_id_38]",
|
| 45 |
+
"[extra_id_39]",
|
| 46 |
+
"[extra_id_40]",
|
| 47 |
+
"[extra_id_41]",
|
| 48 |
+
"[extra_id_42]",
|
| 49 |
+
"[extra_id_43]",
|
| 50 |
+
"[extra_id_44]",
|
| 51 |
+
"[extra_id_45]",
|
| 52 |
+
"[extra_id_46]",
|
| 53 |
+
"[extra_id_47]",
|
| 54 |
+
"[extra_id_48]",
|
| 55 |
+
"[extra_id_49]",
|
| 56 |
+
"[extra_id_50]",
|
| 57 |
+
"[extra_id_51]",
|
| 58 |
+
"[extra_id_52]",
|
| 59 |
+
"[extra_id_53]",
|
| 60 |
+
"[extra_id_54]",
|
| 61 |
+
"[extra_id_55]",
|
| 62 |
+
"[extra_id_56]",
|
| 63 |
+
"[extra_id_57]",
|
| 64 |
+
"[extra_id_58]",
|
| 65 |
+
"[extra_id_59]",
|
| 66 |
+
"[extra_id_60]",
|
| 67 |
+
"[extra_id_61]",
|
| 68 |
+
"[extra_id_62]",
|
| 69 |
+
"[extra_id_63]",
|
| 70 |
+
"[extra_id_64]",
|
| 71 |
+
"[extra_id_65]",
|
| 72 |
+
"[extra_id_66]",
|
| 73 |
+
"[extra_id_67]",
|
| 74 |
+
"[extra_id_68]",
|
| 75 |
+
"[extra_id_69]",
|
| 76 |
+
"[extra_id_70]",
|
| 77 |
+
"[extra_id_71]",
|
| 78 |
+
"[extra_id_72]",
|
| 79 |
+
"[extra_id_73]",
|
| 80 |
+
"[extra_id_74]",
|
| 81 |
+
"[extra_id_75]",
|
| 82 |
+
"[extra_id_76]",
|
| 83 |
+
"[extra_id_77]",
|
| 84 |
+
"[extra_id_78]",
|
| 85 |
+
"[extra_id_79]",
|
| 86 |
+
"[extra_id_80]",
|
| 87 |
+
"[extra_id_81]",
|
| 88 |
+
"[extra_id_82]",
|
| 89 |
+
"[extra_id_83]",
|
| 90 |
+
"[extra_id_84]",
|
| 91 |
+
"[extra_id_85]",
|
| 92 |
+
"[extra_id_86]",
|
| 93 |
+
"[extra_id_87]",
|
| 94 |
+
"[extra_id_88]",
|
| 95 |
+
"[extra_id_89]",
|
| 96 |
+
"[extra_id_90]",
|
| 97 |
+
"[extra_id_91]",
|
| 98 |
+
"[extra_id_92]",
|
| 99 |
+
"[extra_id_93]",
|
| 100 |
+
"[extra_id_94]",
|
| 101 |
+
"[extra_id_95]",
|
| 102 |
+
"[extra_id_96]",
|
| 103 |
+
"[extra_id_97]",
|
| 104 |
+
"[extra_id_98]",
|
| 105 |
+
"[extra_id_99]",
|
| 106 |
+
"[extra_id_100]",
|
| 107 |
+
"[extra_id_101]",
|
| 108 |
+
"[extra_id_102]",
|
| 109 |
+
"[extra_id_103]",
|
| 110 |
+
"[extra_id_104]",
|
| 111 |
+
"[extra_id_105]",
|
| 112 |
+
"[extra_id_106]",
|
| 113 |
+
"[extra_id_107]",
|
| 114 |
+
"[extra_id_108]",
|
| 115 |
+
"[extra_id_109]",
|
| 116 |
+
"[extra_id_110]",
|
| 117 |
+
"[extra_id_111]",
|
| 118 |
+
"[extra_id_112]",
|
| 119 |
+
"[extra_id_113]",
|
| 120 |
+
"[extra_id_114]",
|
| 121 |
+
"[extra_id_115]",
|
| 122 |
+
"[extra_id_116]",
|
| 123 |
+
"[extra_id_117]",
|
| 124 |
+
"[extra_id_118]",
|
| 125 |
+
"[extra_id_119]",
|
| 126 |
+
"[extra_id_120]",
|
| 127 |
+
"[extra_id_121]",
|
| 128 |
+
"[extra_id_122]",
|
| 129 |
+
"[extra_id_123]",
|
| 130 |
+
"[extra_id_124]",
|
| 131 |
+
"[extra_id_125]",
|
| 132 |
+
"[extra_id_126]",
|
| 133 |
+
"[extra_id_127]",
|
| 134 |
+
"[extra_id_128]",
|
| 135 |
+
"[extra_id_129]",
|
| 136 |
+
"[extra_id_130]",
|
| 137 |
+
"[extra_id_131]",
|
| 138 |
+
"[extra_id_132]",
|
| 139 |
+
"[extra_id_133]",
|
| 140 |
+
"[extra_id_134]",
|
| 141 |
+
"[extra_id_135]",
|
| 142 |
+
"[extra_id_136]",
|
| 143 |
+
"[extra_id_137]",
|
| 144 |
+
"[extra_id_138]",
|
| 145 |
+
"[extra_id_139]",
|
| 146 |
+
"[extra_id_140]",
|
| 147 |
+
"[extra_id_141]",
|
| 148 |
+
"[extra_id_142]",
|
| 149 |
+
"[extra_id_143]",
|
| 150 |
+
"[extra_id_144]",
|
| 151 |
+
"[extra_id_145]",
|
| 152 |
+
"[extra_id_146]",
|
| 153 |
+
"[extra_id_147]",
|
| 154 |
+
"[extra_id_148]",
|
| 155 |
+
"[extra_id_149]",
|
| 156 |
+
"[extra_id_150]",
|
| 157 |
+
"[extra_id_151]",
|
| 158 |
+
"[extra_id_152]",
|
| 159 |
+
"[extra_id_153]",
|
| 160 |
+
"[extra_id_154]",
|
| 161 |
+
"[extra_id_155]",
|
| 162 |
+
"[extra_id_156]",
|
| 163 |
+
"[extra_id_157]",
|
| 164 |
+
"[extra_id_158]",
|
| 165 |
+
"[extra_id_159]",
|
| 166 |
+
"[extra_id_160]",
|
| 167 |
+
"[extra_id_161]",
|
| 168 |
+
"[extra_id_162]",
|
| 169 |
+
"[extra_id_163]",
|
| 170 |
+
"[extra_id_164]",
|
| 171 |
+
"[extra_id_165]",
|
| 172 |
+
"[extra_id_166]",
|
| 173 |
+
"[extra_id_167]",
|
| 174 |
+
"[extra_id_168]",
|
| 175 |
+
"[extra_id_169]",
|
| 176 |
+
"[extra_id_170]",
|
| 177 |
+
"[extra_id_171]",
|
| 178 |
+
"[extra_id_172]",
|
| 179 |
+
"[extra_id_173]",
|
| 180 |
+
"[extra_id_174]",
|
| 181 |
+
"[extra_id_175]",
|
| 182 |
+
"[extra_id_176]",
|
| 183 |
+
"[extra_id_177]",
|
| 184 |
+
"[extra_id_178]",
|
| 185 |
+
"[extra_id_179]",
|
| 186 |
+
"[extra_id_180]",
|
| 187 |
+
"[extra_id_181]",
|
| 188 |
+
"[extra_id_182]",
|
| 189 |
+
"[extra_id_183]",
|
| 190 |
+
"[extra_id_184]",
|
| 191 |
+
"[extra_id_185]",
|
| 192 |
+
"[extra_id_186]",
|
| 193 |
+
"[extra_id_187]",
|
| 194 |
+
"[extra_id_188]",
|
| 195 |
+
"[extra_id_189]",
|
| 196 |
+
"[extra_id_190]",
|
| 197 |
+
"[extra_id_191]",
|
| 198 |
+
"[extra_id_192]",
|
| 199 |
+
"[extra_id_193]",
|
| 200 |
+
"[extra_id_194]",
|
| 201 |
+
"[extra_id_195]",
|
| 202 |
+
"[extra_id_196]",
|
| 203 |
+
"[extra_id_197]",
|
| 204 |
+
"[extra_id_198]",
|
| 205 |
+
"[extra_id_199]",
|
| 206 |
+
"[extra_id_200]",
|
| 207 |
+
"[extra_id_201]",
|
| 208 |
+
"[extra_id_202]",
|
| 209 |
+
"[extra_id_203]",
|
| 210 |
+
"[extra_id_204]",
|
| 211 |
+
"[extra_id_205]",
|
| 212 |
+
"[extra_id_206]",
|
| 213 |
+
"[extra_id_207]",
|
| 214 |
+
"[extra_id_208]",
|
| 215 |
+
"[extra_id_209]",
|
| 216 |
+
"[extra_id_210]",
|
| 217 |
+
"[extra_id_211]",
|
| 218 |
+
"[extra_id_212]",
|
| 219 |
+
"[extra_id_213]",
|
| 220 |
+
"[extra_id_214]",
|
| 221 |
+
"[extra_id_215]",
|
| 222 |
+
"[extra_id_216]",
|
| 223 |
+
"[extra_id_217]",
|
| 224 |
+
"[extra_id_218]",
|
| 225 |
+
"[extra_id_219]",
|
| 226 |
+
"[extra_id_220]",
|
| 227 |
+
"[extra_id_221]",
|
| 228 |
+
"[extra_id_222]",
|
| 229 |
+
"[extra_id_223]",
|
| 230 |
+
"[extra_id_224]",
|
| 231 |
+
"[extra_id_225]",
|
| 232 |
+
"[extra_id_226]",
|
| 233 |
+
"[extra_id_227]",
|
| 234 |
+
"[extra_id_228]",
|
| 235 |
+
"[extra_id_229]",
|
| 236 |
+
"[extra_id_230]",
|
| 237 |
+
"[extra_id_231]",
|
| 238 |
+
"[extra_id_232]",
|
| 239 |
+
"[extra_id_233]",
|
| 240 |
+
"[extra_id_234]",
|
| 241 |
+
"[extra_id_235]",
|
| 242 |
+
"[extra_id_236]",
|
| 243 |
+
"[extra_id_237]",
|
| 244 |
+
"[extra_id_238]",
|
| 245 |
+
"[extra_id_239]",
|
| 246 |
+
"[extra_id_240]",
|
| 247 |
+
"[extra_id_241]",
|
| 248 |
+
"[extra_id_242]",
|
| 249 |
+
"[extra_id_243]",
|
| 250 |
+
"[extra_id_244]",
|
| 251 |
+
"[extra_id_245]",
|
| 252 |
+
"[extra_id_246]",
|
| 253 |
+
"[extra_id_247]",
|
| 254 |
+
"[extra_id_248]"
|
| 255 |
+
],
|
| 256 |
+
"bos_token": "[BOS]",
|
| 257 |
+
"eos_token": "[EOS]",
|
| 258 |
+
"pad_token": "[extra_id_250]",
|
| 259 |
+
"unk_token": "[extra_id_249]"
|
| 260 |
+
}
|
tiktoken.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6c497a7469b33ced9c38afb1ad6e47f03f5e5dc05f15930799210ec050c5103
|
| 3 |
+
size 2795286
|
tokenization_kimi.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tiktoken
|
| 3 |
+
|
| 4 |
+
from logging import getLogger
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import (
|
| 7 |
+
cast,
|
| 8 |
+
Tuple,
|
| 9 |
+
Dict,
|
| 10 |
+
Iterator,
|
| 11 |
+
List,
|
| 12 |
+
Union,
|
| 13 |
+
Optional,
|
| 14 |
+
)
|
| 15 |
+
from shutil import copyfile
|
| 16 |
+
from tiktoken.load import load_tiktoken_bpe
|
| 17 |
+
from tokenizers import AddedToken, pre_tokenizers, Regex
|
| 18 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 19 |
+
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = getLogger(__name__)
|
| 24 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TikTokenTokenizer(PreTrainedTokenizer):
|
| 28 |
+
"""
|
| 29 |
+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
|
| 30 |
+
|
| 31 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 32 |
+
this superclass for more information regarding those methods.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vocab_file (`str`):
|
| 36 |
+
The path to the Tiktoken model file.
|
| 37 |
+
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
|
| 38 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 39 |
+
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
|
| 40 |
+
The end of sequence token.
|
| 41 |
+
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
|
| 42 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 43 |
+
token instead. The second to last item in special_tokens.
|
| 44 |
+
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
|
| 45 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 46 |
+
additional_special_tokens (list of `str`, *optional*):
|
| 47 |
+
A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
|
| 48 |
+
skipped when decoding if `skip_special_tokens` is set to `True`.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 52 |
+
|
| 53 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 54 |
+
|
| 55 |
+
special_tokens: Dict[str, int]
|
| 56 |
+
|
| 57 |
+
num_reserved_special_tokens = 256
|
| 58 |
+
|
| 59 |
+
pat_str = "|".join(
|
| 60 |
+
[
|
| 61 |
+
r"""[\p{Han}]+""",
|
| 62 |
+
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
| 63 |
+
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
| 64 |
+
r"""\p{N}{1,3}""",
|
| 65 |
+
r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
|
| 66 |
+
r"""\s*[\r\n]+""",
|
| 67 |
+
r"""\s+(?!\S)""",
|
| 68 |
+
r"""\s+""",
|
| 69 |
+
]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
vocab_file,
|
| 75 |
+
bos_token: Union[str, AddedToken]="[BOS]",
|
| 76 |
+
eos_token: Union[str, AddedToken]="[EOS]",
|
| 77 |
+
unk_token: Union[str, AddedToken, None]=None,
|
| 78 |
+
pad_token: Union[str, AddedToken, None]=None,
|
| 79 |
+
additional_special_tokens: List[str]=None,
|
| 80 |
+
added_tokens_decoder: Optional[dict] = None,
|
| 81 |
+
**kwargs,
|
| 82 |
+
):
|
| 83 |
+
assert os.path.isfile(vocab_file), vocab_file
|
| 84 |
+
|
| 85 |
+
if additional_special_tokens is None:
|
| 86 |
+
additional_special_tokens = [
|
| 87 |
+
"<|im_end|>",
|
| 88 |
+
"<|im_user|>",
|
| 89 |
+
"<|im_assistant|>",
|
| 90 |
+
"<|start_header_id|>",
|
| 91 |
+
"<|end_header_id|>",
|
| 92 |
+
"[EOT]",
|
| 93 |
+
"<|im_system|>",
|
| 94 |
+
"<|im_middle|>",
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
special_tokens_mapping = {
|
| 98 |
+
i: added_tokens_decoder[i].content for i in added_tokens_decoder
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
self.vocab_file = vocab_file
|
| 102 |
+
mergeable_ranks = load_tiktoken_bpe(vocab_file)
|
| 103 |
+
num_base_tokens = len(mergeable_ranks)
|
| 104 |
+
self.special_tokens = {
|
| 105 |
+
special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
|
| 106 |
+
for i in range(
|
| 107 |
+
num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
|
| 108 |
+
)
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
self.model = tiktoken.Encoding(
|
| 114 |
+
name=Path(vocab_file).name,
|
| 115 |
+
pat_str=self.pat_str,
|
| 116 |
+
mergeable_ranks=mergeable_ranks,
|
| 117 |
+
special_tokens=self.special_tokens,
|
| 118 |
+
)
|
| 119 |
+
logger.info(f"Reloaded tiktoken model from {vocab_file}")
|
| 120 |
+
|
| 121 |
+
self.n_words: int = self.model.n_vocab
|
| 122 |
+
# BOS / EOS token IDs
|
| 123 |
+
self.bos_id: int = self.special_tokens[str(bos_token)]
|
| 124 |
+
self.eos_id: int = self.special_tokens[str(eos_token)]
|
| 125 |
+
logger.info(
|
| 126 |
+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.pad_id: int = self.special_tokens[str(pad_token)]
|
| 130 |
+
self.unk_id: int = self.special_tokens[str(unk_token)]
|
| 131 |
+
|
| 132 |
+
self.byte_encoder = bytes_to_unicode()
|
| 133 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 134 |
+
|
| 135 |
+
self.decoder = {}
|
| 136 |
+
for i in range(self.n_words):
|
| 137 |
+
# Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
|
| 138 |
+
decoding = ''.join([
|
| 139 |
+
self.byte_encoder[ord(char)] for char in
|
| 140 |
+
self.model.decode_single_token_bytes(i).decode('latin-1')
|
| 141 |
+
])
|
| 142 |
+
self.decoder[i] = decoding
|
| 143 |
+
|
| 144 |
+
self.encoder = {}
|
| 145 |
+
for i in range(self.n_words):
|
| 146 |
+
if i in self.decoder:
|
| 147 |
+
self.encoder[self.decoder[i]] = i
|
| 148 |
+
|
| 149 |
+
super().__init__(
|
| 150 |
+
bos_token=bos_token,
|
| 151 |
+
eos_token=eos_token,
|
| 152 |
+
unk_token=unk_token,
|
| 153 |
+
pad_token=pad_token,
|
| 154 |
+
additional_special_tokens=additional_special_tokens,
|
| 155 |
+
**kwargs,
|
| 156 |
+
)
|
| 157 |
+
self.all_special_ids_set = set(self.all_special_ids)
|
| 158 |
+
|
| 159 |
+
def encode(
|
| 160 |
+
self,
|
| 161 |
+
text: str,
|
| 162 |
+
allow_special_tokens: bool = True,
|
| 163 |
+
**kwargs
|
| 164 |
+
) -> List[int]:
|
| 165 |
+
"""
|
| 166 |
+
Encodes a string into a list of token IDs.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
text (str): The input string to be encoded.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
list[int]: A list of token IDs.
|
| 173 |
+
"""
|
| 174 |
+
# If there are other args, we should call super().encode because there are a lot of code
|
| 175 |
+
# to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
|
| 176 |
+
# NOTE: our encode method is not compatible with the super().encode method,
|
| 177 |
+
# e.g. split_special_tokens' default is True in our encode method.
|
| 178 |
+
if len(kwargs) > 0:
|
| 179 |
+
logger.warning( f"Calling super().encode with {kwargs}" )
|
| 180 |
+
return super().encode(text, **kwargs)
|
| 181 |
+
|
| 182 |
+
assert type(text) is str
|
| 183 |
+
|
| 184 |
+
# The tiktoken tokenizer can handle <=400k chars without
|
| 185 |
+
# pyo3_runtime.PanicException.
|
| 186 |
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
| 187 |
+
|
| 188 |
+
# https://github.com/openai/tiktoken/issues/195
|
| 189 |
+
# Here we iterate over subsequences and split if we exceed the limit
|
| 190 |
+
# of max consecutive non-whitespace or whitespace characters.
|
| 191 |
+
MAX_NO_WHITESPACES_CHARS = 25_000
|
| 192 |
+
|
| 193 |
+
texts = self.pre_tokenizer_process(text)
|
| 194 |
+
|
| 195 |
+
all_substrs = []
|
| 196 |
+
for text in texts:
|
| 197 |
+
substrs = (
|
| 198 |
+
substr
|
| 199 |
+
for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
|
| 200 |
+
for substr in self._split_whitespaces_or_nonwhitespaces(
|
| 201 |
+
text[i: i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
all_substrs.extend(substrs)
|
| 205 |
+
|
| 206 |
+
t: List[int] = []
|
| 207 |
+
for substr in all_substrs:
|
| 208 |
+
if allow_special_tokens:
|
| 209 |
+
t.extend(
|
| 210 |
+
# we should consider special token as a common token
|
| 211 |
+
self.model.encode(
|
| 212 |
+
substr,
|
| 213 |
+
allowed_special="all",
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
t.extend(
|
| 218 |
+
# we should consider special token as a common token
|
| 219 |
+
self.model.encode(
|
| 220 |
+
substr,
|
| 221 |
+
disallowed_special=(),
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return t
|
| 226 |
+
|
| 227 |
+
def decode(
|
| 228 |
+
self,
|
| 229 |
+
token_ids: Union[int, List[int]],
|
| 230 |
+
**kwargs
|
| 231 |
+
) -> str:
|
| 232 |
+
"""
|
| 233 |
+
Decodes a list of token IDs into a string.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
token_ids (List[int]): The list of token IDs to be decoded.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
str: The decoded string.
|
| 240 |
+
"""
|
| 241 |
+
# If there are other args, we should call super().decode because there are a lot of code
|
| 242 |
+
# to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
|
| 243 |
+
if len(kwargs) > 0:
|
| 244 |
+
return super().decode(token_ids, **kwargs)
|
| 245 |
+
|
| 246 |
+
if type(token_ids) is int:
|
| 247 |
+
token_ids = [token_ids]
|
| 248 |
+
|
| 249 |
+
return self.model.decode(cast(List[int], token_ids))
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def _split_whitespaces_or_nonwhitespaces(
|
| 253 |
+
s: str, max_consecutive_slice_len: int
|
| 254 |
+
) -> Iterator[str]:
|
| 255 |
+
"""
|
| 256 |
+
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
| 257 |
+
consecutive whitespaces or consecutive non-whitespaces.
|
| 258 |
+
"""
|
| 259 |
+
current_slice_len = 0
|
| 260 |
+
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
| 261 |
+
slice_start = 0
|
| 262 |
+
|
| 263 |
+
for i in range(len(s)):
|
| 264 |
+
is_now_space = s[i].isspace()
|
| 265 |
+
|
| 266 |
+
if current_slice_is_space ^ is_now_space:
|
| 267 |
+
current_slice_len = 1
|
| 268 |
+
current_slice_is_space = is_now_space
|
| 269 |
+
else:
|
| 270 |
+
current_slice_len += 1
|
| 271 |
+
if current_slice_len > max_consecutive_slice_len:
|
| 272 |
+
yield s[slice_start:i]
|
| 273 |
+
slice_start = i
|
| 274 |
+
current_slice_len = 1
|
| 275 |
+
yield s[slice_start:]
|
| 276 |
+
|
| 277 |
+
def pre_tokenizer_process(self, text: str) -> List[str]:
|
| 278 |
+
"""
|
| 279 |
+
pre-tokenizes the input text into a list of tokens.
|
| 280 |
+
This method is used to split the input text into smaller chunks for internal processing.
|
| 281 |
+
"""
|
| 282 |
+
return [text]
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
""" ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
|
| 286 |
+
@property
|
| 287 |
+
def vocab_size(self) -> int:
|
| 288 |
+
return self.n_words
|
| 289 |
+
|
| 290 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 291 |
+
return self.encoder
|
| 292 |
+
|
| 293 |
+
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
| 294 |
+
return [
|
| 295 |
+
self.decoder[t]
|
| 296 |
+
for t in self.encode(text)
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 300 |
+
return self.encoder.get(token, self.unk_id)
|
| 301 |
+
|
| 302 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 303 |
+
return self.decoder.get(index)
|
| 304 |
+
|
| 305 |
+
@staticmethod
|
| 306 |
+
def clean_up_tokenization(out_string: str) -> str:
|
| 307 |
+
return out_string
|
| 308 |
+
|
| 309 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 310 |
+
text = ''.join(tokens)
|
| 311 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', 'replace')
|
| 312 |
+
return text
|
| 313 |
+
|
| 314 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 315 |
+
if not os.path.isdir(save_directory):
|
| 316 |
+
raise ValueError(f"vocabulary path ({save_directory}) should be a directory")
|
| 317 |
+
out_vocab_file = os.path.join(
|
| 318 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 322 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 323 |
+
|
| 324 |
+
return (out_vocab_file,)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def apply_chat_template(
|
| 329 |
+
self, conversation, tools: Optional[list[dict]] = None,
|
| 330 |
+
tokenize: bool = False,
|
| 331 |
+
add_generation_prompt: bool = True,
|
| 332 |
+
**kwargs
|
| 333 |
+
):
|
| 334 |
+
tools = deep_sort_dict(tools)
|
| 335 |
+
return super().apply_chat_template(conversation,
|
| 336 |
+
tools=tools,
|
| 337 |
+
tokenize=tokenize,
|
| 338 |
+
add_generation_prompt=add_generation_prompt,
|
| 339 |
+
**kwargs)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def deep_sort_dict(obj: Any) -> Any:
|
| 343 |
+
if isinstance(obj, dict):
|
| 344 |
+
return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
|
| 345 |
+
if isinstance(obj, list):
|
| 346 |
+
return [deep_sort_dict(item) for item in obj]
|
| 347 |
+
return obj
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"163584": {
|
| 4 |
+
"content": "[BOS]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"163585": {
|
| 12 |
+
"content": "[EOS]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"163586": {
|
| 20 |
+
"content": "<|im_end|>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"163587": {
|
| 28 |
+
"content": "<|im_user|>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"163588": {
|
| 36 |
+
"content": "<|im_assistant|>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
},
|
| 43 |
+
"163590": {
|
| 44 |
+
"content": "<|start_header_id|>",
|
| 45 |
+
"lstrip": false,
|
| 46 |
+
"normalized": false,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"single_word": false,
|
| 49 |
+
"special": true
|
| 50 |
+
},
|
| 51 |
+
"163591": {
|
| 52 |
+
"content": "<|end_header_id|>",
|
| 53 |
+
"lstrip": false,
|
| 54 |
+
"normalized": false,
|
| 55 |
+
"rstrip": false,
|
| 56 |
+
"single_word": false,
|
| 57 |
+
"special": true
|
| 58 |
+
},
|
| 59 |
+
"163593": {
|
| 60 |
+
"content": "[EOT]",
|
| 61 |
+
"lstrip": false,
|
| 62 |
+
"normalized": false,
|
| 63 |
+
"rstrip": false,
|
| 64 |
+
"single_word": false,
|
| 65 |
+
"special": true
|
| 66 |
+
},
|
| 67 |
+
"163594": {
|
| 68 |
+
"content": "<|im_system|>",
|
| 69 |
+
"lstrip": false,
|
| 70 |
+
"normalized": false,
|
| 71 |
+
"rstrip": false,
|
| 72 |
+
"single_word": false,
|
| 73 |
+
"special": true
|
| 74 |
+
},
|
| 75 |
+
"163595": {
|
| 76 |
+
"content": "<|tool_calls_section_begin|>",
|
| 77 |
+
"lstrip": false,
|
| 78 |
+
"normalized": false,
|
| 79 |
+
"rstrip": false,
|
| 80 |
+
"single_word": false,
|
| 81 |
+
"special": false
|
| 82 |
+
},
|
| 83 |
+
"163596": {
|
| 84 |
+
"content": "<|tool_calls_section_end|>",
|
| 85 |
+
"lstrip": false,
|
| 86 |
+
"normalized": false,
|
| 87 |
+
"rstrip": false,
|
| 88 |
+
"single_word": false,
|
| 89 |
+
"special": false
|
| 90 |
+
},
|
| 91 |
+
"163597": {
|
| 92 |
+
"content": "<|tool_call_begin|>",
|
| 93 |
+
"lstrip": false,
|
| 94 |
+
"normalized": false,
|
| 95 |
+
"rstrip": false,
|
| 96 |
+
"single_word": false,
|
| 97 |
+
"special": false
|
| 98 |
+
},
|
| 99 |
+
"163598": {
|
| 100 |
+
"content": "<|tool_call_argument_begin|>",
|
| 101 |
+
"lstrip": false,
|
| 102 |
+
"normalized": false,
|
| 103 |
+
"rstrip": false,
|
| 104 |
+
"single_word": false,
|
| 105 |
+
"special": false
|
| 106 |
+
},
|
| 107 |
+
"163599": {
|
| 108 |
+
"content": "<|tool_call_end|>",
|
| 109 |
+
"lstrip": false,
|
| 110 |
+
"normalized": false,
|
| 111 |
+
"rstrip": false,
|
| 112 |
+
"single_word": false,
|
| 113 |
+
"special": false
|
| 114 |
+
},
|
| 115 |
+
"163601": {
|
| 116 |
+
"content": "<|im_middle|>",
|
| 117 |
+
"lstrip": false,
|
| 118 |
+
"normalized": false,
|
| 119 |
+
"rstrip": false,
|
| 120 |
+
"single_word": false,
|
| 121 |
+
"special": true
|
| 122 |
+
},
|
| 123 |
+
"163838": {
|
| 124 |
+
"content": "[UNK]",
|
| 125 |
+
"lstrip": false,
|
| 126 |
+
"normalized": false,
|
| 127 |
+
"rstrip": false,
|
| 128 |
+
"single_word": false,
|
| 129 |
+
"special": true
|
| 130 |
+
},
|
| 131 |
+
"163839": {
|
| 132 |
+
"content": "[PAD]",
|
| 133 |
+
"lstrip": false,
|
| 134 |
+
"normalized": false,
|
| 135 |
+
"rstrip": false,
|
| 136 |
+
"single_word": false,
|
| 137 |
+
"special": true
|
| 138 |
+
}
|
| 139 |
+
},
|
| 140 |
+
"additional_special_tokens": [
|
| 141 |
+
"<|im_end|>",
|
| 142 |
+
"<|im_user|>",
|
| 143 |
+
"<|im_assistant|>",
|
| 144 |
+
"<|start_header_id|>",
|
| 145 |
+
"<|end_header_id|>",
|
| 146 |
+
"[EOT]",
|
| 147 |
+
"<|im_system|>",
|
| 148 |
+
"<|im_middle|>"
|
| 149 |
+
],
|
| 150 |
+
"bos_token": "[BOS]",
|
| 151 |
+
"clean_up_tokenization_spaces": false,
|
| 152 |
+
"eos_token": "[EOS]",
|
| 153 |
+
"extra_special_tokens": {},
|
| 154 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 155 |
+
"pad_token": "[PAD]",
|
| 156 |
+
"tokenizer_class": "TikTokenTokenizer",
|
| 157 |
+
"unk_token": "[UNK]",
|
| 158 |
+
"auto_map": {
|
| 159 |
+
"AutoTokenizer": [
|
| 160 |
+
"tokenization_kimi.TikTokenTokenizer",
|
| 161 |
+
null
|
| 162 |
+
]
|
| 163 |
+
}
|
| 164 |
+
}
|