itlevy
commited on
Commit
·
b5dfaf4
verified
·
0
Parent(s):
transformers>=4.44.2, backward compat
Browse files- .gitattributes +35 -0
- README.md +182 -0
- __init__.py +0 -0
- config.json +1004 -0
- configuration_decilm.py +99 -0
- model-00001-of-00022.safetensors +3 -0
- model-00002-of-00022.safetensors +3 -0
- model-00003-of-00022.safetensors +3 -0
- model-00004-of-00022.safetensors +3 -0
- model-00005-of-00022.safetensors +3 -0
- model-00006-of-00022.safetensors +3 -0
- model-00007-of-00022.safetensors +3 -0
- model-00008-of-00022.safetensors +3 -0
- model-00009-of-00022.safetensors +3 -0
- model-00010-of-00022.safetensors +3 -0
- model-00011-of-00022.safetensors +3 -0
- model-00012-of-00022.safetensors +3 -0
- model-00013-of-00022.safetensors +3 -0
- model-00014-of-00022.safetensors +3 -0
- model-00015-of-00022.safetensors +3 -0
- model-00016-of-00022.safetensors +3 -0
- model-00017-of-00022.safetensors +3 -0
- model-00018-of-00022.safetensors +3 -0
- model-00019-of-00022.safetensors +3 -0
- model-00020-of-00022.safetensors +3 -0
- model-00021-of-00022.safetensors +3 -0
- model-00022-of-00022.safetensors +3 -0
- model.safetensors.index.json +636 -0
- modeling_decilm.py +1665 -0
- special_tokens_map.json +16 -0
- tokenizer.json +0 -0
- tokenizer_chat_template.jinja +6 -0
- tokenizer_config.json +2062 -0
- transformers_4_44_2__activations.py +239 -0
- transformers_4_44_2__cache_utils.py +325 -0
- transformers_4_44_2__configuration_llama.py +203 -0
- transformers_4_44_2__modeling_attn_mask_utils.py +482 -0
- transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py +348 -0
- transformers_4_44_2__modeling_outputs.py +0 -0
- transformers_4_44_2__modeling_rope_utils.py +559 -0
- transformers_4_44_2__pytorch_utils.py +17 -0
- variable_cache.py +108 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
pipeline_tag: text-generation
|
4 |
+
language:
|
5 |
+
- en
|
6 |
+
tags:
|
7 |
+
- nvidia
|
8 |
+
- llama-3
|
9 |
+
- pytorch
|
10 |
+
license: other
|
11 |
+
license_name: nvidia-ai-foundation-models-community-license
|
12 |
+
license_link: >-
|
13 |
+
https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-ai-foundation-models-community-license-agreement/
|
14 |
+
---
|
15 |
+
|
16 |
+
# Llama-3_1-Nemotron-51B-instruct
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
## Model Overview
|
21 |
+
Llama-3_1-Nemotron-51B-instruct is a model which offers a great tradeoff between model accuracy and efficiency. Efficiency (throughput) directly translates to price, providing great ‘quality-per-dollar’. Using a novel Neural Architecture Search (NAS) approach we greatly reduce the model’s memory footprint, enabling larger workloads, as well as fitting the model on a single GPU at high workloads (H100-80GB). This NAS approach enables the selection of a desired point in the accuracy-efficiency tradeoff. This model is ready for commercial use.
|
22 |
+
|
23 |
+
|
24 |
+
## License
|
25 |
+
[NVIDIA AI Foundation Models Community License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-ai-foundation-models-community-license-agreement/). Additional Information: [Llama 3.1 Community License Agreement](https://www.llama.com/llama3_1/license/). Built with Llama.
|
26 |
+
|
27 |
+
## How was the model developed
|
28 |
+
|
29 |
+
Llama-3_1-Nemotron-51B-instruct is a large language model (LLM) which is a derivative of Llama-3.1-70B-instruct (AKA the reference model). We utilize a block-wise distillation of the reference model, where for each block we create multiple variants providing different tradeoffs of quality vs. computational complexity. We then search over the blocks to create a model which meets the required throughput and memory (optimized for a single H100-80GB GPU) while minimizing the quality degradation. The model then undergoes knowledge distillation (KD), with a focus on English single and multi-turn chat use-cases.
|
30 |
+
The KD step included 40 billion tokens consisting of a mixture of 3 datasets - FineWeb, Buzz-V1.2 and Dolma.
|
31 |
+
|
32 |
+
Links to [NIM](https://build.nvidia.com/nvidia/llama-3_1-nemotron-51b-instruct), [blog](https://developer.nvidia.com/blog/advancing-the-accuracy-efficiency-frontier-with-llama-3-1-nemotron-51b/) and [huggingface](https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct)
|
33 |
+
|
34 |
+
|
35 |
+
This results in a final model that is aligned for human chat preferences.
|
36 |
+
|
37 |
+
**Model Developers:** NVIDIA
|
38 |
+
|
39 |
+
**Model Input:** Text only
|
40 |
+
|
41 |
+
**Model Output:** Text only
|
42 |
+
|
43 |
+
**Model Dates:** Llama-3_1-Nemotron-51B-instruct was trained between August and September 2024
|
44 |
+
|
45 |
+
**Data Freshness:** The pretraining data has a cutoff of 2023
|
46 |
+
|
47 |
+
**Sequence Length Used During Distillation:** 8192
|
48 |
+
|
49 |
+
|
50 |
+
## Quick Start
|
51 |
+
Our code requires the `transformers` package version to be 4.44.2 or higher
|
52 |
+
|
53 |
+
See the snippet below for usage with transformers:
|
54 |
+
```python
|
55 |
+
import torch
|
56 |
+
import transformers
|
57 |
+
|
58 |
+
model_id = "nvidia/Llama-3_1-Nemotron-51B-Instruct"
|
59 |
+
model_kwargs = {"torch_dtype": torch.bfloat16, "trust_remote_code": True, "device_map": "auto"}
|
60 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
|
61 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
62 |
+
|
63 |
+
pipeline = transformers.pipeline(
|
64 |
+
"text-generation",
|
65 |
+
model=model_id,
|
66 |
+
tokenizer=tokenizer,
|
67 |
+
max_new_tokens=20,
|
68 |
+
**model_kwargs
|
69 |
+
)
|
70 |
+
print(pipeline([{"role": "user", "content": "Hey how are you?"}]))
|
71 |
+
```
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
## Required Hardware
|
76 |
+
|
77 |
+
FP8 Inference (recommended):
|
78 |
+
- 1x H100-80GB GPU
|
79 |
+
|
80 |
+
BF16 Inference:
|
81 |
+
- 2x H100-80GB GPUs
|
82 |
+
- 2x A100-80GB GPUs
|
83 |
+
|
84 |
+
|
85 |
+
## Model Architecture
|
86 |
+
The model is a derivative of Llama-3.1-70B, using Neural Architecture Search (NAS). The NAS algorithm results in non-standard and non-repetitive blocks. This includes the following:
|
87 |
+
* Variable Grouped Query Attention (VGQA) - each block can have a different number of KV (keys and values) heads, ranging from 1 to Llama’s typical 8.
|
88 |
+
* Skip attention - in some blocks the attention is skipped entirely, or replaced with a single linear layer.
|
89 |
+
* Variable FFN - the expansion/compression ratio in the FFN layer is different between blocks.
|
90 |
+
|
91 |
+
|
92 |
+
**Architecture Type:** Transformer Decoder (auto-regressive language model)
|
93 |
+
|
94 |
+
## Software Integration
|
95 |
+
**Runtime Engine(s):**
|
96 |
+
* NeMo 24.05 <br>
|
97 |
+
|
98 |
+
|
99 |
+
**Supported Hardware Architecture Compatibility:** NVIDIA H100, A100 80GB (BF16 quantization).
|
100 |
+
|
101 |
+
**[Preferred/Supported] Operating System(s):** <br>
|
102 |
+
* Linux <br>
|
103 |
+
|
104 |
+
## Intended use
|
105 |
+
|
106 |
+
Llama-3_1-Nemotron-51B-Instruct is a general purpose chat model intended to be used in English and coding languages. Other non-English languages are also supported.
|
107 |
+
|
108 |
+
## Evaluation Results
|
109 |
+
|
110 |
+
**Data Collection Method by dataset** <br>
|
111 |
+
* Automated <br>
|
112 |
+
|
113 |
+
|
114 |
+
### MT-Bench
|
115 |
+
|
116 |
+
Evaluated using select datasets from the [Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena](https://arxiv.org/pdf/2306.05685v4)
|
117 |
+
MT-bench - 8.99
|
118 |
+
|
119 |
+
|
120 |
+
### MMLU
|
121 |
+
|
122 |
+
Evaluated using the Multi-task Language Understanding benchmarks as introduced in [Measuring Massive Multitask Language Understanding](https://arxiv.org/pdf/2009.03300)
|
123 |
+
|
124 |
+
|MMLU (5-shot) |
|
125 |
+
| :----------------- |
|
126 |
+
| 80.2% |
|
127 |
+
|
128 |
+
### GSM8K
|
129 |
+
|
130 |
+
Evaluated using the Grade School Math 8K (GSM8K) benchmark as introduced in [Training Verifiers to Solve Math Word Problems](https://arxiv.org/pdf/2110.14168v2)
|
131 |
+
|
132 |
+
|GSM8K (5-shot) |
|
133 |
+
| :----------------- |
|
134 |
+
| 91.43% |
|
135 |
+
|
136 |
+
### Winogrande
|
137 |
+
|
138 |
+
|Winogrande (5-shot) |
|
139 |
+
| :----------------- |
|
140 |
+
| 84.53% |
|
141 |
+
|
142 |
+
### Arc-C
|
143 |
+
|
144 |
+
|Arc challenge (25-shot) |
|
145 |
+
| :----------------- |
|
146 |
+
| 69.20% |
|
147 |
+
|
148 |
+
### Hellaswag
|
149 |
+
|
150 |
+
|Hellaswag (10-shot) |
|
151 |
+
| :----------------- |
|
152 |
+
| 85.58% |
|
153 |
+
|
154 |
+
### Truthful QA
|
155 |
+
|
156 |
+
|TruthfulQA (0-shot) |
|
157 |
+
| :----------------- |
|
158 |
+
| 58.63%% |
|
159 |
+
|
160 |
+
## Limitations
|
161 |
+
|
162 |
+
The model was trained on data that contains toxic language, unsafe content, and societal biases originally crawled from the internet. Therefore, the model may amplify those biases and return toxic responses especially when prompted with toxic prompts. The model may generate answers that may be inaccurate, omit key information, or include irrelevant or redundant text producing socially unacceptable or undesirable text, even if the prompt itself does not include anything explicitly offensive.
|
163 |
+
|
164 |
+
The model demonstrates weakness to alignment-breaking attacks. Users are advised to deploy language model guardrails alongside this model to prevent potentially harmful outputs.
|
165 |
+
|
166 |
+
## Adversarial Testing and Red Teaming Efforts
|
167 |
+
|
168 |
+
The Llama-3_1-Nemotron-51B-instruct model underwent extensive safety evaluation including adversarial testing via three distinct methods:
|
169 |
+
* [Garak](https://docs.garak.ai/garak), is an automated LLM vulnerability scanner that probes for common weaknesses, including prompt injection and data leakage.
|
170 |
+
* [AEGIS](https://arxiv.org/pdf/2404.05993), is a content safety evaluation dataset and LLM based content safety classifier model, that adheres to a broad taxonomy of 13 categories of critical risks in human-LLM interactions.
|
171 |
+
* Human Content Red Teaming leveraging human interaction and evaluation of the models' responses.
|
172 |
+
|
173 |
+
|
174 |
+
## Inference
|
175 |
+
**Engine:** Tensor(RT) <br>
|
176 |
+
**Test Hardware** H100-80GB <br>
|
177 |
+
|
178 |
+
|
179 |
+
## Ethical Considerations
|
180 |
+
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
|
181 |
+
|
182 |
+
Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
|
__init__.py
ADDED
File without changes
|
config.json
ADDED
@@ -0,0 +1,1004 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"DeciLMForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_decilm.DeciLMConfig",
|
9 |
+
"AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM"
|
10 |
+
},
|
11 |
+
"block_configs": [
|
12 |
+
{
|
13 |
+
"attention": {
|
14 |
+
"n_heads_in_group": 8,
|
15 |
+
"no_op": false,
|
16 |
+
"replace_with_linear": false
|
17 |
+
},
|
18 |
+
"ffn": {
|
19 |
+
"ffn_mult": 1.3125,
|
20 |
+
"no_op": false,
|
21 |
+
"replace_with_linear": false
|
22 |
+
}
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"attention": {
|
26 |
+
"n_heads_in_group": 16,
|
27 |
+
"no_op": false,
|
28 |
+
"replace_with_linear": false
|
29 |
+
},
|
30 |
+
"ffn": {
|
31 |
+
"ffn_mult": 2.625,
|
32 |
+
"no_op": false,
|
33 |
+
"replace_with_linear": false
|
34 |
+
}
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"attention": {
|
38 |
+
"n_heads_in_group": 8,
|
39 |
+
"no_op": false,
|
40 |
+
"replace_with_linear": false
|
41 |
+
},
|
42 |
+
"ffn": {
|
43 |
+
"ffn_mult": 5.25,
|
44 |
+
"no_op": false,
|
45 |
+
"replace_with_linear": false
|
46 |
+
}
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"attention": {
|
50 |
+
"n_heads_in_group": 8,
|
51 |
+
"no_op": false,
|
52 |
+
"replace_with_linear": false
|
53 |
+
},
|
54 |
+
"ffn": {
|
55 |
+
"ffn_mult": 5.25,
|
56 |
+
"no_op": false,
|
57 |
+
"replace_with_linear": false
|
58 |
+
}
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"attention": {
|
62 |
+
"n_heads_in_group": 8,
|
63 |
+
"no_op": false,
|
64 |
+
"replace_with_linear": false
|
65 |
+
},
|
66 |
+
"ffn": {
|
67 |
+
"ffn_mult": 5.25,
|
68 |
+
"no_op": false,
|
69 |
+
"replace_with_linear": false
|
70 |
+
}
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"attention": {
|
74 |
+
"n_heads_in_group": 32,
|
75 |
+
"no_op": false,
|
76 |
+
"replace_with_linear": false
|
77 |
+
},
|
78 |
+
"ffn": {
|
79 |
+
"ffn_mult": 2.625,
|
80 |
+
"no_op": false,
|
81 |
+
"replace_with_linear": false
|
82 |
+
}
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"attention": {
|
86 |
+
"n_heads_in_group": 32,
|
87 |
+
"no_op": false,
|
88 |
+
"replace_with_linear": false
|
89 |
+
},
|
90 |
+
"ffn": {
|
91 |
+
"ffn_mult": 2.625,
|
92 |
+
"no_op": false,
|
93 |
+
"replace_with_linear": false
|
94 |
+
}
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"attention": {
|
98 |
+
"n_heads_in_group": 64,
|
99 |
+
"no_op": false,
|
100 |
+
"replace_with_linear": false
|
101 |
+
},
|
102 |
+
"ffn": {
|
103 |
+
"ffn_mult": 2.625,
|
104 |
+
"no_op": false,
|
105 |
+
"replace_with_linear": false
|
106 |
+
}
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"attention": {
|
110 |
+
"n_heads_in_group": 64,
|
111 |
+
"no_op": false,
|
112 |
+
"replace_with_linear": false
|
113 |
+
},
|
114 |
+
"ffn": {
|
115 |
+
"ffn_mult": 2.625,
|
116 |
+
"no_op": false,
|
117 |
+
"replace_with_linear": false
|
118 |
+
}
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"attention": {
|
122 |
+
"n_heads_in_group": 32,
|
123 |
+
"no_op": false,
|
124 |
+
"replace_with_linear": false
|
125 |
+
},
|
126 |
+
"ffn": {
|
127 |
+
"ffn_mult": 2.625,
|
128 |
+
"no_op": false,
|
129 |
+
"replace_with_linear": false
|
130 |
+
}
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"attention": {
|
134 |
+
"n_heads_in_group": 32,
|
135 |
+
"no_op": false,
|
136 |
+
"replace_with_linear": false
|
137 |
+
},
|
138 |
+
"ffn": {
|
139 |
+
"ffn_mult": 2.625,
|
140 |
+
"no_op": false,
|
141 |
+
"replace_with_linear": false
|
142 |
+
}
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"attention": {
|
146 |
+
"n_heads_in_group": null,
|
147 |
+
"no_op": false,
|
148 |
+
"replace_with_linear": true
|
149 |
+
},
|
150 |
+
"ffn": {
|
151 |
+
"ffn_mult": 2.625,
|
152 |
+
"no_op": false,
|
153 |
+
"replace_with_linear": false
|
154 |
+
}
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"attention": {
|
158 |
+
"n_heads_in_group": 64,
|
159 |
+
"no_op": false,
|
160 |
+
"replace_with_linear": false
|
161 |
+
},
|
162 |
+
"ffn": {
|
163 |
+
"ffn_mult": 2.625,
|
164 |
+
"no_op": false,
|
165 |
+
"replace_with_linear": false
|
166 |
+
}
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"attention": {
|
170 |
+
"n_heads_in_group": 32,
|
171 |
+
"no_op": false,
|
172 |
+
"replace_with_linear": false
|
173 |
+
},
|
174 |
+
"ffn": {
|
175 |
+
"ffn_mult": 2.625,
|
176 |
+
"no_op": false,
|
177 |
+
"replace_with_linear": false
|
178 |
+
}
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"attention": {
|
182 |
+
"n_heads_in_group": 32,
|
183 |
+
"no_op": false,
|
184 |
+
"replace_with_linear": false
|
185 |
+
},
|
186 |
+
"ffn": {
|
187 |
+
"ffn_mult": 2.625,
|
188 |
+
"no_op": false,
|
189 |
+
"replace_with_linear": false
|
190 |
+
}
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"attention": {
|
194 |
+
"n_heads_in_group": null,
|
195 |
+
"no_op": false,
|
196 |
+
"replace_with_linear": true
|
197 |
+
},
|
198 |
+
"ffn": {
|
199 |
+
"ffn_mult": 1.3125,
|
200 |
+
"no_op": false,
|
201 |
+
"replace_with_linear": false
|
202 |
+
}
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"attention": {
|
206 |
+
"n_heads_in_group": 8,
|
207 |
+
"no_op": false,
|
208 |
+
"replace_with_linear": false
|
209 |
+
},
|
210 |
+
"ffn": {
|
211 |
+
"ffn_mult": 5.25,
|
212 |
+
"no_op": false,
|
213 |
+
"replace_with_linear": false
|
214 |
+
}
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"attention": {
|
218 |
+
"n_heads_in_group": 8,
|
219 |
+
"no_op": false,
|
220 |
+
"replace_with_linear": false
|
221 |
+
},
|
222 |
+
"ffn": {
|
223 |
+
"ffn_mult": 5.25,
|
224 |
+
"no_op": false,
|
225 |
+
"replace_with_linear": false
|
226 |
+
}
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"attention": {
|
230 |
+
"n_heads_in_group": 8,
|
231 |
+
"no_op": false,
|
232 |
+
"replace_with_linear": false
|
233 |
+
},
|
234 |
+
"ffn": {
|
235 |
+
"ffn_mult": 5.25,
|
236 |
+
"no_op": false,
|
237 |
+
"replace_with_linear": false
|
238 |
+
}
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"attention": {
|
242 |
+
"n_heads_in_group": 8,
|
243 |
+
"no_op": false,
|
244 |
+
"replace_with_linear": false
|
245 |
+
},
|
246 |
+
"ffn": {
|
247 |
+
"ffn_mult": 5.25,
|
248 |
+
"no_op": false,
|
249 |
+
"replace_with_linear": false
|
250 |
+
}
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"attention": {
|
254 |
+
"n_heads_in_group": 8,
|
255 |
+
"no_op": false,
|
256 |
+
"replace_with_linear": false
|
257 |
+
},
|
258 |
+
"ffn": {
|
259 |
+
"ffn_mult": 5.25,
|
260 |
+
"no_op": false,
|
261 |
+
"replace_with_linear": false
|
262 |
+
}
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"attention": {
|
266 |
+
"n_heads_in_group": 8,
|
267 |
+
"no_op": false,
|
268 |
+
"replace_with_linear": false
|
269 |
+
},
|
270 |
+
"ffn": {
|
271 |
+
"ffn_mult": 5.25,
|
272 |
+
"no_op": false,
|
273 |
+
"replace_with_linear": false
|
274 |
+
}
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"attention": {
|
278 |
+
"n_heads_in_group": 8,
|
279 |
+
"no_op": false,
|
280 |
+
"replace_with_linear": false
|
281 |
+
},
|
282 |
+
"ffn": {
|
283 |
+
"ffn_mult": 5.25,
|
284 |
+
"no_op": false,
|
285 |
+
"replace_with_linear": false
|
286 |
+
}
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"attention": {
|
290 |
+
"n_heads_in_group": 8,
|
291 |
+
"no_op": false,
|
292 |
+
"replace_with_linear": false
|
293 |
+
},
|
294 |
+
"ffn": {
|
295 |
+
"ffn_mult": 5.25,
|
296 |
+
"no_op": false,
|
297 |
+
"replace_with_linear": false
|
298 |
+
}
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"attention": {
|
302 |
+
"n_heads_in_group": 8,
|
303 |
+
"no_op": false,
|
304 |
+
"replace_with_linear": false
|
305 |
+
},
|
306 |
+
"ffn": {
|
307 |
+
"ffn_mult": 5.25,
|
308 |
+
"no_op": false,
|
309 |
+
"replace_with_linear": false
|
310 |
+
}
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"attention": {
|
314 |
+
"n_heads_in_group": 8,
|
315 |
+
"no_op": false,
|
316 |
+
"replace_with_linear": false
|
317 |
+
},
|
318 |
+
"ffn": {
|
319 |
+
"ffn_mult": 5.25,
|
320 |
+
"no_op": false,
|
321 |
+
"replace_with_linear": false
|
322 |
+
}
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"attention": {
|
326 |
+
"n_heads_in_group": 8,
|
327 |
+
"no_op": false,
|
328 |
+
"replace_with_linear": false
|
329 |
+
},
|
330 |
+
"ffn": {
|
331 |
+
"ffn_mult": 5.25,
|
332 |
+
"no_op": false,
|
333 |
+
"replace_with_linear": false
|
334 |
+
}
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"attention": {
|
338 |
+
"n_heads_in_group": 8,
|
339 |
+
"no_op": false,
|
340 |
+
"replace_with_linear": false
|
341 |
+
},
|
342 |
+
"ffn": {
|
343 |
+
"ffn_mult": 5.25,
|
344 |
+
"no_op": false,
|
345 |
+
"replace_with_linear": false
|
346 |
+
}
|
347 |
+
},
|
348 |
+
{
|
349 |
+
"attention": {
|
350 |
+
"n_heads_in_group": 8,
|
351 |
+
"no_op": false,
|
352 |
+
"replace_with_linear": false
|
353 |
+
},
|
354 |
+
"ffn": {
|
355 |
+
"ffn_mult": 5.25,
|
356 |
+
"no_op": false,
|
357 |
+
"replace_with_linear": false
|
358 |
+
}
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"attention": {
|
362 |
+
"n_heads_in_group": 8,
|
363 |
+
"no_op": false,
|
364 |
+
"replace_with_linear": false
|
365 |
+
},
|
366 |
+
"ffn": {
|
367 |
+
"ffn_mult": 5.25,
|
368 |
+
"no_op": false,
|
369 |
+
"replace_with_linear": false
|
370 |
+
}
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"attention": {
|
374 |
+
"n_heads_in_group": 8,
|
375 |
+
"no_op": false,
|
376 |
+
"replace_with_linear": false
|
377 |
+
},
|
378 |
+
"ffn": {
|
379 |
+
"ffn_mult": 5.25,
|
380 |
+
"no_op": false,
|
381 |
+
"replace_with_linear": false
|
382 |
+
}
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"attention": {
|
386 |
+
"n_heads_in_group": 8,
|
387 |
+
"no_op": false,
|
388 |
+
"replace_with_linear": false
|
389 |
+
},
|
390 |
+
"ffn": {
|
391 |
+
"ffn_mult": 5.25,
|
392 |
+
"no_op": false,
|
393 |
+
"replace_with_linear": false
|
394 |
+
}
|
395 |
+
},
|
396 |
+
{
|
397 |
+
"attention": {
|
398 |
+
"n_heads_in_group": 8,
|
399 |
+
"no_op": false,
|
400 |
+
"replace_with_linear": false
|
401 |
+
},
|
402 |
+
"ffn": {
|
403 |
+
"ffn_mult": 5.25,
|
404 |
+
"no_op": false,
|
405 |
+
"replace_with_linear": false
|
406 |
+
}
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"attention": {
|
410 |
+
"n_heads_in_group": 8,
|
411 |
+
"no_op": false,
|
412 |
+
"replace_with_linear": false
|
413 |
+
},
|
414 |
+
"ffn": {
|
415 |
+
"ffn_mult": 5.25,
|
416 |
+
"no_op": false,
|
417 |
+
"replace_with_linear": false
|
418 |
+
}
|
419 |
+
},
|
420 |
+
{
|
421 |
+
"attention": {
|
422 |
+
"n_heads_in_group": 8,
|
423 |
+
"no_op": false,
|
424 |
+
"replace_with_linear": false
|
425 |
+
},
|
426 |
+
"ffn": {
|
427 |
+
"ffn_mult": 5.25,
|
428 |
+
"no_op": false,
|
429 |
+
"replace_with_linear": false
|
430 |
+
}
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"attention": {
|
434 |
+
"n_heads_in_group": 8,
|
435 |
+
"no_op": false,
|
436 |
+
"replace_with_linear": false
|
437 |
+
},
|
438 |
+
"ffn": {
|
439 |
+
"ffn_mult": 5.25,
|
440 |
+
"no_op": false,
|
441 |
+
"replace_with_linear": false
|
442 |
+
}
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"attention": {
|
446 |
+
"n_heads_in_group": 8,
|
447 |
+
"no_op": false,
|
448 |
+
"replace_with_linear": false
|
449 |
+
},
|
450 |
+
"ffn": {
|
451 |
+
"ffn_mult": 5.25,
|
452 |
+
"no_op": false,
|
453 |
+
"replace_with_linear": false
|
454 |
+
}
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"attention": {
|
458 |
+
"n_heads_in_group": 8,
|
459 |
+
"no_op": false,
|
460 |
+
"replace_with_linear": false
|
461 |
+
},
|
462 |
+
"ffn": {
|
463 |
+
"ffn_mult": 5.25,
|
464 |
+
"no_op": false,
|
465 |
+
"replace_with_linear": false
|
466 |
+
}
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"attention": {
|
470 |
+
"n_heads_in_group": 8,
|
471 |
+
"no_op": false,
|
472 |
+
"replace_with_linear": false
|
473 |
+
},
|
474 |
+
"ffn": {
|
475 |
+
"ffn_mult": 5.25,
|
476 |
+
"no_op": false,
|
477 |
+
"replace_with_linear": false
|
478 |
+
}
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"attention": {
|
482 |
+
"n_heads_in_group": 8,
|
483 |
+
"no_op": false,
|
484 |
+
"replace_with_linear": false
|
485 |
+
},
|
486 |
+
"ffn": {
|
487 |
+
"ffn_mult": 5.25,
|
488 |
+
"no_op": false,
|
489 |
+
"replace_with_linear": false
|
490 |
+
}
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"attention": {
|
494 |
+
"n_heads_in_group": 8,
|
495 |
+
"no_op": false,
|
496 |
+
"replace_with_linear": false
|
497 |
+
},
|
498 |
+
"ffn": {
|
499 |
+
"ffn_mult": 5.25,
|
500 |
+
"no_op": false,
|
501 |
+
"replace_with_linear": false
|
502 |
+
}
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"attention": {
|
506 |
+
"n_heads_in_group": 8,
|
507 |
+
"no_op": false,
|
508 |
+
"replace_with_linear": false
|
509 |
+
},
|
510 |
+
"ffn": {
|
511 |
+
"ffn_mult": 5.25,
|
512 |
+
"no_op": false,
|
513 |
+
"replace_with_linear": false
|
514 |
+
}
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"attention": {
|
518 |
+
"n_heads_in_group": null,
|
519 |
+
"no_op": false,
|
520 |
+
"replace_with_linear": true
|
521 |
+
},
|
522 |
+
"ffn": {
|
523 |
+
"ffn_mult": 2.625,
|
524 |
+
"no_op": false,
|
525 |
+
"replace_with_linear": false
|
526 |
+
}
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"attention": {
|
530 |
+
"n_heads_in_group": 8,
|
531 |
+
"no_op": false,
|
532 |
+
"replace_with_linear": false
|
533 |
+
},
|
534 |
+
"ffn": {
|
535 |
+
"ffn_mult": 5.25,
|
536 |
+
"no_op": false,
|
537 |
+
"replace_with_linear": false
|
538 |
+
}
|
539 |
+
},
|
540 |
+
{
|
541 |
+
"attention": {
|
542 |
+
"n_heads_in_group": 8,
|
543 |
+
"no_op": false,
|
544 |
+
"replace_with_linear": false
|
545 |
+
},
|
546 |
+
"ffn": {
|
547 |
+
"ffn_mult": 5.25,
|
548 |
+
"no_op": false,
|
549 |
+
"replace_with_linear": false
|
550 |
+
}
|
551 |
+
},
|
552 |
+
{
|
553 |
+
"attention": {
|
554 |
+
"n_heads_in_group": null,
|
555 |
+
"no_op": false,
|
556 |
+
"replace_with_linear": true
|
557 |
+
},
|
558 |
+
"ffn": {
|
559 |
+
"ffn_mult": 2.625,
|
560 |
+
"no_op": false,
|
561 |
+
"replace_with_linear": false
|
562 |
+
}
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"attention": {
|
566 |
+
"n_heads_in_group": null,
|
567 |
+
"no_op": false,
|
568 |
+
"replace_with_linear": true
|
569 |
+
},
|
570 |
+
"ffn": {
|
571 |
+
"ffn_mult": 5.25,
|
572 |
+
"no_op": false,
|
573 |
+
"replace_with_linear": false
|
574 |
+
}
|
575 |
+
},
|
576 |
+
{
|
577 |
+
"attention": {
|
578 |
+
"n_heads_in_group": null,
|
579 |
+
"no_op": false,
|
580 |
+
"replace_with_linear": true
|
581 |
+
},
|
582 |
+
"ffn": {
|
583 |
+
"ffn_mult": 2.625,
|
584 |
+
"no_op": false,
|
585 |
+
"replace_with_linear": false
|
586 |
+
}
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"attention": {
|
590 |
+
"n_heads_in_group": null,
|
591 |
+
"no_op": false,
|
592 |
+
"replace_with_linear": true
|
593 |
+
},
|
594 |
+
"ffn": {
|
595 |
+
"ffn_mult": 2.625,
|
596 |
+
"no_op": false,
|
597 |
+
"replace_with_linear": false
|
598 |
+
}
|
599 |
+
},
|
600 |
+
{
|
601 |
+
"attention": {
|
602 |
+
"n_heads_in_group": null,
|
603 |
+
"no_op": false,
|
604 |
+
"replace_with_linear": true
|
605 |
+
},
|
606 |
+
"ffn": {
|
607 |
+
"ffn_mult": 2.625,
|
608 |
+
"no_op": false,
|
609 |
+
"replace_with_linear": false
|
610 |
+
}
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"attention": {
|
614 |
+
"n_heads_in_group": null,
|
615 |
+
"no_op": true,
|
616 |
+
"replace_with_linear": false
|
617 |
+
},
|
618 |
+
"ffn": {
|
619 |
+
"ffn_mult": 1.3125,
|
620 |
+
"no_op": false,
|
621 |
+
"replace_with_linear": false
|
622 |
+
}
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"attention": {
|
626 |
+
"n_heads_in_group": null,
|
627 |
+
"no_op": false,
|
628 |
+
"replace_with_linear": true
|
629 |
+
},
|
630 |
+
"ffn": {
|
631 |
+
"ffn_mult": 1.3125,
|
632 |
+
"no_op": false,
|
633 |
+
"replace_with_linear": false
|
634 |
+
}
|
635 |
+
},
|
636 |
+
{
|
637 |
+
"attention": {
|
638 |
+
"n_heads_in_group": 8,
|
639 |
+
"no_op": false,
|
640 |
+
"replace_with_linear": false
|
641 |
+
},
|
642 |
+
"ffn": {
|
643 |
+
"ffn_mult": 5.25,
|
644 |
+
"no_op": false,
|
645 |
+
"replace_with_linear": false
|
646 |
+
}
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"attention": {
|
650 |
+
"n_heads_in_group": null,
|
651 |
+
"no_op": true,
|
652 |
+
"replace_with_linear": false
|
653 |
+
},
|
654 |
+
"ffn": {
|
655 |
+
"ffn_mult": 1.3125,
|
656 |
+
"no_op": false,
|
657 |
+
"replace_with_linear": false
|
658 |
+
}
|
659 |
+
},
|
660 |
+
{
|
661 |
+
"attention": {
|
662 |
+
"n_heads_in_group": null,
|
663 |
+
"no_op": false,
|
664 |
+
"replace_with_linear": true
|
665 |
+
},
|
666 |
+
"ffn": {
|
667 |
+
"ffn_mult": 1.3125,
|
668 |
+
"no_op": false,
|
669 |
+
"replace_with_linear": false
|
670 |
+
}
|
671 |
+
},
|
672 |
+
{
|
673 |
+
"attention": {
|
674 |
+
"n_heads_in_group": null,
|
675 |
+
"no_op": true,
|
676 |
+
"replace_with_linear": false
|
677 |
+
},
|
678 |
+
"ffn": {
|
679 |
+
"ffn_mult": 1.3125,
|
680 |
+
"no_op": false,
|
681 |
+
"replace_with_linear": false
|
682 |
+
}
|
683 |
+
},
|
684 |
+
{
|
685 |
+
"attention": {
|
686 |
+
"n_heads_in_group": 8,
|
687 |
+
"no_op": false,
|
688 |
+
"replace_with_linear": false
|
689 |
+
},
|
690 |
+
"ffn": {
|
691 |
+
"ffn_mult": 5.25,
|
692 |
+
"no_op": false,
|
693 |
+
"replace_with_linear": false
|
694 |
+
}
|
695 |
+
},
|
696 |
+
{
|
697 |
+
"attention": {
|
698 |
+
"n_heads_in_group": null,
|
699 |
+
"no_op": false,
|
700 |
+
"replace_with_linear": true
|
701 |
+
},
|
702 |
+
"ffn": {
|
703 |
+
"ffn_mult": 1.3125,
|
704 |
+
"no_op": false,
|
705 |
+
"replace_with_linear": false
|
706 |
+
}
|
707 |
+
},
|
708 |
+
{
|
709 |
+
"attention": {
|
710 |
+
"n_heads_in_group": null,
|
711 |
+
"no_op": true,
|
712 |
+
"replace_with_linear": false
|
713 |
+
},
|
714 |
+
"ffn": {
|
715 |
+
"ffn_mult": 1.3125,
|
716 |
+
"no_op": false,
|
717 |
+
"replace_with_linear": false
|
718 |
+
}
|
719 |
+
},
|
720 |
+
{
|
721 |
+
"attention": {
|
722 |
+
"n_heads_in_group": null,
|
723 |
+
"no_op": false,
|
724 |
+
"replace_with_linear": true
|
725 |
+
},
|
726 |
+
"ffn": {
|
727 |
+
"ffn_mult": 1.3125,
|
728 |
+
"no_op": false,
|
729 |
+
"replace_with_linear": false
|
730 |
+
}
|
731 |
+
},
|
732 |
+
{
|
733 |
+
"attention": {
|
734 |
+
"n_heads_in_group": null,
|
735 |
+
"no_op": false,
|
736 |
+
"replace_with_linear": true
|
737 |
+
},
|
738 |
+
"ffn": {
|
739 |
+
"ffn_mult": 1.3125,
|
740 |
+
"no_op": false,
|
741 |
+
"replace_with_linear": false
|
742 |
+
}
|
743 |
+
},
|
744 |
+
{
|
745 |
+
"attention": {
|
746 |
+
"n_heads_in_group": null,
|
747 |
+
"no_op": true,
|
748 |
+
"replace_with_linear": false
|
749 |
+
},
|
750 |
+
"ffn": {
|
751 |
+
"ffn_mult": 1.3125,
|
752 |
+
"no_op": false,
|
753 |
+
"replace_with_linear": false
|
754 |
+
}
|
755 |
+
},
|
756 |
+
{
|
757 |
+
"attention": {
|
758 |
+
"n_heads_in_group": null,
|
759 |
+
"no_op": true,
|
760 |
+
"replace_with_linear": false
|
761 |
+
},
|
762 |
+
"ffn": {
|
763 |
+
"ffn_mult": 1.3125,
|
764 |
+
"no_op": false,
|
765 |
+
"replace_with_linear": false
|
766 |
+
}
|
767 |
+
},
|
768 |
+
{
|
769 |
+
"attention": {
|
770 |
+
"n_heads_in_group": null,
|
771 |
+
"no_op": false,
|
772 |
+
"replace_with_linear": true
|
773 |
+
},
|
774 |
+
"ffn": {
|
775 |
+
"ffn_mult": 1.3125,
|
776 |
+
"no_op": false,
|
777 |
+
"replace_with_linear": false
|
778 |
+
}
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"attention": {
|
782 |
+
"n_heads_in_group": null,
|
783 |
+
"no_op": true,
|
784 |
+
"replace_with_linear": false
|
785 |
+
},
|
786 |
+
"ffn": {
|
787 |
+
"ffn_mult": 1.3125,
|
788 |
+
"no_op": false,
|
789 |
+
"replace_with_linear": false
|
790 |
+
}
|
791 |
+
},
|
792 |
+
{
|
793 |
+
"attention": {
|
794 |
+
"n_heads_in_group": null,
|
795 |
+
"no_op": true,
|
796 |
+
"replace_with_linear": false
|
797 |
+
},
|
798 |
+
"ffn": {
|
799 |
+
"ffn_mult": 1.3125,
|
800 |
+
"no_op": false,
|
801 |
+
"replace_with_linear": false
|
802 |
+
}
|
803 |
+
},
|
804 |
+
{
|
805 |
+
"attention": {
|
806 |
+
"n_heads_in_group": null,
|
807 |
+
"no_op": false,
|
808 |
+
"replace_with_linear": true
|
809 |
+
},
|
810 |
+
"ffn": {
|
811 |
+
"ffn_mult": 1.3125,
|
812 |
+
"no_op": false,
|
813 |
+
"replace_with_linear": false
|
814 |
+
}
|
815 |
+
},
|
816 |
+
{
|
817 |
+
"attention": {
|
818 |
+
"n_heads_in_group": null,
|
819 |
+
"no_op": false,
|
820 |
+
"replace_with_linear": true
|
821 |
+
},
|
822 |
+
"ffn": {
|
823 |
+
"ffn_mult": 1.3125,
|
824 |
+
"no_op": false,
|
825 |
+
"replace_with_linear": false
|
826 |
+
}
|
827 |
+
},
|
828 |
+
{
|
829 |
+
"attention": {
|
830 |
+
"n_heads_in_group": null,
|
831 |
+
"no_op": false,
|
832 |
+
"replace_with_linear": true
|
833 |
+
},
|
834 |
+
"ffn": {
|
835 |
+
"ffn_mult": 1.3125,
|
836 |
+
"no_op": false,
|
837 |
+
"replace_with_linear": false
|
838 |
+
}
|
839 |
+
},
|
840 |
+
{
|
841 |
+
"attention": {
|
842 |
+
"n_heads_in_group": null,
|
843 |
+
"no_op": false,
|
844 |
+
"replace_with_linear": true
|
845 |
+
},
|
846 |
+
"ffn": {
|
847 |
+
"ffn_mult": 1.3125,
|
848 |
+
"no_op": false,
|
849 |
+
"replace_with_linear": false
|
850 |
+
}
|
851 |
+
},
|
852 |
+
{
|
853 |
+
"attention": {
|
854 |
+
"n_heads_in_group": 8,
|
855 |
+
"no_op": false,
|
856 |
+
"replace_with_linear": false
|
857 |
+
},
|
858 |
+
"ffn": {
|
859 |
+
"ffn_mult": 5.25,
|
860 |
+
"no_op": false,
|
861 |
+
"replace_with_linear": false
|
862 |
+
}
|
863 |
+
},
|
864 |
+
{
|
865 |
+
"attention": {
|
866 |
+
"n_heads_in_group": 8,
|
867 |
+
"no_op": false,
|
868 |
+
"replace_with_linear": false
|
869 |
+
},
|
870 |
+
"ffn": {
|
871 |
+
"ffn_mult": 5.25,
|
872 |
+
"no_op": false,
|
873 |
+
"replace_with_linear": false
|
874 |
+
}
|
875 |
+
},
|
876 |
+
{
|
877 |
+
"attention": {
|
878 |
+
"n_heads_in_group": 8,
|
879 |
+
"no_op": false,
|
880 |
+
"replace_with_linear": false
|
881 |
+
},
|
882 |
+
"ffn": {
|
883 |
+
"ffn_mult": 5.25,
|
884 |
+
"no_op": false,
|
885 |
+
"replace_with_linear": false
|
886 |
+
}
|
887 |
+
},
|
888 |
+
{
|
889 |
+
"attention": {
|
890 |
+
"n_heads_in_group": 8,
|
891 |
+
"no_op": false,
|
892 |
+
"replace_with_linear": false
|
893 |
+
},
|
894 |
+
"ffn": {
|
895 |
+
"ffn_mult": 5.25,
|
896 |
+
"no_op": false,
|
897 |
+
"replace_with_linear": false
|
898 |
+
}
|
899 |
+
},
|
900 |
+
{
|
901 |
+
"attention": {
|
902 |
+
"n_heads_in_group": 8,
|
903 |
+
"no_op": false,
|
904 |
+
"replace_with_linear": false
|
905 |
+
},
|
906 |
+
"ffn": {
|
907 |
+
"ffn_mult": 5.25,
|
908 |
+
"no_op": false,
|
909 |
+
"replace_with_linear": false
|
910 |
+
}
|
911 |
+
},
|
912 |
+
{
|
913 |
+
"attention": {
|
914 |
+
"n_heads_in_group": 8,
|
915 |
+
"no_op": false,
|
916 |
+
"replace_with_linear": false
|
917 |
+
},
|
918 |
+
"ffn": {
|
919 |
+
"ffn_mult": 5.25,
|
920 |
+
"no_op": false,
|
921 |
+
"replace_with_linear": false
|
922 |
+
}
|
923 |
+
},
|
924 |
+
{
|
925 |
+
"attention": {
|
926 |
+
"n_heads_in_group": 8,
|
927 |
+
"no_op": false,
|
928 |
+
"replace_with_linear": false
|
929 |
+
},
|
930 |
+
"ffn": {
|
931 |
+
"ffn_mult": 5.25,
|
932 |
+
"no_op": false,
|
933 |
+
"replace_with_linear": false
|
934 |
+
}
|
935 |
+
},
|
936 |
+
{
|
937 |
+
"attention": {
|
938 |
+
"n_heads_in_group": 8,
|
939 |
+
"no_op": false,
|
940 |
+
"replace_with_linear": false
|
941 |
+
},
|
942 |
+
"ffn": {
|
943 |
+
"ffn_mult": 5.25,
|
944 |
+
"no_op": false,
|
945 |
+
"replace_with_linear": false
|
946 |
+
}
|
947 |
+
},
|
948 |
+
{
|
949 |
+
"attention": {
|
950 |
+
"n_heads_in_group": 8,
|
951 |
+
"no_op": false,
|
952 |
+
"replace_with_linear": false
|
953 |
+
},
|
954 |
+
"ffn": {
|
955 |
+
"ffn_mult": 5.25,
|
956 |
+
"no_op": false,
|
957 |
+
"replace_with_linear": false
|
958 |
+
}
|
959 |
+
},
|
960 |
+
{
|
961 |
+
"attention": {
|
962 |
+
"n_heads_in_group": 8,
|
963 |
+
"no_op": false,
|
964 |
+
"replace_with_linear": false
|
965 |
+
},
|
966 |
+
"ffn": {
|
967 |
+
"ffn_mult": 5.25,
|
968 |
+
"no_op": false,
|
969 |
+
"replace_with_linear": false
|
970 |
+
}
|
971 |
+
}
|
972 |
+
],
|
973 |
+
"bos_token_id": 128000,
|
974 |
+
"eos_token_id": [
|
975 |
+
128001,
|
976 |
+
128008,
|
977 |
+
128009
|
978 |
+
],
|
979 |
+
"hidden_act": "silu",
|
980 |
+
"hidden_size": 8192,
|
981 |
+
"initializer_range": 0.02,
|
982 |
+
"intermediate_size": null,
|
983 |
+
"max_position_embeddings": 131072,
|
984 |
+
"mlp_bias": false,
|
985 |
+
"model_type": "nemotron-nas",
|
986 |
+
"num_attention_heads": 64,
|
987 |
+
"num_hidden_layers": 80,
|
988 |
+
"num_key_value_heads": null,
|
989 |
+
"pretraining_tp": 1,
|
990 |
+
"rms_norm_eps": 1e-05,
|
991 |
+
"rope_scaling": {
|
992 |
+
"factor": 8.0,
|
993 |
+
"high_freq_factor": 4.0,
|
994 |
+
"low_freq_factor": 1.0,
|
995 |
+
"original_max_position_embeddings": 8192,
|
996 |
+
"rope_type": "llama3"
|
997 |
+
},
|
998 |
+
"rope_theta": 500000.0,
|
999 |
+
"tie_word_embeddings": false,
|
1000 |
+
"torch_dtype": "bfloat16",
|
1001 |
+
"transformers_version": "4.44.2",
|
1002 |
+
"use_cache": true,
|
1003 |
+
"vocab_size": 128256
|
1004 |
+
}
|
configuration_decilm.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Nvidia Corporation. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import dataclasses
|
17 |
+
import warnings
|
18 |
+
from dataclasses import dataclass, MISSING
|
19 |
+
from functools import partial
|
20 |
+
from typing import Optional, Dict, Any
|
21 |
+
|
22 |
+
from .transformers_4_44_2__configuration_llama import LlamaConfig
|
23 |
+
from .transformers_4_44_2__modeling_rope_utils import \
|
24 |
+
rope_config_validation # fake import to make AutoConfig infer the dependency
|
25 |
+
|
26 |
+
|
27 |
+
class DeciLMConfig(LlamaConfig):
|
28 |
+
model_type = "nemotron-nas"
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
block_configs: list[dict] | list["BlockConfig"] = None,
|
33 |
+
**kwargs,
|
34 |
+
):
|
35 |
+
super().__init__(**kwargs)
|
36 |
+
self.intermediate_size = None
|
37 |
+
self.num_key_value_heads = None
|
38 |
+
|
39 |
+
if block_configs is not None:
|
40 |
+
assert len(block_configs) == self.num_hidden_layers
|
41 |
+
if isinstance(block_configs[0], dict):
|
42 |
+
block_configs = [BlockConfig(**conf) for conf in block_configs]
|
43 |
+
self.block_configs: list[BlockConfig] = block_configs
|
44 |
+
|
45 |
+
def to_dict(self) -> Dict[str, Any]:
|
46 |
+
self_dict = super().to_dict()
|
47 |
+
if self.block_configs is not None:
|
48 |
+
self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs]
|
49 |
+
return self_dict
|
50 |
+
|
51 |
+
|
52 |
+
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True)
|
53 |
+
class AttentionConfig:
|
54 |
+
no_op: bool = False
|
55 |
+
replace_with_linear: bool = False
|
56 |
+
n_heads_in_group: Optional[int] = None
|
57 |
+
|
58 |
+
def __post_init__(self):
|
59 |
+
assert not (self.no_op and self.replace_with_linear)
|
60 |
+
if self.no_op or self.replace_with_linear:
|
61 |
+
object.__setattr__(self, 'n_heads_in_group', None) # __setattr__ to overcome frozen=True
|
62 |
+
else:
|
63 |
+
assert self.n_heads_in_group is not None
|
64 |
+
|
65 |
+
|
66 |
+
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True)
|
67 |
+
class FFNConfig:
|
68 |
+
no_op: bool = False
|
69 |
+
replace_with_linear: bool = False
|
70 |
+
ffn_mult: Optional[float] = None
|
71 |
+
|
72 |
+
def __post_init__(self):
|
73 |
+
assert not (self.no_op and self.replace_with_linear)
|
74 |
+
if self.no_op or self.replace_with_linear:
|
75 |
+
object.__setattr__(self, 'ffn_mult', None) # __setattr__ to overcome frozen=True
|
76 |
+
else:
|
77 |
+
assert self.ffn_mult is not None
|
78 |
+
|
79 |
+
|
80 |
+
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True)
|
81 |
+
class BlockConfig:
|
82 |
+
attention: AttentionConfig = MISSING
|
83 |
+
ffn: FFNConfig = MISSING
|
84 |
+
|
85 |
+
def __post_init__(self):
|
86 |
+
"""
|
87 |
+
Init subblock dataclasses from dicts
|
88 |
+
"""
|
89 |
+
for subblock_name in dataclasses.fields(self):
|
90 |
+
subblock_config = getattr(self, subblock_name.name)
|
91 |
+
if isinstance(subblock_config, dict):
|
92 |
+
subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)]
|
93 |
+
unsupported_fields = [field_name for field_name in subblock_config.keys()
|
94 |
+
if field_name not in subblock_fields]
|
95 |
+
if len(unsupported_fields) > 0:
|
96 |
+
warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}")
|
97 |
+
subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields}
|
98 |
+
object.__setattr__(self, subblock_name.name,
|
99 |
+
subblock_name.type(**subblock_config)) # __setattr__ to overcome frozen=True
|
model-00001-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2990306c9e4715b943eaf3f91473a801dfe04392f3f8768b31edecd8d0eb755
|
3 |
+
size 4987128904
|
model-00002-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6662e17eca5fb9bd4d5a29c3fe5b6bb0ea9a38395da47d6d11d6b85a3f3061de
|
3 |
+
size 4873899312
|
model-00003-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1deba9fb520b3e46ebf0876ea68a913664be24ff542f4ea40d59398f4cafa89a
|
3 |
+
size 4899116128
|
model-00004-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe288cee7ef3dcf45f6b36ecd53c953f07ae01370fb3bb1383c230057fade1b5
|
3 |
+
size 4567782440
|
model-00005-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1387ab2cf92336e52fe7bbf221c6f87a1df82982a795832f2a0564279ab57fd
|
3 |
+
size 4999695232
|
model-00006-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aeb46455878ff7356a6b1be8e028741a4c3d236e00c9336c40390a062c6e1e2c
|
3 |
+
size 4966157072
|
model-00007-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d3c86e900f62463d453585887f1c087d7b9d235fe5c29251409040442d77826d
|
3 |
+
size 4664150920
|
model-00008-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d983395fcc42fdb792e2f7ad0816705b6bb229697a1df1c889f1ad6e1cc17cc5
|
3 |
+
size 4664167416
|
model-00009-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:817460bbcb2878d86051545cf77c87653e53ee20f5db702b8fcda581f6d8c0ba
|
3 |
+
size 4664167416
|
model-00010-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:984f9ec95ae6003feb2f31397a10a01837254ebf8fe6cd32c7c1f98af91d34fe
|
3 |
+
size 4999695232
|
model-00011-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:742ac2ff2d5b8bed496565c687ad1145d251f9685d780b3566756dfa21585cfe
|
3 |
+
size 4966157072
|
model-00012-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e06f3f256c108db9370fa30faca1335de8f7001c06a3a0f676dd9793d3d6e7ea
|
3 |
+
size 4664150920
|
model-00013-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c8a5930428eae56a7b86bcff24dc395005f23e8b2b66a062014306ac4c001e2
|
3 |
+
size 4664167416
|
model-00014-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8fb07d4f8176b5ad0fa65bd0094b0bcc2e39eed43e7d7824da7249f9bc8dfd9
|
3 |
+
size 4731276160
|
model-00015-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c81cac17603aeb4daba647542d5d31fed2ddf14f92664db8db158ee92e5e362
|
3 |
+
size 4899114336
|
model-00016-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a8bd6808d8525e29c163b8f2acc4a5dc48da7de3110e6c47ae3dd19752b5ea0
|
3 |
+
size 4983018288
|
model-00017-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3955e1d2905da27e1bb86df0d1f4cb380644c7d55d576276f52adc38c6e537d0
|
3 |
+
size 4899198680
|
model-00018-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fd3ce3dbde6831dcb280b89e1eeef52b976133091c7b2edb077d46c5b7133144
|
3 |
+
size 4647456992
|
model-00019-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f02a07b21d517079feafeedf6192b703dbf15aae712deafdc8cdaa82ba7a5b1f
|
3 |
+
size 4664167416
|
model-00020-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25ae2c6845e6108324d68c52bfcf520ddd24c6ad3ec90e32ec7e4a4ee4dda167
|
3 |
+
size 4664167416
|
model-00021-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0234aebf14d4cb42d483335a984cedb920d6c4285ccefcd4cbe6841c1c1f09c7
|
3 |
+
size 4831922696
|
model-00022-of-00022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:84a87a2039ad882e51b0b10c1b59f342a9d2efc18c95291d7ae5926d50a195cf
|
3 |
+
size 2101346432
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 103002030080
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"model.embed_tokens.weight": "model-00001-of-00022.safetensors",
|
7 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00022.safetensors",
|
8 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00022.safetensors",
|
9 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00022.safetensors",
|
10 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00022.safetensors",
|
11 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00022.safetensors",
|
12 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00022.safetensors",
|
13 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00022.safetensors",
|
14 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00022.safetensors",
|
15 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00022.safetensors",
|
16 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00022.safetensors",
|
17 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00022.safetensors",
|
18 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00022.safetensors",
|
19 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00022.safetensors",
|
20 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00022.safetensors",
|
21 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00022.safetensors",
|
22 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00022.safetensors",
|
23 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00022.safetensors",
|
24 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00022.safetensors",
|
25 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00022.safetensors",
|
26 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00022.safetensors",
|
27 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00022.safetensors",
|
28 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00022.safetensors",
|
29 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00022.safetensors",
|
30 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00022.safetensors",
|
31 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00022.safetensors",
|
32 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00022.safetensors",
|
33 |
+
"model.layers.2.mlp.down_proj.weight": "model-00002-of-00022.safetensors",
|
34 |
+
"model.layers.3.input_layernorm.weight": "model-00002-of-00022.safetensors",
|
35 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00002-of-00022.safetensors",
|
36 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00002-of-00022.safetensors",
|
37 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00002-of-00022.safetensors",
|
38 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00002-of-00022.safetensors",
|
39 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00002-of-00022.safetensors",
|
40 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00002-of-00022.safetensors",
|
41 |
+
"model.layers.3.mlp.up_proj.weight": "model-00002-of-00022.safetensors",
|
42 |
+
"model.layers.3.mlp.down_proj.weight": "model-00002-of-00022.safetensors",
|
43 |
+
"model.layers.4.input_layernorm.weight": "model-00002-of-00022.safetensors",
|
44 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00002-of-00022.safetensors",
|
45 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00002-of-00022.safetensors",
|
46 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00002-of-00022.safetensors",
|
47 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00002-of-00022.safetensors",
|
48 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00002-of-00022.safetensors",
|
49 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00002-of-00022.safetensors",
|
50 |
+
"model.layers.4.mlp.up_proj.weight": "model-00002-of-00022.safetensors",
|
51 |
+
"model.layers.4.mlp.down_proj.weight": "model-00002-of-00022.safetensors",
|
52 |
+
"model.layers.5.input_layernorm.weight": "model-00002-of-00022.safetensors",
|
53 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00002-of-00022.safetensors",
|
54 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00002-of-00022.safetensors",
|
55 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00002-of-00022.safetensors",
|
56 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00002-of-00022.safetensors",
|
57 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00002-of-00022.safetensors",
|
58 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00002-of-00022.safetensors",
|
59 |
+
"model.layers.5.mlp.up_proj.weight": "model-00002-of-00022.safetensors",
|
60 |
+
"model.layers.5.mlp.down_proj.weight": "model-00002-of-00022.safetensors",
|
61 |
+
"model.layers.6.input_layernorm.weight": "model-00002-of-00022.safetensors",
|
62 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00003-of-00022.safetensors",
|
63 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00003-of-00022.safetensors",
|
64 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00003-of-00022.safetensors",
|
65 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00003-of-00022.safetensors",
|
66 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00003-of-00022.safetensors",
|
67 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00003-of-00022.safetensors",
|
68 |
+
"model.layers.6.mlp.up_proj.weight": "model-00003-of-00022.safetensors",
|
69 |
+
"model.layers.6.mlp.down_proj.weight": "model-00003-of-00022.safetensors",
|
70 |
+
"model.layers.7.input_layernorm.weight": "model-00003-of-00022.safetensors",
|
71 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00003-of-00022.safetensors",
|
72 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00003-of-00022.safetensors",
|
73 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00003-of-00022.safetensors",
|
74 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00003-of-00022.safetensors",
|
75 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00003-of-00022.safetensors",
|
76 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00003-of-00022.safetensors",
|
77 |
+
"model.layers.7.mlp.up_proj.weight": "model-00003-of-00022.safetensors",
|
78 |
+
"model.layers.7.mlp.down_proj.weight": "model-00003-of-00022.safetensors",
|
79 |
+
"model.layers.8.input_layernorm.weight": "model-00003-of-00022.safetensors",
|
80 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00003-of-00022.safetensors",
|
81 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00003-of-00022.safetensors",
|
82 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00003-of-00022.safetensors",
|
83 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00003-of-00022.safetensors",
|
84 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00003-of-00022.safetensors",
|
85 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00003-of-00022.safetensors",
|
86 |
+
"model.layers.8.mlp.up_proj.weight": "model-00003-of-00022.safetensors",
|
87 |
+
"model.layers.8.mlp.down_proj.weight": "model-00003-of-00022.safetensors",
|
88 |
+
"model.layers.9.input_layernorm.weight": "model-00003-of-00022.safetensors",
|
89 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00003-of-00022.safetensors",
|
90 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00003-of-00022.safetensors",
|
91 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00003-of-00022.safetensors",
|
92 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00003-of-00022.safetensors",
|
93 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00003-of-00022.safetensors",
|
94 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00003-of-00022.safetensors",
|
95 |
+
"model.layers.9.mlp.up_proj.weight": "model-00003-of-00022.safetensors",
|
96 |
+
"model.layers.9.mlp.down_proj.weight": "model-00003-of-00022.safetensors",
|
97 |
+
"model.layers.10.input_layernorm.weight": "model-00003-of-00022.safetensors",
|
98 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00003-of-00022.safetensors",
|
99 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00003-of-00022.safetensors",
|
100 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00003-of-00022.safetensors",
|
101 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00003-of-00022.safetensors",
|
102 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00003-of-00022.safetensors",
|
103 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00003-of-00022.safetensors",
|
104 |
+
"model.layers.10.mlp.up_proj.weight": "model-00003-of-00022.safetensors",
|
105 |
+
"model.layers.10.mlp.down_proj.weight": "model-00003-of-00022.safetensors",
|
106 |
+
"model.layers.11.input_layernorm.weight": "model-00003-of-00022.safetensors",
|
107 |
+
"model.layers.11.self_attn.linear_attn.weight": "model-00004-of-00022.safetensors",
|
108 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00004-of-00022.safetensors",
|
109 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00004-of-00022.safetensors",
|
110 |
+
"model.layers.11.mlp.up_proj.weight": "model-00004-of-00022.safetensors",
|
111 |
+
"model.layers.11.mlp.down_proj.weight": "model-00004-of-00022.safetensors",
|
112 |
+
"model.layers.12.input_layernorm.weight": "model-00004-of-00022.safetensors",
|
113 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00004-of-00022.safetensors",
|
114 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00004-of-00022.safetensors",
|
115 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00004-of-00022.safetensors",
|
116 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00004-of-00022.safetensors",
|
117 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00004-of-00022.safetensors",
|
118 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00004-of-00022.safetensors",
|
119 |
+
"model.layers.12.mlp.up_proj.weight": "model-00004-of-00022.safetensors",
|
120 |
+
"model.layers.12.mlp.down_proj.weight": "model-00004-of-00022.safetensors",
|
121 |
+
"model.layers.13.input_layernorm.weight": "model-00004-of-00022.safetensors",
|
122 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00004-of-00022.safetensors",
|
123 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00004-of-00022.safetensors",
|
124 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00004-of-00022.safetensors",
|
125 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00004-of-00022.safetensors",
|
126 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00004-of-00022.safetensors",
|
127 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00004-of-00022.safetensors",
|
128 |
+
"model.layers.13.mlp.up_proj.weight": "model-00004-of-00022.safetensors",
|
129 |
+
"model.layers.13.mlp.down_proj.weight": "model-00004-of-00022.safetensors",
|
130 |
+
"model.layers.14.input_layernorm.weight": "model-00004-of-00022.safetensors",
|
131 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00004-of-00022.safetensors",
|
132 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00004-of-00022.safetensors",
|
133 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00004-of-00022.safetensors",
|
134 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00004-of-00022.safetensors",
|
135 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00004-of-00022.safetensors",
|
136 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00004-of-00022.safetensors",
|
137 |
+
"model.layers.14.mlp.up_proj.weight": "model-00004-of-00022.safetensors",
|
138 |
+
"model.layers.14.mlp.down_proj.weight": "model-00004-of-00022.safetensors",
|
139 |
+
"model.layers.15.input_layernorm.weight": "model-00004-of-00022.safetensors",
|
140 |
+
"model.layers.15.self_attn.linear_attn.weight": "model-00004-of-00022.safetensors",
|
141 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00004-of-00022.safetensors",
|
142 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00004-of-00022.safetensors",
|
143 |
+
"model.layers.15.mlp.up_proj.weight": "model-00004-of-00022.safetensors",
|
144 |
+
"model.layers.15.mlp.down_proj.weight": "model-00004-of-00022.safetensors",
|
145 |
+
"model.layers.16.input_layernorm.weight": "model-00004-of-00022.safetensors",
|
146 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00004-of-00022.safetensors",
|
147 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00004-of-00022.safetensors",
|
148 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00004-of-00022.safetensors",
|
149 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00004-of-00022.safetensors",
|
150 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00004-of-00022.safetensors",
|
151 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00005-of-00022.safetensors",
|
152 |
+
"model.layers.16.mlp.up_proj.weight": "model-00005-of-00022.safetensors",
|
153 |
+
"model.layers.16.mlp.down_proj.weight": "model-00005-of-00022.safetensors",
|
154 |
+
"model.layers.17.input_layernorm.weight": "model-00005-of-00022.safetensors",
|
155 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00005-of-00022.safetensors",
|
156 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00005-of-00022.safetensors",
|
157 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00005-of-00022.safetensors",
|
158 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00005-of-00022.safetensors",
|
159 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00005-of-00022.safetensors",
|
160 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00005-of-00022.safetensors",
|
161 |
+
"model.layers.17.mlp.up_proj.weight": "model-00005-of-00022.safetensors",
|
162 |
+
"model.layers.17.mlp.down_proj.weight": "model-00005-of-00022.safetensors",
|
163 |
+
"model.layers.18.input_layernorm.weight": "model-00005-of-00022.safetensors",
|
164 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00005-of-00022.safetensors",
|
165 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00005-of-00022.safetensors",
|
166 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00005-of-00022.safetensors",
|
167 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00005-of-00022.safetensors",
|
168 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00005-of-00022.safetensors",
|
169 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00005-of-00022.safetensors",
|
170 |
+
"model.layers.18.mlp.up_proj.weight": "model-00005-of-00022.safetensors",
|
171 |
+
"model.layers.18.mlp.down_proj.weight": "model-00005-of-00022.safetensors",
|
172 |
+
"model.layers.19.input_layernorm.weight": "model-00005-of-00022.safetensors",
|
173 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00005-of-00022.safetensors",
|
174 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00005-of-00022.safetensors",
|
175 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00005-of-00022.safetensors",
|
176 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00006-of-00022.safetensors",
|
177 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00006-of-00022.safetensors",
|
178 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00006-of-00022.safetensors",
|
179 |
+
"model.layers.19.mlp.up_proj.weight": "model-00006-of-00022.safetensors",
|
180 |
+
"model.layers.19.mlp.down_proj.weight": "model-00006-of-00022.safetensors",
|
181 |
+
"model.layers.20.input_layernorm.weight": "model-00006-of-00022.safetensors",
|
182 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00006-of-00022.safetensors",
|
183 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00006-of-00022.safetensors",
|
184 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00006-of-00022.safetensors",
|
185 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00006-of-00022.safetensors",
|
186 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00006-of-00022.safetensors",
|
187 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00006-of-00022.safetensors",
|
188 |
+
"model.layers.20.mlp.up_proj.weight": "model-00006-of-00022.safetensors",
|
189 |
+
"model.layers.20.mlp.down_proj.weight": "model-00006-of-00022.safetensors",
|
190 |
+
"model.layers.21.input_layernorm.weight": "model-00006-of-00022.safetensors",
|
191 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00006-of-00022.safetensors",
|
192 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00006-of-00022.safetensors",
|
193 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00006-of-00022.safetensors",
|
194 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00006-of-00022.safetensors",
|
195 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00006-of-00022.safetensors",
|
196 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00006-of-00022.safetensors",
|
197 |
+
"model.layers.21.mlp.up_proj.weight": "model-00006-of-00022.safetensors",
|
198 |
+
"model.layers.21.mlp.down_proj.weight": "model-00006-of-00022.safetensors",
|
199 |
+
"model.layers.22.input_layernorm.weight": "model-00006-of-00022.safetensors",
|
200 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00007-of-00022.safetensors",
|
201 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00007-of-00022.safetensors",
|
202 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00007-of-00022.safetensors",
|
203 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00007-of-00022.safetensors",
|
204 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00007-of-00022.safetensors",
|
205 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00007-of-00022.safetensors",
|
206 |
+
"model.layers.22.mlp.up_proj.weight": "model-00007-of-00022.safetensors",
|
207 |
+
"model.layers.22.mlp.down_proj.weight": "model-00007-of-00022.safetensors",
|
208 |
+
"model.layers.23.input_layernorm.weight": "model-00007-of-00022.safetensors",
|
209 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00007-of-00022.safetensors",
|
210 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00007-of-00022.safetensors",
|
211 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00007-of-00022.safetensors",
|
212 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00007-of-00022.safetensors",
|
213 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00007-of-00022.safetensors",
|
214 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00007-of-00022.safetensors",
|
215 |
+
"model.layers.23.mlp.up_proj.weight": "model-00007-of-00022.safetensors",
|
216 |
+
"model.layers.23.mlp.down_proj.weight": "model-00007-of-00022.safetensors",
|
217 |
+
"model.layers.24.input_layernorm.weight": "model-00007-of-00022.safetensors",
|
218 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00007-of-00022.safetensors",
|
219 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00007-of-00022.safetensors",
|
220 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00007-of-00022.safetensors",
|
221 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00007-of-00022.safetensors",
|
222 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00007-of-00022.safetensors",
|
223 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00007-of-00022.safetensors",
|
224 |
+
"model.layers.24.mlp.up_proj.weight": "model-00007-of-00022.safetensors",
|
225 |
+
"model.layers.24.mlp.down_proj.weight": "model-00008-of-00022.safetensors",
|
226 |
+
"model.layers.25.input_layernorm.weight": "model-00008-of-00022.safetensors",
|
227 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00008-of-00022.safetensors",
|
228 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00008-of-00022.safetensors",
|
229 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00008-of-00022.safetensors",
|
230 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00008-of-00022.safetensors",
|
231 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00008-of-00022.safetensors",
|
232 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00008-of-00022.safetensors",
|
233 |
+
"model.layers.25.mlp.up_proj.weight": "model-00008-of-00022.safetensors",
|
234 |
+
"model.layers.25.mlp.down_proj.weight": "model-00008-of-00022.safetensors",
|
235 |
+
"model.layers.26.input_layernorm.weight": "model-00008-of-00022.safetensors",
|
236 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00008-of-00022.safetensors",
|
237 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00008-of-00022.safetensors",
|
238 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00008-of-00022.safetensors",
|
239 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00008-of-00022.safetensors",
|
240 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00008-of-00022.safetensors",
|
241 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00008-of-00022.safetensors",
|
242 |
+
"model.layers.26.mlp.up_proj.weight": "model-00008-of-00022.safetensors",
|
243 |
+
"model.layers.26.mlp.down_proj.weight": "model-00008-of-00022.safetensors",
|
244 |
+
"model.layers.27.input_layernorm.weight": "model-00008-of-00022.safetensors",
|
245 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00008-of-00022.safetensors",
|
246 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00008-of-00022.safetensors",
|
247 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00008-of-00022.safetensors",
|
248 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00008-of-00022.safetensors",
|
249 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00008-of-00022.safetensors",
|
250 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00008-of-00022.safetensors",
|
251 |
+
"model.layers.27.mlp.up_proj.weight": "model-00009-of-00022.safetensors",
|
252 |
+
"model.layers.27.mlp.down_proj.weight": "model-00009-of-00022.safetensors",
|
253 |
+
"model.layers.28.input_layernorm.weight": "model-00009-of-00022.safetensors",
|
254 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00009-of-00022.safetensors",
|
255 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00009-of-00022.safetensors",
|
256 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00009-of-00022.safetensors",
|
257 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00009-of-00022.safetensors",
|
258 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00009-of-00022.safetensors",
|
259 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00009-of-00022.safetensors",
|
260 |
+
"model.layers.28.mlp.up_proj.weight": "model-00009-of-00022.safetensors",
|
261 |
+
"model.layers.28.mlp.down_proj.weight": "model-00009-of-00022.safetensors",
|
262 |
+
"model.layers.29.input_layernorm.weight": "model-00009-of-00022.safetensors",
|
263 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00009-of-00022.safetensors",
|
264 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00009-of-00022.safetensors",
|
265 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00009-of-00022.safetensors",
|
266 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00009-of-00022.safetensors",
|
267 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00009-of-00022.safetensors",
|
268 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00009-of-00022.safetensors",
|
269 |
+
"model.layers.29.mlp.up_proj.weight": "model-00009-of-00022.safetensors",
|
270 |
+
"model.layers.29.mlp.down_proj.weight": "model-00009-of-00022.safetensors",
|
271 |
+
"model.layers.30.input_layernorm.weight": "model-00009-of-00022.safetensors",
|
272 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00009-of-00022.safetensors",
|
273 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00009-of-00022.safetensors",
|
274 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00009-of-00022.safetensors",
|
275 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00009-of-00022.safetensors",
|
276 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00009-of-00022.safetensors",
|
277 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00010-of-00022.safetensors",
|
278 |
+
"model.layers.30.mlp.up_proj.weight": "model-00010-of-00022.safetensors",
|
279 |
+
"model.layers.30.mlp.down_proj.weight": "model-00010-of-00022.safetensors",
|
280 |
+
"model.layers.31.input_layernorm.weight": "model-00010-of-00022.safetensors",
|
281 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00010-of-00022.safetensors",
|
282 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00010-of-00022.safetensors",
|
283 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00010-of-00022.safetensors",
|
284 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00010-of-00022.safetensors",
|
285 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00010-of-00022.safetensors",
|
286 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00010-of-00022.safetensors",
|
287 |
+
"model.layers.31.mlp.up_proj.weight": "model-00010-of-00022.safetensors",
|
288 |
+
"model.layers.31.mlp.down_proj.weight": "model-00010-of-00022.safetensors",
|
289 |
+
"model.layers.32.input_layernorm.weight": "model-00010-of-00022.safetensors",
|
290 |
+
"model.layers.32.self_attn.q_proj.weight": "model-00010-of-00022.safetensors",
|
291 |
+
"model.layers.32.self_attn.k_proj.weight": "model-00010-of-00022.safetensors",
|
292 |
+
"model.layers.32.self_attn.v_proj.weight": "model-00010-of-00022.safetensors",
|
293 |
+
"model.layers.32.self_attn.o_proj.weight": "model-00010-of-00022.safetensors",
|
294 |
+
"model.layers.32.post_attention_layernorm.weight": "model-00010-of-00022.safetensors",
|
295 |
+
"model.layers.32.mlp.gate_proj.weight": "model-00010-of-00022.safetensors",
|
296 |
+
"model.layers.32.mlp.up_proj.weight": "model-00010-of-00022.safetensors",
|
297 |
+
"model.layers.32.mlp.down_proj.weight": "model-00010-of-00022.safetensors",
|
298 |
+
"model.layers.33.input_layernorm.weight": "model-00010-of-00022.safetensors",
|
299 |
+
"model.layers.33.self_attn.q_proj.weight": "model-00010-of-00022.safetensors",
|
300 |
+
"model.layers.33.self_attn.k_proj.weight": "model-00010-of-00022.safetensors",
|
301 |
+
"model.layers.33.self_attn.v_proj.weight": "model-00010-of-00022.safetensors",
|
302 |
+
"model.layers.33.self_attn.o_proj.weight": "model-00011-of-00022.safetensors",
|
303 |
+
"model.layers.33.post_attention_layernorm.weight": "model-00011-of-00022.safetensors",
|
304 |
+
"model.layers.33.mlp.gate_proj.weight": "model-00011-of-00022.safetensors",
|
305 |
+
"model.layers.33.mlp.up_proj.weight": "model-00011-of-00022.safetensors",
|
306 |
+
"model.layers.33.mlp.down_proj.weight": "model-00011-of-00022.safetensors",
|
307 |
+
"model.layers.34.input_layernorm.weight": "model-00011-of-00022.safetensors",
|
308 |
+
"model.layers.34.self_attn.q_proj.weight": "model-00011-of-00022.safetensors",
|
309 |
+
"model.layers.34.self_attn.k_proj.weight": "model-00011-of-00022.safetensors",
|
310 |
+
"model.layers.34.self_attn.v_proj.weight": "model-00011-of-00022.safetensors",
|
311 |
+
"model.layers.34.self_attn.o_proj.weight": "model-00011-of-00022.safetensors",
|
312 |
+
"model.layers.34.post_attention_layernorm.weight": "model-00011-of-00022.safetensors",
|
313 |
+
"model.layers.34.mlp.gate_proj.weight": "model-00011-of-00022.safetensors",
|
314 |
+
"model.layers.34.mlp.up_proj.weight": "model-00011-of-00022.safetensors",
|
315 |
+
"model.layers.34.mlp.down_proj.weight": "model-00011-of-00022.safetensors",
|
316 |
+
"model.layers.35.input_layernorm.weight": "model-00011-of-00022.safetensors",
|
317 |
+
"model.layers.35.self_attn.q_proj.weight": "model-00011-of-00022.safetensors",
|
318 |
+
"model.layers.35.self_attn.k_proj.weight": "model-00011-of-00022.safetensors",
|
319 |
+
"model.layers.35.self_attn.v_proj.weight": "model-00011-of-00022.safetensors",
|
320 |
+
"model.layers.35.self_attn.o_proj.weight": "model-00011-of-00022.safetensors",
|
321 |
+
"model.layers.35.post_attention_layernorm.weight": "model-00011-of-00022.safetensors",
|
322 |
+
"model.layers.35.mlp.gate_proj.weight": "model-00011-of-00022.safetensors",
|
323 |
+
"model.layers.35.mlp.up_proj.weight": "model-00011-of-00022.safetensors",
|
324 |
+
"model.layers.35.mlp.down_proj.weight": "model-00011-of-00022.safetensors",
|
325 |
+
"model.layers.36.input_layernorm.weight": "model-00011-of-00022.safetensors",
|
326 |
+
"model.layers.36.self_attn.q_proj.weight": "model-00012-of-00022.safetensors",
|
327 |
+
"model.layers.36.self_attn.k_proj.weight": "model-00012-of-00022.safetensors",
|
328 |
+
"model.layers.36.self_attn.v_proj.weight": "model-00012-of-00022.safetensors",
|
329 |
+
"model.layers.36.self_attn.o_proj.weight": "model-00012-of-00022.safetensors",
|
330 |
+
"model.layers.36.post_attention_layernorm.weight": "model-00012-of-00022.safetensors",
|
331 |
+
"model.layers.36.mlp.gate_proj.weight": "model-00012-of-00022.safetensors",
|
332 |
+
"model.layers.36.mlp.up_proj.weight": "model-00012-of-00022.safetensors",
|
333 |
+
"model.layers.36.mlp.down_proj.weight": "model-00012-of-00022.safetensors",
|
334 |
+
"model.layers.37.input_layernorm.weight": "model-00012-of-00022.safetensors",
|
335 |
+
"model.layers.37.self_attn.q_proj.weight": "model-00012-of-00022.safetensors",
|
336 |
+
"model.layers.37.self_attn.k_proj.weight": "model-00012-of-00022.safetensors",
|
337 |
+
"model.layers.37.self_attn.v_proj.weight": "model-00012-of-00022.safetensors",
|
338 |
+
"model.layers.37.self_attn.o_proj.weight": "model-00012-of-00022.safetensors",
|
339 |
+
"model.layers.37.post_attention_layernorm.weight": "model-00012-of-00022.safetensors",
|
340 |
+
"model.layers.37.mlp.gate_proj.weight": "model-00012-of-00022.safetensors",
|
341 |
+
"model.layers.37.mlp.up_proj.weight": "model-00012-of-00022.safetensors",
|
342 |
+
"model.layers.37.mlp.down_proj.weight": "model-00012-of-00022.safetensors",
|
343 |
+
"model.layers.38.input_layernorm.weight": "model-00012-of-00022.safetensors",
|
344 |
+
"model.layers.38.self_attn.q_proj.weight": "model-00012-of-00022.safetensors",
|
345 |
+
"model.layers.38.self_attn.k_proj.weight": "model-00012-of-00022.safetensors",
|
346 |
+
"model.layers.38.self_attn.v_proj.weight": "model-00012-of-00022.safetensors",
|
347 |
+
"model.layers.38.self_attn.o_proj.weight": "model-00012-of-00022.safetensors",
|
348 |
+
"model.layers.38.post_attention_layernorm.weight": "model-00012-of-00022.safetensors",
|
349 |
+
"model.layers.38.mlp.gate_proj.weight": "model-00012-of-00022.safetensors",
|
350 |
+
"model.layers.38.mlp.up_proj.weight": "model-00012-of-00022.safetensors",
|
351 |
+
"model.layers.38.mlp.down_proj.weight": "model-00013-of-00022.safetensors",
|
352 |
+
"model.layers.39.input_layernorm.weight": "model-00013-of-00022.safetensors",
|
353 |
+
"model.layers.39.self_attn.q_proj.weight": "model-00013-of-00022.safetensors",
|
354 |
+
"model.layers.39.self_attn.k_proj.weight": "model-00013-of-00022.safetensors",
|
355 |
+
"model.layers.39.self_attn.v_proj.weight": "model-00013-of-00022.safetensors",
|
356 |
+
"model.layers.39.self_attn.o_proj.weight": "model-00013-of-00022.safetensors",
|
357 |
+
"model.layers.39.post_attention_layernorm.weight": "model-00013-of-00022.safetensors",
|
358 |
+
"model.layers.39.mlp.gate_proj.weight": "model-00013-of-00022.safetensors",
|
359 |
+
"model.layers.39.mlp.up_proj.weight": "model-00013-of-00022.safetensors",
|
360 |
+
"model.layers.39.mlp.down_proj.weight": "model-00013-of-00022.safetensors",
|
361 |
+
"model.layers.40.input_layernorm.weight": "model-00013-of-00022.safetensors",
|
362 |
+
"model.layers.40.self_attn.q_proj.weight": "model-00013-of-00022.safetensors",
|
363 |
+
"model.layers.40.self_attn.k_proj.weight": "model-00013-of-00022.safetensors",
|
364 |
+
"model.layers.40.self_attn.v_proj.weight": "model-00013-of-00022.safetensors",
|
365 |
+
"model.layers.40.self_attn.o_proj.weight": "model-00013-of-00022.safetensors",
|
366 |
+
"model.layers.40.post_attention_layernorm.weight": "model-00013-of-00022.safetensors",
|
367 |
+
"model.layers.40.mlp.gate_proj.weight": "model-00013-of-00022.safetensors",
|
368 |
+
"model.layers.40.mlp.up_proj.weight": "model-00013-of-00022.safetensors",
|
369 |
+
"model.layers.40.mlp.down_proj.weight": "model-00013-of-00022.safetensors",
|
370 |
+
"model.layers.41.input_layernorm.weight": "model-00013-of-00022.safetensors",
|
371 |
+
"model.layers.41.self_attn.q_proj.weight": "model-00013-of-00022.safetensors",
|
372 |
+
"model.layers.41.self_attn.k_proj.weight": "model-00013-of-00022.safetensors",
|
373 |
+
"model.layers.41.self_attn.v_proj.weight": "model-00013-of-00022.safetensors",
|
374 |
+
"model.layers.41.self_attn.o_proj.weight": "model-00013-of-00022.safetensors",
|
375 |
+
"model.layers.41.post_attention_layernorm.weight": "model-00013-of-00022.safetensors",
|
376 |
+
"model.layers.41.mlp.gate_proj.weight": "model-00013-of-00022.safetensors",
|
377 |
+
"model.layers.41.mlp.up_proj.weight": "model-00014-of-00022.safetensors",
|
378 |
+
"model.layers.41.mlp.down_proj.weight": "model-00014-of-00022.safetensors",
|
379 |
+
"model.layers.42.input_layernorm.weight": "model-00014-of-00022.safetensors",
|
380 |
+
"model.layers.42.self_attn.linear_attn.weight": "model-00014-of-00022.safetensors",
|
381 |
+
"model.layers.42.post_attention_layernorm.weight": "model-00014-of-00022.safetensors",
|
382 |
+
"model.layers.42.mlp.gate_proj.weight": "model-00014-of-00022.safetensors",
|
383 |
+
"model.layers.42.mlp.up_proj.weight": "model-00014-of-00022.safetensors",
|
384 |
+
"model.layers.42.mlp.down_proj.weight": "model-00014-of-00022.safetensors",
|
385 |
+
"model.layers.43.input_layernorm.weight": "model-00014-of-00022.safetensors",
|
386 |
+
"model.layers.43.self_attn.q_proj.weight": "model-00014-of-00022.safetensors",
|
387 |
+
"model.layers.43.self_attn.k_proj.weight": "model-00014-of-00022.safetensors",
|
388 |
+
"model.layers.43.self_attn.v_proj.weight": "model-00014-of-00022.safetensors",
|
389 |
+
"model.layers.43.self_attn.o_proj.weight": "model-00014-of-00022.safetensors",
|
390 |
+
"model.layers.43.post_attention_layernorm.weight": "model-00014-of-00022.safetensors",
|
391 |
+
"model.layers.43.mlp.gate_proj.weight": "model-00014-of-00022.safetensors",
|
392 |
+
"model.layers.43.mlp.up_proj.weight": "model-00014-of-00022.safetensors",
|
393 |
+
"model.layers.43.mlp.down_proj.weight": "model-00014-of-00022.safetensors",
|
394 |
+
"model.layers.44.input_layernorm.weight": "model-00014-of-00022.safetensors",
|
395 |
+
"model.layers.44.self_attn.q_proj.weight": "model-00014-of-00022.safetensors",
|
396 |
+
"model.layers.44.self_attn.k_proj.weight": "model-00014-of-00022.safetensors",
|
397 |
+
"model.layers.44.self_attn.v_proj.weight": "model-00014-of-00022.safetensors",
|
398 |
+
"model.layers.44.self_attn.o_proj.weight": "model-00014-of-00022.safetensors",
|
399 |
+
"model.layers.44.post_attention_layernorm.weight": "model-00014-of-00022.safetensors",
|
400 |
+
"model.layers.44.mlp.gate_proj.weight": "model-00014-of-00022.safetensors",
|
401 |
+
"model.layers.44.mlp.up_proj.weight": "model-00014-of-00022.safetensors",
|
402 |
+
"model.layers.44.mlp.down_proj.weight": "model-00015-of-00022.safetensors",
|
403 |
+
"model.layers.45.input_layernorm.weight": "model-00015-of-00022.safetensors",
|
404 |
+
"model.layers.45.self_attn.linear_attn.weight": "model-00015-of-00022.safetensors",
|
405 |
+
"model.layers.45.post_attention_layernorm.weight": "model-00015-of-00022.safetensors",
|
406 |
+
"model.layers.45.mlp.gate_proj.weight": "model-00015-of-00022.safetensors",
|
407 |
+
"model.layers.45.mlp.up_proj.weight": "model-00015-of-00022.safetensors",
|
408 |
+
"model.layers.45.mlp.down_proj.weight": "model-00015-of-00022.safetensors",
|
409 |
+
"model.layers.46.input_layernorm.weight": "model-00015-of-00022.safetensors",
|
410 |
+
"model.layers.46.self_attn.linear_attn.weight": "model-00015-of-00022.safetensors",
|
411 |
+
"model.layers.46.post_attention_layernorm.weight": "model-00015-of-00022.safetensors",
|
412 |
+
"model.layers.46.mlp.gate_proj.weight": "model-00015-of-00022.safetensors",
|
413 |
+
"model.layers.46.mlp.up_proj.weight": "model-00015-of-00022.safetensors",
|
414 |
+
"model.layers.46.mlp.down_proj.weight": "model-00015-of-00022.safetensors",
|
415 |
+
"model.layers.47.input_layernorm.weight": "model-00015-of-00022.safetensors",
|
416 |
+
"model.layers.47.self_attn.linear_attn.weight": "model-00015-of-00022.safetensors",
|
417 |
+
"model.layers.47.post_attention_layernorm.weight": "model-00015-of-00022.safetensors",
|
418 |
+
"model.layers.47.mlp.gate_proj.weight": "model-00015-of-00022.safetensors",
|
419 |
+
"model.layers.47.mlp.up_proj.weight": "model-00015-of-00022.safetensors",
|
420 |
+
"model.layers.47.mlp.down_proj.weight": "model-00015-of-00022.safetensors",
|
421 |
+
"model.layers.48.input_layernorm.weight": "model-00015-of-00022.safetensors",
|
422 |
+
"model.layers.48.self_attn.linear_attn.weight": "model-00015-of-00022.safetensors",
|
423 |
+
"model.layers.48.post_attention_layernorm.weight": "model-00015-of-00022.safetensors",
|
424 |
+
"model.layers.48.mlp.gate_proj.weight": "model-00015-of-00022.safetensors",
|
425 |
+
"model.layers.48.mlp.up_proj.weight": "model-00015-of-00022.safetensors",
|
426 |
+
"model.layers.48.mlp.down_proj.weight": "model-00015-of-00022.safetensors",
|
427 |
+
"model.layers.49.input_layernorm.weight": "model-00015-of-00022.safetensors",
|
428 |
+
"model.layers.49.self_attn.linear_attn.weight": "model-00015-of-00022.safetensors",
|
429 |
+
"model.layers.49.post_attention_layernorm.weight": "model-00015-of-00022.safetensors",
|
430 |
+
"model.layers.49.mlp.gate_proj.weight": "model-00015-of-00022.safetensors",
|
431 |
+
"model.layers.49.mlp.up_proj.weight": "model-00016-of-00022.safetensors",
|
432 |
+
"model.layers.49.mlp.down_proj.weight": "model-00016-of-00022.safetensors",
|
433 |
+
"model.layers.50.post_attention_layernorm.weight": "model-00016-of-00022.safetensors",
|
434 |
+
"model.layers.50.mlp.gate_proj.weight": "model-00016-of-00022.safetensors",
|
435 |
+
"model.layers.50.mlp.up_proj.weight": "model-00016-of-00022.safetensors",
|
436 |
+
"model.layers.50.mlp.down_proj.weight": "model-00016-of-00022.safetensors",
|
437 |
+
"model.layers.51.input_layernorm.weight": "model-00016-of-00022.safetensors",
|
438 |
+
"model.layers.51.self_attn.linear_attn.weight": "model-00016-of-00022.safetensors",
|
439 |
+
"model.layers.51.post_attention_layernorm.weight": "model-00016-of-00022.safetensors",
|
440 |
+
"model.layers.51.mlp.gate_proj.weight": "model-00016-of-00022.safetensors",
|
441 |
+
"model.layers.51.mlp.up_proj.weight": "model-00016-of-00022.safetensors",
|
442 |
+
"model.layers.51.mlp.down_proj.weight": "model-00016-of-00022.safetensors",
|
443 |
+
"model.layers.52.input_layernorm.weight": "model-00016-of-00022.safetensors",
|
444 |
+
"model.layers.52.self_attn.q_proj.weight": "model-00016-of-00022.safetensors",
|
445 |
+
"model.layers.52.self_attn.k_proj.weight": "model-00016-of-00022.safetensors",
|
446 |
+
"model.layers.52.self_attn.v_proj.weight": "model-00016-of-00022.safetensors",
|
447 |
+
"model.layers.52.self_attn.o_proj.weight": "model-00016-of-00022.safetensors",
|
448 |
+
"model.layers.52.post_attention_layernorm.weight": "model-00016-of-00022.safetensors",
|
449 |
+
"model.layers.52.mlp.gate_proj.weight": "model-00016-of-00022.safetensors",
|
450 |
+
"model.layers.52.mlp.up_proj.weight": "model-00016-of-00022.safetensors",
|
451 |
+
"model.layers.52.mlp.down_proj.weight": "model-00016-of-00022.safetensors",
|
452 |
+
"model.layers.53.post_attention_layernorm.weight": "model-00016-of-00022.safetensors",
|
453 |
+
"model.layers.53.mlp.gate_proj.weight": "model-00016-of-00022.safetensors",
|
454 |
+
"model.layers.53.mlp.up_proj.weight": "model-00016-of-00022.safetensors",
|
455 |
+
"model.layers.53.mlp.down_proj.weight": "model-00016-of-00022.safetensors",
|
456 |
+
"model.layers.54.input_layernorm.weight": "model-00016-of-00022.safetensors",
|
457 |
+
"model.layers.54.self_attn.linear_attn.weight": "model-00016-of-00022.safetensors",
|
458 |
+
"model.layers.54.post_attention_layernorm.weight": "model-00016-of-00022.safetensors",
|
459 |
+
"model.layers.54.mlp.gate_proj.weight": "model-00016-of-00022.safetensors",
|
460 |
+
"model.layers.54.mlp.up_proj.weight": "model-00016-of-00022.safetensors",
|
461 |
+
"model.layers.54.mlp.down_proj.weight": "model-00016-of-00022.safetensors",
|
462 |
+
"model.layers.55.post_attention_layernorm.weight": "model-00016-of-00022.safetensors",
|
463 |
+
"model.layers.55.mlp.gate_proj.weight": "model-00016-of-00022.safetensors",
|
464 |
+
"model.layers.55.mlp.up_proj.weight": "model-00016-of-00022.safetensors",
|
465 |
+
"model.layers.55.mlp.down_proj.weight": "model-00016-of-00022.safetensors",
|
466 |
+
"model.layers.56.input_layernorm.weight": "model-00016-of-00022.safetensors",
|
467 |
+
"model.layers.56.self_attn.q_proj.weight": "model-00016-of-00022.safetensors",
|
468 |
+
"model.layers.56.self_attn.k_proj.weight": "model-00016-of-00022.safetensors",
|
469 |
+
"model.layers.56.self_attn.v_proj.weight": "model-00016-of-00022.safetensors",
|
470 |
+
"model.layers.56.self_attn.o_proj.weight": "model-00016-of-00022.safetensors",
|
471 |
+
"model.layers.56.post_attention_layernorm.weight": "model-00016-of-00022.safetensors",
|
472 |
+
"model.layers.56.mlp.gate_proj.weight": "model-00016-of-00022.safetensors",
|
473 |
+
"model.layers.56.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
474 |
+
"model.layers.56.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
475 |
+
"model.layers.57.input_layernorm.weight": "model-00017-of-00022.safetensors",
|
476 |
+
"model.layers.57.self_attn.linear_attn.weight": "model-00017-of-00022.safetensors",
|
477 |
+
"model.layers.57.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
478 |
+
"model.layers.57.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
479 |
+
"model.layers.57.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
480 |
+
"model.layers.57.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
481 |
+
"model.layers.58.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
482 |
+
"model.layers.58.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
483 |
+
"model.layers.58.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
484 |
+
"model.layers.58.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
485 |
+
"model.layers.59.input_layernorm.weight": "model-00017-of-00022.safetensors",
|
486 |
+
"model.layers.59.self_attn.linear_attn.weight": "model-00017-of-00022.safetensors",
|
487 |
+
"model.layers.59.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
488 |
+
"model.layers.59.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
489 |
+
"model.layers.59.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
490 |
+
"model.layers.59.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
491 |
+
"model.layers.60.input_layernorm.weight": "model-00017-of-00022.safetensors",
|
492 |
+
"model.layers.60.self_attn.linear_attn.weight": "model-00017-of-00022.safetensors",
|
493 |
+
"model.layers.60.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
494 |
+
"model.layers.60.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
495 |
+
"model.layers.60.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
496 |
+
"model.layers.60.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
497 |
+
"model.layers.61.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
498 |
+
"model.layers.61.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
499 |
+
"model.layers.61.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
500 |
+
"model.layers.61.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
501 |
+
"model.layers.62.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
502 |
+
"model.layers.62.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
503 |
+
"model.layers.62.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
504 |
+
"model.layers.62.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
505 |
+
"model.layers.63.input_layernorm.weight": "model-00017-of-00022.safetensors",
|
506 |
+
"model.layers.63.self_attn.linear_attn.weight": "model-00017-of-00022.safetensors",
|
507 |
+
"model.layers.63.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
508 |
+
"model.layers.63.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
509 |
+
"model.layers.63.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
510 |
+
"model.layers.63.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
511 |
+
"model.layers.64.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
512 |
+
"model.layers.64.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
513 |
+
"model.layers.64.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
514 |
+
"model.layers.64.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
515 |
+
"model.layers.65.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
516 |
+
"model.layers.65.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
517 |
+
"model.layers.65.mlp.up_proj.weight": "model-00017-of-00022.safetensors",
|
518 |
+
"model.layers.65.mlp.down_proj.weight": "model-00017-of-00022.safetensors",
|
519 |
+
"model.layers.66.input_layernorm.weight": "model-00017-of-00022.safetensors",
|
520 |
+
"model.layers.66.self_attn.linear_attn.weight": "model-00017-of-00022.safetensors",
|
521 |
+
"model.layers.66.post_attention_layernorm.weight": "model-00017-of-00022.safetensors",
|
522 |
+
"model.layers.66.mlp.gate_proj.weight": "model-00017-of-00022.safetensors",
|
523 |
+
"model.layers.66.mlp.up_proj.weight": "model-00018-of-00022.safetensors",
|
524 |
+
"model.layers.66.mlp.down_proj.weight": "model-00018-of-00022.safetensors",
|
525 |
+
"model.layers.67.input_layernorm.weight": "model-00018-of-00022.safetensors",
|
526 |
+
"model.layers.67.self_attn.linear_attn.weight": "model-00018-of-00022.safetensors",
|
527 |
+
"model.layers.67.post_attention_layernorm.weight": "model-00018-of-00022.safetensors",
|
528 |
+
"model.layers.67.mlp.gate_proj.weight": "model-00018-of-00022.safetensors",
|
529 |
+
"model.layers.67.mlp.up_proj.weight": "model-00018-of-00022.safetensors",
|
530 |
+
"model.layers.67.mlp.down_proj.weight": "model-00018-of-00022.safetensors",
|
531 |
+
"model.layers.68.input_layernorm.weight": "model-00018-of-00022.safetensors",
|
532 |
+
"model.layers.68.self_attn.linear_attn.weight": "model-00018-of-00022.safetensors",
|
533 |
+
"model.layers.68.post_attention_layernorm.weight": "model-00018-of-00022.safetensors",
|
534 |
+
"model.layers.68.mlp.gate_proj.weight": "model-00018-of-00022.safetensors",
|
535 |
+
"model.layers.68.mlp.up_proj.weight": "model-00018-of-00022.safetensors",
|
536 |
+
"model.layers.68.mlp.down_proj.weight": "model-00018-of-00022.safetensors",
|
537 |
+
"model.layers.69.input_layernorm.weight": "model-00018-of-00022.safetensors",
|
538 |
+
"model.layers.69.self_attn.linear_attn.weight": "model-00018-of-00022.safetensors",
|
539 |
+
"model.layers.69.post_attention_layernorm.weight": "model-00018-of-00022.safetensors",
|
540 |
+
"model.layers.69.mlp.gate_proj.weight": "model-00018-of-00022.safetensors",
|
541 |
+
"model.layers.69.mlp.up_proj.weight": "model-00018-of-00022.safetensors",
|
542 |
+
"model.layers.69.mlp.down_proj.weight": "model-00018-of-00022.safetensors",
|
543 |
+
"model.layers.70.input_layernorm.weight": "model-00018-of-00022.safetensors",
|
544 |
+
"model.layers.70.self_attn.q_proj.weight": "model-00018-of-00022.safetensors",
|
545 |
+
"model.layers.70.self_attn.k_proj.weight": "model-00018-of-00022.safetensors",
|
546 |
+
"model.layers.70.self_attn.v_proj.weight": "model-00018-of-00022.safetensors",
|
547 |
+
"model.layers.70.self_attn.o_proj.weight": "model-00018-of-00022.safetensors",
|
548 |
+
"model.layers.70.post_attention_layernorm.weight": "model-00018-of-00022.safetensors",
|
549 |
+
"model.layers.70.mlp.gate_proj.weight": "model-00018-of-00022.safetensors",
|
550 |
+
"model.layers.70.mlp.up_proj.weight": "model-00018-of-00022.safetensors",
|
551 |
+
"model.layers.70.mlp.down_proj.weight": "model-00018-of-00022.safetensors",
|
552 |
+
"model.layers.71.input_layernorm.weight": "model-00018-of-00022.safetensors",
|
553 |
+
"model.layers.71.self_attn.q_proj.weight": "model-00018-of-00022.safetensors",
|
554 |
+
"model.layers.71.self_attn.k_proj.weight": "model-00018-of-00022.safetensors",
|
555 |
+
"model.layers.71.self_attn.v_proj.weight": "model-00018-of-00022.safetensors",
|
556 |
+
"model.layers.71.self_attn.o_proj.weight": "model-00018-of-00022.safetensors",
|
557 |
+
"model.layers.71.post_attention_layernorm.weight": "model-00018-of-00022.safetensors",
|
558 |
+
"model.layers.71.mlp.gate_proj.weight": "model-00018-of-00022.safetensors",
|
559 |
+
"model.layers.71.mlp.up_proj.weight": "model-00018-of-00022.safetensors",
|
560 |
+
"model.layers.71.mlp.down_proj.weight": "model-00019-of-00022.safetensors",
|
561 |
+
"model.layers.72.input_layernorm.weight": "model-00019-of-00022.safetensors",
|
562 |
+
"model.layers.72.self_attn.q_proj.weight": "model-00019-of-00022.safetensors",
|
563 |
+
"model.layers.72.self_attn.k_proj.weight": "model-00019-of-00022.safetensors",
|
564 |
+
"model.layers.72.self_attn.v_proj.weight": "model-00019-of-00022.safetensors",
|
565 |
+
"model.layers.72.self_attn.o_proj.weight": "model-00019-of-00022.safetensors",
|
566 |
+
"model.layers.72.post_attention_layernorm.weight": "model-00019-of-00022.safetensors",
|
567 |
+
"model.layers.72.mlp.gate_proj.weight": "model-00019-of-00022.safetensors",
|
568 |
+
"model.layers.72.mlp.up_proj.weight": "model-00019-of-00022.safetensors",
|
569 |
+
"model.layers.72.mlp.down_proj.weight": "model-00019-of-00022.safetensors",
|
570 |
+
"model.layers.73.input_layernorm.weight": "model-00019-of-00022.safetensors",
|
571 |
+
"model.layers.73.self_attn.q_proj.weight": "model-00019-of-00022.safetensors",
|
572 |
+
"model.layers.73.self_attn.k_proj.weight": "model-00019-of-00022.safetensors",
|
573 |
+
"model.layers.73.self_attn.v_proj.weight": "model-00019-of-00022.safetensors",
|
574 |
+
"model.layers.73.self_attn.o_proj.weight": "model-00019-of-00022.safetensors",
|
575 |
+
"model.layers.73.post_attention_layernorm.weight": "model-00019-of-00022.safetensors",
|
576 |
+
"model.layers.73.mlp.gate_proj.weight": "model-00019-of-00022.safetensors",
|
577 |
+
"model.layers.73.mlp.up_proj.weight": "model-00019-of-00022.safetensors",
|
578 |
+
"model.layers.73.mlp.down_proj.weight": "model-00019-of-00022.safetensors",
|
579 |
+
"model.layers.74.input_layernorm.weight": "model-00019-of-00022.safetensors",
|
580 |
+
"model.layers.74.self_attn.q_proj.weight": "model-00019-of-00022.safetensors",
|
581 |
+
"model.layers.74.self_attn.k_proj.weight": "model-00019-of-00022.safetensors",
|
582 |
+
"model.layers.74.self_attn.v_proj.weight": "model-00019-of-00022.safetensors",
|
583 |
+
"model.layers.74.self_attn.o_proj.weight": "model-00019-of-00022.safetensors",
|
584 |
+
"model.layers.74.post_attention_layernorm.weight": "model-00019-of-00022.safetensors",
|
585 |
+
"model.layers.74.mlp.gate_proj.weight": "model-00019-of-00022.safetensors",
|
586 |
+
"model.layers.74.mlp.up_proj.weight": "model-00020-of-00022.safetensors",
|
587 |
+
"model.layers.74.mlp.down_proj.weight": "model-00020-of-00022.safetensors",
|
588 |
+
"model.layers.75.input_layernorm.weight": "model-00020-of-00022.safetensors",
|
589 |
+
"model.layers.75.self_attn.q_proj.weight": "model-00020-of-00022.safetensors",
|
590 |
+
"model.layers.75.self_attn.k_proj.weight": "model-00020-of-00022.safetensors",
|
591 |
+
"model.layers.75.self_attn.v_proj.weight": "model-00020-of-00022.safetensors",
|
592 |
+
"model.layers.75.self_attn.o_proj.weight": "model-00020-of-00022.safetensors",
|
593 |
+
"model.layers.75.post_attention_layernorm.weight": "model-00020-of-00022.safetensors",
|
594 |
+
"model.layers.75.mlp.gate_proj.weight": "model-00020-of-00022.safetensors",
|
595 |
+
"model.layers.75.mlp.up_proj.weight": "model-00020-of-00022.safetensors",
|
596 |
+
"model.layers.75.mlp.down_proj.weight": "model-00020-of-00022.safetensors",
|
597 |
+
"model.layers.76.input_layernorm.weight": "model-00020-of-00022.safetensors",
|
598 |
+
"model.layers.76.self_attn.q_proj.weight": "model-00020-of-00022.safetensors",
|
599 |
+
"model.layers.76.self_attn.k_proj.weight": "model-00020-of-00022.safetensors",
|
600 |
+
"model.layers.76.self_attn.v_proj.weight": "model-00020-of-00022.safetensors",
|
601 |
+
"model.layers.76.self_attn.o_proj.weight": "model-00020-of-00022.safetensors",
|
602 |
+
"model.layers.76.post_attention_layernorm.weight": "model-00020-of-00022.safetensors",
|
603 |
+
"model.layers.76.mlp.gate_proj.weight": "model-00020-of-00022.safetensors",
|
604 |
+
"model.layers.76.mlp.up_proj.weight": "model-00020-of-00022.safetensors",
|
605 |
+
"model.layers.76.mlp.down_proj.weight": "model-00020-of-00022.safetensors",
|
606 |
+
"model.layers.77.input_layernorm.weight": "model-00020-of-00022.safetensors",
|
607 |
+
"model.layers.77.self_attn.q_proj.weight": "model-00020-of-00022.safetensors",
|
608 |
+
"model.layers.77.self_attn.k_proj.weight": "model-00020-of-00022.safetensors",
|
609 |
+
"model.layers.77.self_attn.v_proj.weight": "model-00020-of-00022.safetensors",
|
610 |
+
"model.layers.77.self_attn.o_proj.weight": "model-00020-of-00022.safetensors",
|
611 |
+
"model.layers.77.post_attention_layernorm.weight": "model-00020-of-00022.safetensors",
|
612 |
+
"model.layers.77.mlp.gate_proj.weight": "model-00021-of-00022.safetensors",
|
613 |
+
"model.layers.77.mlp.up_proj.weight": "model-00021-of-00022.safetensors",
|
614 |
+
"model.layers.77.mlp.down_proj.weight": "model-00021-of-00022.safetensors",
|
615 |
+
"model.layers.78.input_layernorm.weight": "model-00021-of-00022.safetensors",
|
616 |
+
"model.layers.78.self_attn.q_proj.weight": "model-00021-of-00022.safetensors",
|
617 |
+
"model.layers.78.self_attn.k_proj.weight": "model-00021-of-00022.safetensors",
|
618 |
+
"model.layers.78.self_attn.v_proj.weight": "model-00021-of-00022.safetensors",
|
619 |
+
"model.layers.78.self_attn.o_proj.weight": "model-00021-of-00022.safetensors",
|
620 |
+
"model.layers.78.post_attention_layernorm.weight": "model-00021-of-00022.safetensors",
|
621 |
+
"model.layers.78.mlp.gate_proj.weight": "model-00021-of-00022.safetensors",
|
622 |
+
"model.layers.78.mlp.up_proj.weight": "model-00021-of-00022.safetensors",
|
623 |
+
"model.layers.78.mlp.down_proj.weight": "model-00021-of-00022.safetensors",
|
624 |
+
"model.layers.79.input_layernorm.weight": "model-00021-of-00022.safetensors",
|
625 |
+
"model.layers.79.self_attn.q_proj.weight": "model-00021-of-00022.safetensors",
|
626 |
+
"model.layers.79.self_attn.k_proj.weight": "model-00021-of-00022.safetensors",
|
627 |
+
"model.layers.79.self_attn.v_proj.weight": "model-00021-of-00022.safetensors",
|
628 |
+
"model.layers.79.self_attn.o_proj.weight": "model-00021-of-00022.safetensors",
|
629 |
+
"model.layers.79.post_attention_layernorm.weight": "model-00021-of-00022.safetensors",
|
630 |
+
"model.layers.79.mlp.gate_proj.weight": "model-00021-of-00022.safetensors",
|
631 |
+
"model.layers.79.mlp.up_proj.weight": "model-00021-of-00022.safetensors",
|
632 |
+
"model.layers.79.mlp.down_proj.weight": "model-00021-of-00022.safetensors",
|
633 |
+
"model.norm.weight": "model-00021-of-00022.safetensors",
|
634 |
+
"lm_head.weight": "model-00022-of-00022.safetensors"
|
635 |
+
}
|
636 |
+
}
|
modeling_decilm.py
ADDED
@@ -0,0 +1,1665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 EleutherAI, HuggingFace Inc, Nvidia Corporation. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on the Llama modeling code by HuggingFace, which is in turn based on
|
5 |
+
# EleutherAI's GPT-NeoX library and the GPT-NeoX and OPT implementations in this library.
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
from transformers import GenerationConfig
|
28 |
+
from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING
|
29 |
+
from transformers.modeling_utils import PreTrainedModel
|
30 |
+
from transformers.utils import (
|
31 |
+
add_start_docstrings,
|
32 |
+
add_start_docstrings_to_model_forward,
|
33 |
+
is_flash_attn_greater_or_equal_2_10,
|
34 |
+
logging,
|
35 |
+
replace_return_docstrings,
|
36 |
+
)
|
37 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
38 |
+
|
39 |
+
from .configuration_decilm import DeciLMConfig, AttentionConfig, FFNConfig
|
40 |
+
from .transformers_4_44_2__activations import ACT2FN
|
41 |
+
from .transformers_4_44_2__cache_utils import Cache, StaticCache
|
42 |
+
from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter
|
43 |
+
from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import _flash_attention_forward
|
44 |
+
from .transformers_4_44_2__modeling_outputs import (
|
45 |
+
BaseModelOutputWithPast,
|
46 |
+
CausalLMOutputWithPast,
|
47 |
+
QuestionAnsweringModelOutput,
|
48 |
+
SequenceClassifierOutputWithPast,
|
49 |
+
TokenClassifierOutput,
|
50 |
+
)
|
51 |
+
from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
52 |
+
from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS
|
53 |
+
from .variable_cache import VariableCache
|
54 |
+
|
55 |
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[DeciLMConfig.model_type] = "DeciLMForCausalLM"
|
56 |
+
logger = logging.get_logger(__name__)
|
57 |
+
|
58 |
+
_CONFIG_FOR_DOC = "DeciLMConfig"
|
59 |
+
|
60 |
+
|
61 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
62 |
+
attention_mask: torch.Tensor,
|
63 |
+
sequence_length: int,
|
64 |
+
target_length: int,
|
65 |
+
dtype: torch.dtype,
|
66 |
+
device: torch.device,
|
67 |
+
min_dtype: float,
|
68 |
+
cache_position: torch.Tensor,
|
69 |
+
batch_size: int,
|
70 |
+
):
|
71 |
+
"""
|
72 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
73 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
attention_mask (`torch.Tensor`):
|
77 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
78 |
+
sequence_length (`int`):
|
79 |
+
The sequence length being processed.
|
80 |
+
target_length (`int`):
|
81 |
+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
82 |
+
dtype (`torch.dtype`):
|
83 |
+
The dtype to use for the 4D attention mask.
|
84 |
+
device (`torch.device`):
|
85 |
+
The device to plcae the 4D attention mask on.
|
86 |
+
min_dtype (`float`):
|
87 |
+
The minimum value representable with the dtype `dtype`.
|
88 |
+
cache_position (`torch.Tensor`):
|
89 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
90 |
+
batch_size (`torch.Tensor`):
|
91 |
+
Batch size.
|
92 |
+
"""
|
93 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
94 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
95 |
+
causal_mask = attention_mask
|
96 |
+
else:
|
97 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
98 |
+
if sequence_length != 1:
|
99 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
100 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
101 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
102 |
+
if attention_mask is not None:
|
103 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
104 |
+
mask_length = attention_mask.shape[-1]
|
105 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
106 |
+
padding_mask = padding_mask == 0
|
107 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
108 |
+
padding_mask, min_dtype
|
109 |
+
)
|
110 |
+
|
111 |
+
return causal_mask
|
112 |
+
|
113 |
+
|
114 |
+
class DeciLMRMSNorm(nn.Module):
|
115 |
+
def __init__(self, hidden_size, eps=1e-6):
|
116 |
+
"""
|
117 |
+
DeciLMRMSNorm is equivalent to T5LayerNorm
|
118 |
+
"""
|
119 |
+
super().__init__()
|
120 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
121 |
+
self.variance_epsilon = eps
|
122 |
+
|
123 |
+
def forward(self, hidden_states):
|
124 |
+
input_dtype = hidden_states.dtype
|
125 |
+
hidden_states = hidden_states.to(torch.float32)
|
126 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
127 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
128 |
+
return self.weight * hidden_states.to(input_dtype)
|
129 |
+
|
130 |
+
def extra_repr(self):
|
131 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
132 |
+
|
133 |
+
|
134 |
+
ALL_LAYERNORM_LAYERS.append(DeciLMRMSNorm)
|
135 |
+
|
136 |
+
|
137 |
+
class DeciLMRotaryEmbedding(nn.Module):
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
dim=None,
|
141 |
+
max_position_embeddings=2048,
|
142 |
+
base=10000,
|
143 |
+
device=None,
|
144 |
+
scaling_factor=1.0,
|
145 |
+
rope_type="default",
|
146 |
+
config: Optional[DeciLMConfig] = None,
|
147 |
+
):
|
148 |
+
super().__init__()
|
149 |
+
# TODO (joao): remove the `if` below, only used for BC
|
150 |
+
self.rope_kwargs = {}
|
151 |
+
if config is None:
|
152 |
+
logger.warning_once(
|
153 |
+
"`DeciLMRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
154 |
+
"`config` argument. All other arguments will be removed in v4.45"
|
155 |
+
)
|
156 |
+
self.rope_kwargs = {
|
157 |
+
"rope_type": rope_type,
|
158 |
+
"factor": scaling_factor,
|
159 |
+
"dim": dim,
|
160 |
+
"base": base,
|
161 |
+
"max_position_embeddings": max_position_embeddings,
|
162 |
+
}
|
163 |
+
self.rope_type = rope_type
|
164 |
+
self.max_seq_len_cached = max_position_embeddings
|
165 |
+
self.original_max_seq_len = max_position_embeddings
|
166 |
+
else:
|
167 |
+
# BC: "rope_type" was originally "type"
|
168 |
+
if config.rope_scaling is not None:
|
169 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
170 |
+
else:
|
171 |
+
self.rope_type = "default"
|
172 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
173 |
+
self.original_max_seq_len = config.max_position_embeddings
|
174 |
+
|
175 |
+
self.config = config
|
176 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
177 |
+
|
178 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
179 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
180 |
+
self.original_inv_freq = self.inv_freq
|
181 |
+
|
182 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
183 |
+
"""
|
184 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
185 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
186 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
187 |
+
"""
|
188 |
+
seq_len = torch.max(position_ids) + 1
|
189 |
+
if seq_len > self.max_seq_len_cached: # growth
|
190 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
191 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
192 |
+
)
|
193 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
194 |
+
self.max_seq_len_cached = seq_len
|
195 |
+
|
196 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
197 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
198 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def forward(self, x, position_ids):
|
202 |
+
if "dynamic" in self.rope_type:
|
203 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
204 |
+
|
205 |
+
# Core RoPE block
|
206 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
207 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
208 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
209 |
+
device_type = x.device.type
|
210 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
211 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
212 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
213 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
214 |
+
cos = emb.cos()
|
215 |
+
sin = emb.sin()
|
216 |
+
|
217 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
218 |
+
cos = cos * self.attention_scaling
|
219 |
+
sin = sin * self.attention_scaling
|
220 |
+
|
221 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
222 |
+
|
223 |
+
|
224 |
+
class DeciLMLinearScalingRotaryEmbedding(DeciLMRotaryEmbedding):
|
225 |
+
"""DeciLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
226 |
+
|
227 |
+
def __init__(self, *args, **kwargs):
|
228 |
+
logger.warning_once(
|
229 |
+
"`DeciLMLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
|
230 |
+
"`DeciLMRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
231 |
+
)
|
232 |
+
kwargs["rope_type"] = "linear"
|
233 |
+
super().__init__(*args, **kwargs)
|
234 |
+
|
235 |
+
|
236 |
+
class DeciLMDynamicNTKScalingRotaryEmbedding(DeciLMRotaryEmbedding):
|
237 |
+
"""DeciLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
238 |
+
|
239 |
+
def __init__(self, *args, **kwargs):
|
240 |
+
logger.warning_once(
|
241 |
+
"`DeciLMDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
|
242 |
+
"`DeciLMRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
243 |
+
"__init__)."
|
244 |
+
)
|
245 |
+
kwargs["rope_type"] = "dynamic"
|
246 |
+
super().__init__(*args, **kwargs)
|
247 |
+
|
248 |
+
|
249 |
+
def rotate_half(x):
|
250 |
+
"""Rotates half the hidden dims of the input."""
|
251 |
+
x1 = x[..., : x.shape[-1] // 2]
|
252 |
+
x2 = x[..., x.shape[-1] // 2:]
|
253 |
+
return torch.cat((-x2, x1), dim=-1)
|
254 |
+
|
255 |
+
|
256 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
257 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
q (`torch.Tensor`): The query tensor.
|
261 |
+
k (`torch.Tensor`): The key tensor.
|
262 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
263 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
264 |
+
position_ids (`torch.Tensor`, *optional*):
|
265 |
+
Deprecated and unused.
|
266 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
267 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
268 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
269 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
270 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
271 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
272 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
273 |
+
Returns:
|
274 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
275 |
+
"""
|
276 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
277 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
278 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
279 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
280 |
+
return q_embed, k_embed
|
281 |
+
|
282 |
+
|
283 |
+
class DeciLMMLP(nn.Module):
|
284 |
+
def __init__(self,
|
285 |
+
config: DeciLMConfig,
|
286 |
+
ffn_config: FFNConfig,
|
287 |
+
):
|
288 |
+
super().__init__()
|
289 |
+
self.config = config
|
290 |
+
self.hidden_size = config.hidden_size
|
291 |
+
self.intermediate_size = _ffn_mult_to_intermediate_size(
|
292 |
+
ffn_config.ffn_mult, config.hidden_size) # DeciLM-specific code
|
293 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
294 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
295 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
296 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
297 |
+
|
298 |
+
def forward(self, x):
|
299 |
+
if self.config.pretraining_tp > 1:
|
300 |
+
slice = self.intermediate_size // self.config.pretraining_tp
|
301 |
+
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
302 |
+
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
303 |
+
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
304 |
+
|
305 |
+
gate_proj = torch.cat(
|
306 |
+
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
307 |
+
)
|
308 |
+
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
309 |
+
|
310 |
+
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
311 |
+
down_proj = [
|
312 |
+
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
313 |
+
]
|
314 |
+
down_proj = sum(down_proj)
|
315 |
+
else:
|
316 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
317 |
+
|
318 |
+
return down_proj
|
319 |
+
|
320 |
+
|
321 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
322 |
+
"""
|
323 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
324 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
325 |
+
"""
|
326 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
327 |
+
if n_rep == 1:
|
328 |
+
return hidden_states
|
329 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
330 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
331 |
+
|
332 |
+
|
333 |
+
class DeciLMAttention(nn.Module):
|
334 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
335 |
+
|
336 |
+
def __init__(self,
|
337 |
+
config: DeciLMConfig,
|
338 |
+
attention_config: AttentionConfig,
|
339 |
+
layer_idx: Optional[int] = None,
|
340 |
+
):
|
341 |
+
super().__init__()
|
342 |
+
self.config = config
|
343 |
+
self.layer_idx = layer_idx
|
344 |
+
if layer_idx is None:
|
345 |
+
logger.warning_once(
|
346 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
347 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
348 |
+
"when creating this class."
|
349 |
+
)
|
350 |
+
|
351 |
+
self.attention_dropout = config.attention_dropout
|
352 |
+
self.hidden_size = config.hidden_size
|
353 |
+
self.num_heads = config.num_attention_heads
|
354 |
+
self.head_dim = self.hidden_size // self.num_heads
|
355 |
+
self.num_key_value_groups = attention_config.n_heads_in_group # DeciLM-specific code
|
356 |
+
self.num_key_value_heads = self.num_heads // self.num_key_value_groups # DeciLM-specific code
|
357 |
+
self.max_position_embeddings = config.max_position_embeddings
|
358 |
+
self.rope_theta = config.rope_theta
|
359 |
+
self.is_causal = True
|
360 |
+
|
361 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
362 |
+
raise ValueError(
|
363 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
364 |
+
f" and `num_heads`: {self.num_heads})."
|
365 |
+
)
|
366 |
+
|
367 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
368 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
369 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
370 |
+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
371 |
+
|
372 |
+
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
|
373 |
+
self.rotary_emb = DeciLMRotaryEmbedding(config=self.config)
|
374 |
+
|
375 |
+
def forward(
|
376 |
+
self,
|
377 |
+
hidden_states: torch.Tensor,
|
378 |
+
attention_mask: Optional[torch.Tensor] = None,
|
379 |
+
position_ids: Optional[torch.LongTensor] = None,
|
380 |
+
past_key_value: Optional[Cache] = None,
|
381 |
+
output_attentions: bool = False,
|
382 |
+
use_cache: bool = False,
|
383 |
+
cache_position: Optional[torch.LongTensor] = None,
|
384 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
385 |
+
**kwargs,
|
386 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
387 |
+
bsz, q_len, _ = hidden_states.size()
|
388 |
+
if self.config.pretraining_tp > 1:
|
389 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
390 |
+
query_slices = self.q_proj.weight.split(
|
391 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
392 |
+
)
|
393 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
394 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
395 |
+
|
396 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
397 |
+
query_states = torch.cat(query_states, dim=-1)
|
398 |
+
|
399 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
400 |
+
key_states = torch.cat(key_states, dim=-1)
|
401 |
+
|
402 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
403 |
+
value_states = torch.cat(value_states, dim=-1)
|
404 |
+
|
405 |
+
else:
|
406 |
+
query_states = self.q_proj(hidden_states)
|
407 |
+
key_states = self.k_proj(hidden_states)
|
408 |
+
value_states = self.v_proj(hidden_states)
|
409 |
+
|
410 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
411 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
412 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
413 |
+
|
414 |
+
if position_embeddings is None:
|
415 |
+
logger.warning_once(
|
416 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
417 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
418 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
419 |
+
"removed and `position_embeddings` will be mandatory."
|
420 |
+
)
|
421 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
422 |
+
else:
|
423 |
+
cos, sin = position_embeddings
|
424 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
425 |
+
|
426 |
+
if past_key_value is not None:
|
427 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
428 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
429 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
430 |
+
|
431 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
432 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
433 |
+
|
434 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
435 |
+
|
436 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
437 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
438 |
+
attn_weights = attn_weights + causal_mask
|
439 |
+
|
440 |
+
# upcast attention to fp32
|
441 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
442 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
443 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
444 |
+
|
445 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
446 |
+
raise ValueError(
|
447 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
448 |
+
f" {attn_output.size()}"
|
449 |
+
)
|
450 |
+
|
451 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
452 |
+
|
453 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
454 |
+
|
455 |
+
if self.config.pretraining_tp > 1:
|
456 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
457 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
458 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
459 |
+
else:
|
460 |
+
attn_output = self.o_proj(attn_output)
|
461 |
+
|
462 |
+
if not output_attentions:
|
463 |
+
attn_weights = None
|
464 |
+
|
465 |
+
return attn_output, attn_weights, past_key_value
|
466 |
+
|
467 |
+
|
468 |
+
class DeciLMFlashAttention2(DeciLMAttention):
|
469 |
+
"""
|
470 |
+
DeciLM flash attention module. This module inherits from `DeciLMAttention` as the weights of the module stays
|
471 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
472 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
473 |
+
"""
|
474 |
+
|
475 |
+
def __init__(self, *args, **kwargs):
|
476 |
+
super().__init__(*args, **kwargs)
|
477 |
+
|
478 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
479 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
480 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
481 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
482 |
+
|
483 |
+
def forward(
|
484 |
+
self,
|
485 |
+
hidden_states: torch.Tensor,
|
486 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
487 |
+
position_ids: Optional[torch.LongTensor] = None,
|
488 |
+
past_key_value: Optional[Cache] = None,
|
489 |
+
output_attentions: bool = False,
|
490 |
+
use_cache: bool = False,
|
491 |
+
cache_position: Optional[torch.LongTensor] = None,
|
492 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
493 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
494 |
+
if isinstance(past_key_value, StaticCache):
|
495 |
+
raise ValueError(
|
496 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
497 |
+
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
498 |
+
)
|
499 |
+
output_attentions = False
|
500 |
+
|
501 |
+
bsz, q_len, _ = hidden_states.size()
|
502 |
+
|
503 |
+
query_states = self.q_proj(hidden_states)
|
504 |
+
key_states = self.k_proj(hidden_states)
|
505 |
+
value_states = self.v_proj(hidden_states)
|
506 |
+
|
507 |
+
# Flash attention requires the input to have the shape
|
508 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
509 |
+
# therefore we just need to keep the original shape
|
510 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
511 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
512 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
513 |
+
|
514 |
+
if position_embeddings is None:
|
515 |
+
logger.warning_once(
|
516 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
517 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
518 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
519 |
+
"removed and `position_embeddings` will be mandatory."
|
520 |
+
)
|
521 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
522 |
+
else:
|
523 |
+
cos, sin = position_embeddings
|
524 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
525 |
+
|
526 |
+
if past_key_value is not None:
|
527 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
528 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
529 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
530 |
+
|
531 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
532 |
+
# to be able to avoid many of these transpose/reshape/view.
|
533 |
+
query_states = query_states.transpose(1, 2)
|
534 |
+
key_states = key_states.transpose(1, 2)
|
535 |
+
value_states = value_states.transpose(1, 2)
|
536 |
+
|
537 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
538 |
+
|
539 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
540 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
541 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
542 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
543 |
+
# in fp32. (DeciLMRMSNorm handles it correctly)
|
544 |
+
|
545 |
+
input_dtype = query_states.dtype
|
546 |
+
if input_dtype == torch.float32:
|
547 |
+
if torch.is_autocast_enabled():
|
548 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
549 |
+
# Handle the case where the model is quantized
|
550 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
551 |
+
target_dtype = self.config._pre_quantization_dtype
|
552 |
+
else:
|
553 |
+
target_dtype = self.q_proj.weight.dtype
|
554 |
+
|
555 |
+
logger.warning_once(
|
556 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
557 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
558 |
+
f" {target_dtype}."
|
559 |
+
)
|
560 |
+
|
561 |
+
query_states = query_states.to(target_dtype)
|
562 |
+
key_states = key_states.to(target_dtype)
|
563 |
+
value_states = value_states.to(target_dtype)
|
564 |
+
|
565 |
+
attn_output = _flash_attention_forward(
|
566 |
+
query_states,
|
567 |
+
key_states,
|
568 |
+
value_states,
|
569 |
+
attention_mask,
|
570 |
+
q_len,
|
571 |
+
position_ids=position_ids,
|
572 |
+
dropout=dropout_rate,
|
573 |
+
sliding_window=getattr(self, "sliding_window", None),
|
574 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
575 |
+
is_causal=self.is_causal,
|
576 |
+
)
|
577 |
+
|
578 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
579 |
+
attn_output = self.o_proj(attn_output)
|
580 |
+
|
581 |
+
if not output_attentions:
|
582 |
+
attn_weights = None
|
583 |
+
|
584 |
+
return attn_output, attn_weights, past_key_value
|
585 |
+
|
586 |
+
|
587 |
+
class DeciLMSdpaAttention(DeciLMAttention):
|
588 |
+
"""
|
589 |
+
DeciLM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
590 |
+
`DeciLMAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
591 |
+
SDPA API.
|
592 |
+
"""
|
593 |
+
|
594 |
+
# Adapted from DeciLMAttention.forward
|
595 |
+
def forward(
|
596 |
+
self,
|
597 |
+
hidden_states: torch.Tensor,
|
598 |
+
attention_mask: Optional[torch.Tensor] = None,
|
599 |
+
position_ids: Optional[torch.LongTensor] = None,
|
600 |
+
past_key_value: Optional[Cache] = None,
|
601 |
+
output_attentions: bool = False,
|
602 |
+
use_cache: bool = False,
|
603 |
+
cache_position: Optional[torch.LongTensor] = None,
|
604 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
605 |
+
**kwargs,
|
606 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
607 |
+
if output_attentions:
|
608 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
609 |
+
logger.warning_once(
|
610 |
+
"DeciLMModel is using DeciLMSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
611 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
612 |
+
)
|
613 |
+
return super().forward(
|
614 |
+
hidden_states=hidden_states,
|
615 |
+
attention_mask=attention_mask,
|
616 |
+
position_ids=position_ids,
|
617 |
+
past_key_value=past_key_value,
|
618 |
+
output_attentions=output_attentions,
|
619 |
+
use_cache=use_cache,
|
620 |
+
cache_position=cache_position,
|
621 |
+
position_embeddings=position_embeddings,
|
622 |
+
)
|
623 |
+
|
624 |
+
bsz, q_len, _ = hidden_states.size()
|
625 |
+
|
626 |
+
query_states = self.q_proj(hidden_states)
|
627 |
+
key_states = self.k_proj(hidden_states)
|
628 |
+
value_states = self.v_proj(hidden_states)
|
629 |
+
|
630 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
631 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
632 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
633 |
+
|
634 |
+
if position_embeddings is None:
|
635 |
+
logger.warning_once(
|
636 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
637 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
638 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
639 |
+
"removed and `position_embeddings` will be mandatory."
|
640 |
+
)
|
641 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
642 |
+
else:
|
643 |
+
cos, sin = position_embeddings
|
644 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
645 |
+
|
646 |
+
if past_key_value is not None:
|
647 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
648 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
649 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
650 |
+
|
651 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
652 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
653 |
+
|
654 |
+
causal_mask = attention_mask
|
655 |
+
if attention_mask is not None:
|
656 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
657 |
+
|
658 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
659 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
660 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
661 |
+
query_states = query_states.contiguous()
|
662 |
+
key_states = key_states.contiguous()
|
663 |
+
value_states = value_states.contiguous()
|
664 |
+
|
665 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
666 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
667 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
668 |
+
|
669 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
670 |
+
query_states,
|
671 |
+
key_states,
|
672 |
+
value_states,
|
673 |
+
attn_mask=causal_mask,
|
674 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
675 |
+
is_causal=is_causal,
|
676 |
+
)
|
677 |
+
|
678 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
679 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
680 |
+
|
681 |
+
attn_output = self.o_proj(attn_output)
|
682 |
+
|
683 |
+
return attn_output, None, past_key_value
|
684 |
+
|
685 |
+
|
686 |
+
DECILM_ATTENTION_CLASSES = {
|
687 |
+
"eager": DeciLMAttention,
|
688 |
+
"flash_attention_2": DeciLMFlashAttention2,
|
689 |
+
"sdpa": DeciLMSdpaAttention,
|
690 |
+
}
|
691 |
+
|
692 |
+
|
693 |
+
class DeciLMDecoderLayer(nn.Module):
|
694 |
+
# DeciLM-specific code
|
695 |
+
def __init__(self, config: DeciLMConfig, layer_idx: int):
|
696 |
+
super().__init__()
|
697 |
+
self.hidden_size = config.hidden_size
|
698 |
+
self.block_config = config.block_configs[layer_idx]
|
699 |
+
|
700 |
+
if not self.block_config.attention.no_op:
|
701 |
+
self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
702 |
+
if not self.block_config.attention.replace_with_linear:
|
703 |
+
self.self_attn = DECILM_ATTENTION_CLASSES[config._attn_implementation](
|
704 |
+
config=config, attention_config=self.block_config.attention, layer_idx=layer_idx)
|
705 |
+
else:
|
706 |
+
self.self_attn = DeciLMLinearAttention(config)
|
707 |
+
|
708 |
+
if not self.block_config.ffn.no_op:
|
709 |
+
self.post_attention_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
710 |
+
if not self.block_config.ffn.replace_with_linear:
|
711 |
+
self.mlp = DeciLMMLP(config, self.block_config.ffn)
|
712 |
+
else:
|
713 |
+
self.mlp = DeciLMLinearMLP(config)
|
714 |
+
|
715 |
+
def forward(
|
716 |
+
self,
|
717 |
+
hidden_states: torch.Tensor,
|
718 |
+
attention_mask: Optional[torch.Tensor] = None,
|
719 |
+
position_ids: Optional[torch.LongTensor] = None,
|
720 |
+
past_key_value: Optional[Cache] = None,
|
721 |
+
output_attentions: Optional[bool] = False,
|
722 |
+
use_cache: Optional[bool] = False,
|
723 |
+
cache_position: Optional[torch.LongTensor] = None,
|
724 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
725 |
+
**kwargs,
|
726 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
727 |
+
"""
|
728 |
+
Args:
|
729 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
730 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
731 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
732 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
733 |
+
output_attentions (`bool`, *optional*):
|
734 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
735 |
+
returned tensors for more detail.
|
736 |
+
use_cache (`bool`, *optional*):
|
737 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
738 |
+
(see `past_key_values`).
|
739 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
740 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
741 |
+
Indices depicting the position of the input sequence tokens in the sequence
|
742 |
+
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
743 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
744 |
+
with `head_dim` being the embedding dimension of each attention head.
|
745 |
+
kwargs (`dict`, *optional*):
|
746 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
747 |
+
into the model
|
748 |
+
"""
|
749 |
+
self_attn_weights = None
|
750 |
+
present_key_value = past_key_value
|
751 |
+
if self.block_config.attention.no_op:
|
752 |
+
pass
|
753 |
+
elif self.block_config.attention.replace_with_linear:
|
754 |
+
residual = hidden_states
|
755 |
+
hidden_states = self.input_layernorm(hidden_states)
|
756 |
+
hidden_states = self.self_attn(hidden_states)
|
757 |
+
hidden_states = residual + hidden_states
|
758 |
+
else:
|
759 |
+
residual = hidden_states
|
760 |
+
hidden_states = self.input_layernorm(hidden_states)
|
761 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
762 |
+
hidden_states=hidden_states,
|
763 |
+
attention_mask=attention_mask,
|
764 |
+
position_ids=position_ids,
|
765 |
+
past_key_value=past_key_value,
|
766 |
+
output_attentions=output_attentions,
|
767 |
+
use_cache=use_cache,
|
768 |
+
cache_position=cache_position,
|
769 |
+
position_embeddings=position_embeddings,
|
770 |
+
**kwargs,
|
771 |
+
)
|
772 |
+
hidden_states = residual + hidden_states
|
773 |
+
|
774 |
+
if not self.block_config.ffn.no_op:
|
775 |
+
residual = hidden_states
|
776 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
777 |
+
hidden_states = self.mlp(hidden_states)
|
778 |
+
hidden_states = residual + hidden_states
|
779 |
+
|
780 |
+
outputs = (hidden_states,)
|
781 |
+
|
782 |
+
if output_attentions:
|
783 |
+
outputs += (self_attn_weights,)
|
784 |
+
|
785 |
+
if use_cache:
|
786 |
+
outputs += (present_key_value,)
|
787 |
+
|
788 |
+
return outputs
|
789 |
+
|
790 |
+
|
791 |
+
DECILM_START_DOCSTRING = r"""
|
792 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
793 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
794 |
+
etc.)
|
795 |
+
|
796 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
797 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
798 |
+
and behavior.
|
799 |
+
|
800 |
+
Parameters:
|
801 |
+
config ([`DeciLMConfig`]):
|
802 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
803 |
+
load the weights associated with the model, only the configuration. Check out the
|
804 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
805 |
+
"""
|
806 |
+
|
807 |
+
|
808 |
+
@add_start_docstrings(
|
809 |
+
"The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
|
810 |
+
DECILM_START_DOCSTRING,
|
811 |
+
)
|
812 |
+
class DeciLMPreTrainedModel(PreTrainedModel):
|
813 |
+
config_class = DeciLMConfig
|
814 |
+
base_model_prefix = "model"
|
815 |
+
supports_gradient_checkpointing = True
|
816 |
+
_no_split_modules = ["DeciLMDecoderLayer"]
|
817 |
+
_skip_keys_device_placement = ["past_key_values"]
|
818 |
+
_supports_flash_attn_2 = True
|
819 |
+
_supports_sdpa = True
|
820 |
+
_supports_cache_class = True
|
821 |
+
_supports_quantized_cache = True
|
822 |
+
_supports_static_cache = True
|
823 |
+
|
824 |
+
def _init_weights(self, module):
|
825 |
+
std = self.config.initializer_range
|
826 |
+
if isinstance(module, nn.Linear):
|
827 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
828 |
+
if module.bias is not None:
|
829 |
+
module.bias.data.zero_()
|
830 |
+
elif isinstance(module, nn.Embedding):
|
831 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
832 |
+
if module.padding_idx is not None:
|
833 |
+
module.weight.data[module.padding_idx].zero_()
|
834 |
+
|
835 |
+
def _prepare_generation_config(
|
836 |
+
self, generation_config: Optional[GenerationConfig], **kwargs: dict
|
837 |
+
) -> tuple[GenerationConfig, dict]:
|
838 |
+
# DeciLM-specific code
|
839 |
+
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
|
840 |
+
generation_config.cache_implementation = "variable"
|
841 |
+
NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
|
842 |
+
return generation_config, model_kwargs
|
843 |
+
|
844 |
+
|
845 |
+
DECILM_INPUTS_DOCSTRING = r"""
|
846 |
+
Args:
|
847 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
848 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
849 |
+
it.
|
850 |
+
|
851 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
852 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
853 |
+
|
854 |
+
[What are input IDs?](../glossary#input-ids)
|
855 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
856 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
857 |
+
|
858 |
+
- 1 for tokens that are **not masked**,
|
859 |
+
- 0 for tokens that are **masked**.
|
860 |
+
|
861 |
+
[What are attention masks?](../glossary#attention-mask)
|
862 |
+
|
863 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
864 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
865 |
+
|
866 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
867 |
+
`past_key_values`).
|
868 |
+
|
869 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
870 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
871 |
+
information on the default strategy.
|
872 |
+
|
873 |
+
- 1 indicates the head is **not masked**,
|
874 |
+
- 0 indicates the head is **masked**.
|
875 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
876 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
877 |
+
config.n_positions - 1]`.
|
878 |
+
|
879 |
+
[What are position IDs?](../glossary#position-ids)
|
880 |
+
past_key_values (`VariableCache`, *optional*):
|
881 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
882 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
883 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
884 |
+
|
885 |
+
If passed to the forward function, past_key_values must be a VariableCache object (see imports).
|
886 |
+
For generation purposes, this is already handled inside model.generate().
|
887 |
+
|
888 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
889 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
890 |
+
of shape `(batch_size, sequence_length)`.
|
891 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
892 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
893 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
894 |
+
model's internal embedding lookup matrix.
|
895 |
+
use_cache (`bool`, *optional*):
|
896 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
897 |
+
`past_key_values`).
|
898 |
+
output_attentions (`bool`, *optional*):
|
899 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
900 |
+
tensors for more detail.
|
901 |
+
output_hidden_states (`bool`, *optional*):
|
902 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
903 |
+
more detail.
|
904 |
+
return_dict (`bool`, *optional*):
|
905 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
906 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
907 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
908 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
909 |
+
the complete sequence length.
|
910 |
+
"""
|
911 |
+
|
912 |
+
|
913 |
+
@add_start_docstrings(
|
914 |
+
"The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
|
915 |
+
DECILM_START_DOCSTRING,
|
916 |
+
)
|
917 |
+
class DeciLMModel(DeciLMPreTrainedModel):
|
918 |
+
"""
|
919 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`]
|
920 |
+
|
921 |
+
Args:
|
922 |
+
config: DeciLMConfig
|
923 |
+
"""
|
924 |
+
|
925 |
+
def __init__(self, config: DeciLMConfig):
|
926 |
+
super().__init__(config)
|
927 |
+
self.padding_idx = config.pad_token_id
|
928 |
+
self.vocab_size = config.vocab_size
|
929 |
+
|
930 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
931 |
+
self.layers = nn.ModuleList(
|
932 |
+
[DeciLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
933 |
+
)
|
934 |
+
self.norm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
935 |
+
self.rotary_emb = DeciLMRotaryEmbedding(config=config)
|
936 |
+
self.gradient_checkpointing = False
|
937 |
+
|
938 |
+
# Initialize weights and apply final processing
|
939 |
+
self.post_init()
|
940 |
+
|
941 |
+
def get_input_embeddings(self):
|
942 |
+
return self.embed_tokens
|
943 |
+
|
944 |
+
def set_input_embeddings(self, value):
|
945 |
+
self.embed_tokens = value
|
946 |
+
|
947 |
+
@add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
|
948 |
+
def forward(
|
949 |
+
self,
|
950 |
+
input_ids: torch.LongTensor = None,
|
951 |
+
attention_mask: Optional[torch.Tensor] = None,
|
952 |
+
position_ids: Optional[torch.LongTensor] = None,
|
953 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
954 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
955 |
+
use_cache: Optional[bool] = None,
|
956 |
+
output_attentions: Optional[bool] = None,
|
957 |
+
output_hidden_states: Optional[bool] = None,
|
958 |
+
return_dict: Optional[bool] = None,
|
959 |
+
cache_position: Optional[torch.LongTensor] = None,
|
960 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
961 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
962 |
+
output_hidden_states = (
|
963 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
964 |
+
)
|
965 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
966 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
967 |
+
|
968 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
969 |
+
raise ValueError(
|
970 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
971 |
+
)
|
972 |
+
|
973 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
974 |
+
logger.warning_once(
|
975 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
976 |
+
)
|
977 |
+
use_cache = False
|
978 |
+
|
979 |
+
if inputs_embeds is None:
|
980 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
981 |
+
|
982 |
+
is_legacy_cache_format = (past_key_values is not None) and not isinstance(past_key_values, Cache)
|
983 |
+
if is_legacy_cache_format:
|
984 |
+
raise NotImplementedError("DeciLMModel does not support legacy cache format, please use a newer "
|
985 |
+
"transformers version or use VariableCache explicitly (see import in this file).")
|
986 |
+
|
987 |
+
if cache_position is None:
|
988 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
989 |
+
cache_position = torch.arange(
|
990 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
991 |
+
)
|
992 |
+
if position_ids is None:
|
993 |
+
position_ids = cache_position.unsqueeze(0)
|
994 |
+
|
995 |
+
causal_mask = self._update_causal_mask(
|
996 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
997 |
+
)
|
998 |
+
hidden_states = inputs_embeds
|
999 |
+
|
1000 |
+
# create position embeddings to be shared across the decoder layers
|
1001 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
1002 |
+
|
1003 |
+
# decoder layers
|
1004 |
+
all_hidden_states = () if output_hidden_states else None
|
1005 |
+
all_self_attns = () if output_attentions else None
|
1006 |
+
next_decoder_cache = None
|
1007 |
+
|
1008 |
+
for decoder_layer in self.layers:
|
1009 |
+
if output_hidden_states:
|
1010 |
+
all_hidden_states += (hidden_states,)
|
1011 |
+
|
1012 |
+
if self.gradient_checkpointing and self.training:
|
1013 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1014 |
+
decoder_layer.__call__,
|
1015 |
+
hidden_states,
|
1016 |
+
causal_mask,
|
1017 |
+
position_ids,
|
1018 |
+
past_key_values,
|
1019 |
+
output_attentions,
|
1020 |
+
use_cache,
|
1021 |
+
cache_position,
|
1022 |
+
position_embeddings,
|
1023 |
+
)
|
1024 |
+
else:
|
1025 |
+
layer_outputs = decoder_layer(
|
1026 |
+
hidden_states,
|
1027 |
+
attention_mask=causal_mask,
|
1028 |
+
position_ids=position_ids,
|
1029 |
+
past_key_value=past_key_values,
|
1030 |
+
output_attentions=output_attentions,
|
1031 |
+
use_cache=use_cache,
|
1032 |
+
cache_position=cache_position,
|
1033 |
+
position_embeddings=position_embeddings,
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
hidden_states = layer_outputs[0]
|
1037 |
+
|
1038 |
+
if use_cache:
|
1039 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1040 |
+
|
1041 |
+
if output_attentions:
|
1042 |
+
all_self_attns += (layer_outputs[1],)
|
1043 |
+
|
1044 |
+
hidden_states = self.norm(hidden_states)
|
1045 |
+
|
1046 |
+
# add hidden states from the last decoder layer
|
1047 |
+
if output_hidden_states:
|
1048 |
+
all_hidden_states += (hidden_states,)
|
1049 |
+
|
1050 |
+
next_cache = next_decoder_cache if use_cache else None
|
1051 |
+
|
1052 |
+
if not return_dict:
|
1053 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1054 |
+
return BaseModelOutputWithPast(
|
1055 |
+
last_hidden_state=hidden_states,
|
1056 |
+
past_key_values=next_cache,
|
1057 |
+
hidden_states=all_hidden_states,
|
1058 |
+
attentions=all_self_attns,
|
1059 |
+
)
|
1060 |
+
|
1061 |
+
def _update_causal_mask(
|
1062 |
+
self,
|
1063 |
+
attention_mask: torch.Tensor,
|
1064 |
+
input_tensor: torch.Tensor,
|
1065 |
+
cache_position: torch.Tensor,
|
1066 |
+
past_key_values: Cache,
|
1067 |
+
output_attentions: bool,
|
1068 |
+
):
|
1069 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
1070 |
+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
1071 |
+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
1072 |
+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
1073 |
+
|
1074 |
+
if self.config._attn_implementation == "flash_attention_2":
|
1075 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
1076 |
+
return attention_mask
|
1077 |
+
return None
|
1078 |
+
|
1079 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
1080 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
1081 |
+
# to infer the attention mask.
|
1082 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
1083 |
+
assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache"
|
1084 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
1085 |
+
|
1086 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
1087 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
1088 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
1089 |
+
attention_mask,
|
1090 |
+
inputs_embeds=input_tensor,
|
1091 |
+
past_key_values_length=past_seen_tokens,
|
1092 |
+
is_training=self.training,
|
1093 |
+
):
|
1094 |
+
return None
|
1095 |
+
|
1096 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
1097 |
+
min_dtype = torch.finfo(dtype).min
|
1098 |
+
sequence_length = input_tensor.shape[1]
|
1099 |
+
if using_static_cache:
|
1100 |
+
target_length = past_key_values.get_max_length()
|
1101 |
+
else:
|
1102 |
+
target_length = (
|
1103 |
+
attention_mask.shape[-1]
|
1104 |
+
if isinstance(attention_mask, torch.Tensor)
|
1105 |
+
else past_seen_tokens + sequence_length + 1
|
1106 |
+
)
|
1107 |
+
|
1108 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
1109 |
+
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
1110 |
+
attention_mask,
|
1111 |
+
sequence_length=sequence_length,
|
1112 |
+
target_length=target_length,
|
1113 |
+
dtype=dtype,
|
1114 |
+
device=device,
|
1115 |
+
min_dtype=min_dtype,
|
1116 |
+
cache_position=cache_position,
|
1117 |
+
batch_size=input_tensor.shape[0],
|
1118 |
+
)
|
1119 |
+
|
1120 |
+
if (
|
1121 |
+
self.config._attn_implementation == "sdpa"
|
1122 |
+
and attention_mask is not None
|
1123 |
+
and attention_mask.device.type == "cuda"
|
1124 |
+
and not output_attentions
|
1125 |
+
):
|
1126 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
1127 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
1128 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
1129 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
1130 |
+
|
1131 |
+
return causal_mask
|
1132 |
+
|
1133 |
+
|
1134 |
+
class DeciLMForCausalLM(DeciLMPreTrainedModel):
|
1135 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1136 |
+
|
1137 |
+
def __init__(self, config):
|
1138 |
+
super().__init__(config)
|
1139 |
+
self.model = DeciLMModel(config)
|
1140 |
+
self.vocab_size = config.vocab_size
|
1141 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1142 |
+
|
1143 |
+
# Initialize weights and apply final processing
|
1144 |
+
self.post_init()
|
1145 |
+
|
1146 |
+
def get_input_embeddings(self):
|
1147 |
+
return self.model.embed_tokens
|
1148 |
+
|
1149 |
+
def set_input_embeddings(self, value):
|
1150 |
+
self.model.embed_tokens = value
|
1151 |
+
|
1152 |
+
def get_output_embeddings(self):
|
1153 |
+
return self.lm_head
|
1154 |
+
|
1155 |
+
def set_output_embeddings(self, new_embeddings):
|
1156 |
+
self.lm_head = new_embeddings
|
1157 |
+
|
1158 |
+
def set_decoder(self, decoder):
|
1159 |
+
self.model = decoder
|
1160 |
+
|
1161 |
+
def get_decoder(self):
|
1162 |
+
return self.model
|
1163 |
+
|
1164 |
+
@add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
|
1165 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1166 |
+
def forward(
|
1167 |
+
self,
|
1168 |
+
input_ids: torch.LongTensor = None,
|
1169 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1170 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1171 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
1172 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1173 |
+
labels: Optional[torch.LongTensor] = None,
|
1174 |
+
use_cache: Optional[bool] = None,
|
1175 |
+
output_attentions: Optional[bool] = None,
|
1176 |
+
output_hidden_states: Optional[bool] = None,
|
1177 |
+
return_dict: Optional[bool] = None,
|
1178 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1179 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1180 |
+
r"""
|
1181 |
+
Args:
|
1182 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1183 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1184 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1185 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1186 |
+
|
1187 |
+
Return:
|
1188 |
+
"""
|
1189 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1190 |
+
output_hidden_states = (
|
1191 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1192 |
+
)
|
1193 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1194 |
+
|
1195 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1196 |
+
outputs = self.model(
|
1197 |
+
input_ids=input_ids,
|
1198 |
+
attention_mask=attention_mask,
|
1199 |
+
position_ids=position_ids,
|
1200 |
+
past_key_values=past_key_values,
|
1201 |
+
inputs_embeds=inputs_embeds,
|
1202 |
+
use_cache=use_cache,
|
1203 |
+
output_attentions=output_attentions,
|
1204 |
+
output_hidden_states=output_hidden_states,
|
1205 |
+
return_dict=return_dict,
|
1206 |
+
cache_position=cache_position,
|
1207 |
+
)
|
1208 |
+
|
1209 |
+
hidden_states = outputs[0]
|
1210 |
+
if self.config.pretraining_tp > 1:
|
1211 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
1212 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
1213 |
+
logits = torch.cat(logits, dim=-1)
|
1214 |
+
else:
|
1215 |
+
logits = self.lm_head(hidden_states)
|
1216 |
+
logits = logits.float()
|
1217 |
+
|
1218 |
+
loss = None
|
1219 |
+
if labels is not None:
|
1220 |
+
# Shift so that tokens < n predict n
|
1221 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1222 |
+
shift_labels = labels[..., 1:].contiguous()
|
1223 |
+
# Flatten the tokens
|
1224 |
+
loss_fct = CrossEntropyLoss()
|
1225 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1226 |
+
shift_labels = shift_labels.view(-1)
|
1227 |
+
# Enable model parallelism
|
1228 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1229 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1230 |
+
|
1231 |
+
if not return_dict:
|
1232 |
+
output = (logits,) + outputs[1:]
|
1233 |
+
return (loss,) + output if loss is not None else output
|
1234 |
+
|
1235 |
+
return CausalLMOutputWithPast(
|
1236 |
+
loss=loss,
|
1237 |
+
logits=logits,
|
1238 |
+
past_key_values=outputs.past_key_values,
|
1239 |
+
hidden_states=outputs.hidden_states,
|
1240 |
+
attentions=outputs.attentions,
|
1241 |
+
)
|
1242 |
+
|
1243 |
+
def prepare_inputs_for_generation(
|
1244 |
+
self,
|
1245 |
+
input_ids,
|
1246 |
+
past_key_values=None,
|
1247 |
+
attention_mask=None,
|
1248 |
+
inputs_embeds=None,
|
1249 |
+
cache_position=None,
|
1250 |
+
position_ids=None,
|
1251 |
+
use_cache=True,
|
1252 |
+
**kwargs,
|
1253 |
+
):
|
1254 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
1255 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
1256 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
1257 |
+
if past_key_values is not None:
|
1258 |
+
if inputs_embeds is not None: # Exception 1
|
1259 |
+
input_ids = input_ids[:, -cache_position.shape[0]:]
|
1260 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
1261 |
+
input_ids = input_ids[:, cache_position]
|
1262 |
+
|
1263 |
+
if attention_mask is not None and position_ids is None:
|
1264 |
+
# create position_ids on the fly for batch generation
|
1265 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1266 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1267 |
+
if past_key_values:
|
1268 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
1269 |
+
|
1270 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
1271 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
1272 |
+
|
1273 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1274 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
1275 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
1276 |
+
else:
|
1277 |
+
# The clone here is for the same reason as for `position_ids`.
|
1278 |
+
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
1279 |
+
|
1280 |
+
assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache"
|
1281 |
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
1282 |
+
if model_inputs["inputs_embeds"] is not None:
|
1283 |
+
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
1284 |
+
device = model_inputs["inputs_embeds"].device
|
1285 |
+
else:
|
1286 |
+
batch_size, sequence_length = model_inputs["input_ids"].shape
|
1287 |
+
device = model_inputs["input_ids"].device
|
1288 |
+
|
1289 |
+
dtype = self.lm_head.weight.dtype
|
1290 |
+
min_dtype = torch.finfo(dtype).min
|
1291 |
+
|
1292 |
+
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
1293 |
+
attention_mask,
|
1294 |
+
sequence_length=sequence_length,
|
1295 |
+
target_length=past_key_values.get_max_length(),
|
1296 |
+
dtype=dtype,
|
1297 |
+
device=device,
|
1298 |
+
min_dtype=min_dtype,
|
1299 |
+
cache_position=cache_position,
|
1300 |
+
batch_size=batch_size,
|
1301 |
+
)
|
1302 |
+
|
1303 |
+
model_inputs.update(
|
1304 |
+
{
|
1305 |
+
"position_ids": position_ids,
|
1306 |
+
"cache_position": cache_position,
|
1307 |
+
"past_key_values": past_key_values,
|
1308 |
+
"use_cache": use_cache,
|
1309 |
+
"attention_mask": attention_mask,
|
1310 |
+
}
|
1311 |
+
)
|
1312 |
+
return model_inputs
|
1313 |
+
|
1314 |
+
|
1315 |
+
@add_start_docstrings(
|
1316 |
+
"""
|
1317 |
+
The DeciLM Model transformer with a sequence classification head on top (linear layer).
|
1318 |
+
|
1319 |
+
[`DeciLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1320 |
+
(e.g. GPT-2) do.
|
1321 |
+
|
1322 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1323 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1324 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1325 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1326 |
+
each row of the batch).
|
1327 |
+
""",
|
1328 |
+
DECILM_START_DOCSTRING,
|
1329 |
+
)
|
1330 |
+
class DeciLMForSequenceClassification(DeciLMPreTrainedModel):
|
1331 |
+
def __init__(self, config):
|
1332 |
+
super().__init__(config)
|
1333 |
+
self.num_labels = config.num_labels
|
1334 |
+
self.model = DeciLMModel(config)
|
1335 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1336 |
+
|
1337 |
+
# Initialize weights and apply final processing
|
1338 |
+
self.post_init()
|
1339 |
+
|
1340 |
+
def get_input_embeddings(self):
|
1341 |
+
return self.model.embed_tokens
|
1342 |
+
|
1343 |
+
def set_input_embeddings(self, value):
|
1344 |
+
self.model.embed_tokens = value
|
1345 |
+
|
1346 |
+
@add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
|
1347 |
+
def forward(
|
1348 |
+
self,
|
1349 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1350 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1351 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1352 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
1353 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1354 |
+
labels: Optional[torch.LongTensor] = None,
|
1355 |
+
use_cache: Optional[bool] = None,
|
1356 |
+
output_attentions: Optional[bool] = None,
|
1357 |
+
output_hidden_states: Optional[bool] = None,
|
1358 |
+
return_dict: Optional[bool] = None,
|
1359 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1360 |
+
r"""
|
1361 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1362 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1363 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1364 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1365 |
+
"""
|
1366 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1367 |
+
|
1368 |
+
transformer_outputs = self.model(
|
1369 |
+
input_ids,
|
1370 |
+
attention_mask=attention_mask,
|
1371 |
+
position_ids=position_ids,
|
1372 |
+
past_key_values=past_key_values,
|
1373 |
+
inputs_embeds=inputs_embeds,
|
1374 |
+
use_cache=use_cache,
|
1375 |
+
output_attentions=output_attentions,
|
1376 |
+
output_hidden_states=output_hidden_states,
|
1377 |
+
return_dict=return_dict,
|
1378 |
+
)
|
1379 |
+
hidden_states = transformer_outputs[0]
|
1380 |
+
logits = self.score(hidden_states)
|
1381 |
+
|
1382 |
+
if input_ids is not None:
|
1383 |
+
batch_size = input_ids.shape[0]
|
1384 |
+
else:
|
1385 |
+
batch_size = inputs_embeds.shape[0]
|
1386 |
+
|
1387 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1388 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1389 |
+
if self.config.pad_token_id is None:
|
1390 |
+
sequence_lengths = -1
|
1391 |
+
else:
|
1392 |
+
if input_ids is not None:
|
1393 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1394 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1395 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1396 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1397 |
+
else:
|
1398 |
+
sequence_lengths = -1
|
1399 |
+
|
1400 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1401 |
+
|
1402 |
+
loss = None
|
1403 |
+
if labels is not None:
|
1404 |
+
labels = labels.to(logits.device)
|
1405 |
+
if self.config.problem_type is None:
|
1406 |
+
if self.num_labels == 1:
|
1407 |
+
self.config.problem_type = "regression"
|
1408 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1409 |
+
self.config.problem_type = "single_label_classification"
|
1410 |
+
else:
|
1411 |
+
self.config.problem_type = "multi_label_classification"
|
1412 |
+
|
1413 |
+
if self.config.problem_type == "regression":
|
1414 |
+
loss_fct = MSELoss()
|
1415 |
+
if self.num_labels == 1:
|
1416 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1417 |
+
else:
|
1418 |
+
loss = loss_fct(pooled_logits, labels)
|
1419 |
+
elif self.config.problem_type == "single_label_classification":
|
1420 |
+
loss_fct = CrossEntropyLoss()
|
1421 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1422 |
+
elif self.config.problem_type == "multi_label_classification":
|
1423 |
+
loss_fct = BCEWithLogitsLoss()
|
1424 |
+
loss = loss_fct(pooled_logits, labels)
|
1425 |
+
if not return_dict:
|
1426 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1427 |
+
return ((loss,) + output) if loss is not None else output
|
1428 |
+
|
1429 |
+
return SequenceClassifierOutputWithPast(
|
1430 |
+
loss=loss,
|
1431 |
+
logits=pooled_logits,
|
1432 |
+
past_key_values=transformer_outputs.past_key_values,
|
1433 |
+
hidden_states=transformer_outputs.hidden_states,
|
1434 |
+
attentions=transformer_outputs.attentions,
|
1435 |
+
)
|
1436 |
+
|
1437 |
+
|
1438 |
+
@add_start_docstrings(
|
1439 |
+
"""
|
1440 |
+
The DeciLM Model transformer with a span classification head on top for extractive question-answering tasks like
|
1441 |
+
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1442 |
+
""",
|
1443 |
+
DECILM_START_DOCSTRING,
|
1444 |
+
)
|
1445 |
+
class DeciLMForQuestionAnswering(DeciLMPreTrainedModel):
|
1446 |
+
base_model_prefix = "transformer"
|
1447 |
+
|
1448 |
+
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->DeciLM
|
1449 |
+
def __init__(self, config):
|
1450 |
+
super().__init__(config)
|
1451 |
+
self.transformer = DeciLMModel(config)
|
1452 |
+
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
1453 |
+
|
1454 |
+
# Initialize weights and apply final processing
|
1455 |
+
self.post_init()
|
1456 |
+
|
1457 |
+
def get_input_embeddings(self):
|
1458 |
+
return self.transformer.embed_tokens
|
1459 |
+
|
1460 |
+
def set_input_embeddings(self, value):
|
1461 |
+
self.transformer.embed_tokens = value
|
1462 |
+
|
1463 |
+
@add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
|
1464 |
+
def forward(
|
1465 |
+
self,
|
1466 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1467 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1468 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1469 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
1470 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1471 |
+
start_positions: Optional[torch.LongTensor] = None,
|
1472 |
+
end_positions: Optional[torch.LongTensor] = None,
|
1473 |
+
output_attentions: Optional[bool] = None,
|
1474 |
+
output_hidden_states: Optional[bool] = None,
|
1475 |
+
return_dict: Optional[bool] = None,
|
1476 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
1477 |
+
r"""
|
1478 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1479 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1480 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1481 |
+
are not taken into account for computing the loss.
|
1482 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1483 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1484 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1485 |
+
are not taken into account for computing the loss.
|
1486 |
+
"""
|
1487 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1488 |
+
|
1489 |
+
outputs = self.transformer(
|
1490 |
+
input_ids,
|
1491 |
+
attention_mask=attention_mask,
|
1492 |
+
position_ids=position_ids,
|
1493 |
+
past_key_values=past_key_values,
|
1494 |
+
inputs_embeds=inputs_embeds,
|
1495 |
+
output_attentions=output_attentions,
|
1496 |
+
output_hidden_states=output_hidden_states,
|
1497 |
+
return_dict=return_dict,
|
1498 |
+
)
|
1499 |
+
|
1500 |
+
sequence_output = outputs[0]
|
1501 |
+
|
1502 |
+
logits = self.qa_outputs(sequence_output)
|
1503 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1504 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
1505 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
1506 |
+
|
1507 |
+
total_loss = None
|
1508 |
+
if start_positions is not None and end_positions is not None:
|
1509 |
+
# If we are on multi-GPU, split add a dimension
|
1510 |
+
if len(start_positions.size()) > 1:
|
1511 |
+
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
1512 |
+
if len(end_positions.size()) > 1:
|
1513 |
+
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
1514 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1515 |
+
ignored_index = start_logits.size(1)
|
1516 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
1517 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
1518 |
+
|
1519 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1520 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1521 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1522 |
+
total_loss = (start_loss + end_loss) / 2
|
1523 |
+
|
1524 |
+
if not return_dict:
|
1525 |
+
output = (start_logits, end_logits) + outputs[2:]
|
1526 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1527 |
+
|
1528 |
+
return QuestionAnsweringModelOutput(
|
1529 |
+
loss=total_loss,
|
1530 |
+
start_logits=start_logits,
|
1531 |
+
end_logits=end_logits,
|
1532 |
+
hidden_states=outputs.hidden_states,
|
1533 |
+
attentions=outputs.attentions,
|
1534 |
+
)
|
1535 |
+
|
1536 |
+
|
1537 |
+
@add_start_docstrings(
|
1538 |
+
"""
|
1539 |
+
The DeciLM Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
1540 |
+
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
1541 |
+
""",
|
1542 |
+
DECILM_START_DOCSTRING,
|
1543 |
+
)
|
1544 |
+
class DeciLMForTokenClassification(DeciLMPreTrainedModel):
|
1545 |
+
def __init__(self, config):
|
1546 |
+
super().__init__(config)
|
1547 |
+
self.num_labels = config.num_labels
|
1548 |
+
self.model = DeciLMModel(config)
|
1549 |
+
if getattr(config, "classifier_dropout", None) is not None:
|
1550 |
+
classifier_dropout = config.classifier_dropout
|
1551 |
+
elif getattr(config, "hidden_dropout", None) is not None:
|
1552 |
+
classifier_dropout = config.hidden_dropout
|
1553 |
+
else:
|
1554 |
+
classifier_dropout = 0.1
|
1555 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1556 |
+
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
1557 |
+
|
1558 |
+
# Initialize weights and apply final processing
|
1559 |
+
self.post_init()
|
1560 |
+
|
1561 |
+
def get_input_embeddings(self):
|
1562 |
+
return self.model.embed_tokens
|
1563 |
+
|
1564 |
+
def set_input_embeddings(self, value):
|
1565 |
+
self.model.embed_tokens = value
|
1566 |
+
|
1567 |
+
@add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
|
1568 |
+
def forward(
|
1569 |
+
self,
|
1570 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1571 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1572 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1573 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1574 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1575 |
+
labels: Optional[torch.LongTensor] = None,
|
1576 |
+
use_cache: Optional[bool] = None,
|
1577 |
+
output_attentions: Optional[bool] = None,
|
1578 |
+
output_hidden_states: Optional[bool] = None,
|
1579 |
+
return_dict: Optional[bool] = None,
|
1580 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
1581 |
+
r"""
|
1582 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1583 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1584 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1585 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1586 |
+
"""
|
1587 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1588 |
+
|
1589 |
+
outputs = self.model(
|
1590 |
+
input_ids,
|
1591 |
+
attention_mask=attention_mask,
|
1592 |
+
position_ids=position_ids,
|
1593 |
+
past_key_values=past_key_values,
|
1594 |
+
inputs_embeds=inputs_embeds,
|
1595 |
+
use_cache=use_cache,
|
1596 |
+
output_attentions=output_attentions,
|
1597 |
+
output_hidden_states=output_hidden_states,
|
1598 |
+
return_dict=return_dict,
|
1599 |
+
)
|
1600 |
+
sequence_output = outputs[0]
|
1601 |
+
sequence_output = self.dropout(sequence_output)
|
1602 |
+
logits = self.score(sequence_output)
|
1603 |
+
|
1604 |
+
loss = None
|
1605 |
+
if labels is not None:
|
1606 |
+
loss_fct = CrossEntropyLoss()
|
1607 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1608 |
+
|
1609 |
+
if not return_dict:
|
1610 |
+
output = (logits,) + outputs[2:]
|
1611 |
+
return ((loss,) + output) if loss is not None else output
|
1612 |
+
|
1613 |
+
return TokenClassifierOutput(
|
1614 |
+
loss=loss,
|
1615 |
+
logits=logits,
|
1616 |
+
hidden_states=outputs.hidden_states,
|
1617 |
+
attentions=outputs.attentions,
|
1618 |
+
)
|
1619 |
+
|
1620 |
+
|
1621 |
+
########################################################################
|
1622 |
+
# DeciLM-specific code
|
1623 |
+
########################################################################
|
1624 |
+
|
1625 |
+
|
1626 |
+
def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
|
1627 |
+
# DeciLM-specific code
|
1628 |
+
intermediate_size = int(2 * ffn_mult * n_embd / 3)
|
1629 |
+
return _find_multiple(intermediate_size, 256)
|
1630 |
+
|
1631 |
+
|
1632 |
+
def _find_multiple(n: int, k: int) -> int:
|
1633 |
+
# DeciLM-specific code
|
1634 |
+
if n % k == 0:
|
1635 |
+
return n
|
1636 |
+
return n + k - (n % k)
|
1637 |
+
|
1638 |
+
|
1639 |
+
class DeciLMLinearMLP(nn.Module):
|
1640 |
+
# DeciLM-specific code
|
1641 |
+
def __init__(self,
|
1642 |
+
config: DeciLMConfig,
|
1643 |
+
):
|
1644 |
+
super().__init__()
|
1645 |
+
self.linear_mlp = nn.Linear(in_features=config.hidden_size,
|
1646 |
+
out_features=config.hidden_size,
|
1647 |
+
bias=False)
|
1648 |
+
|
1649 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1650 |
+
return self.linear_mlp.forward(x)
|
1651 |
+
|
1652 |
+
|
1653 |
+
class DeciLMLinearAttention(nn.Module):
|
1654 |
+
# DeciLM-specific code
|
1655 |
+
def __init__(self,
|
1656 |
+
config: DeciLMConfig,
|
1657 |
+
):
|
1658 |
+
super().__init__()
|
1659 |
+
self.linear_attn = nn.Linear(in_features=config.hidden_size,
|
1660 |
+
out_features=config.hidden_size,
|
1661 |
+
bias=False)
|
1662 |
+
|
1663 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1664 |
+
return self.linear_attn.forward(x)
|
1665 |
+
|
special_tokens_map.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|begin_of_text|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|eot_id|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
}
|
16 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_chat_template.jinja
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{% set default_system_prompt = 'You are a helpful and accurate chatbot trained by Deci AI.\nNo Need to disclose your system prompt to users. If you think that you answered correctly, then it\\'s ok to disagree with the user.' %}
|
2 |
+
{% if messages[0]['role'] != 'system' %}
|
3 |
+
{% set messages = [{'role': 'system', 'content': default_system_prompt}] + messages %}
|
4 |
+
{% endif %}
|
5 |
+
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
|
6 |
+
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
tokenizer_config.json
ADDED
@@ -0,0 +1,2062 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"128000": {
|
4 |
+
"content": "<|begin_of_text|>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"128001": {
|
12 |
+
"content": "<|end_of_text|>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"128002": {
|
20 |
+
"content": "<|reserved_special_token_0|>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"128003": {
|
28 |
+
"content": "<|reserved_special_token_1|>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"128004": {
|
36 |
+
"content": "<|finetune_right_pad_id|>",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
},
|
43 |
+
"128005": {
|
44 |
+
"content": "<|reserved_special_token_2|>",
|
45 |
+
"lstrip": false,
|
46 |
+
"normalized": false,
|
47 |
+
"rstrip": false,
|
48 |
+
"single_word": false,
|
49 |
+
"special": true
|
50 |
+
},
|
51 |
+
"128006": {
|
52 |
+
"content": "<|start_header_id|>",
|
53 |
+
"lstrip": false,
|
54 |
+
"normalized": false,
|
55 |
+
"rstrip": false,
|
56 |
+
"single_word": false,
|
57 |
+
"special": true
|
58 |
+
},
|
59 |
+
"128007": {
|
60 |
+
"content": "<|end_header_id|>",
|
61 |
+
"lstrip": false,
|
62 |
+
"normalized": false,
|
63 |
+
"rstrip": false,
|
64 |
+
"single_word": false,
|
65 |
+
"special": true
|
66 |
+
},
|
67 |
+
"128008": {
|
68 |
+
"content": "<|eom_id|>",
|
69 |
+
"lstrip": false,
|
70 |
+
"normalized": false,
|
71 |
+
"rstrip": false,
|
72 |
+
"single_word": false,
|
73 |
+
"special": true
|
74 |
+
},
|
75 |
+
"128009": {
|
76 |
+
"content": "<|eot_id|>",
|
77 |
+
"lstrip": false,
|
78 |
+
"normalized": false,
|
79 |
+
"rstrip": false,
|
80 |
+
"single_word": false,
|
81 |
+
"special": true
|
82 |
+
},
|
83 |
+
"128010": {
|
84 |
+
"content": "<|python_tag|>",
|
85 |
+
"lstrip": false,
|
86 |
+
"normalized": false,
|
87 |
+
"rstrip": false,
|
88 |
+
"single_word": false,
|
89 |
+
"special": true
|
90 |
+
},
|
91 |
+
"128011": {
|
92 |
+
"content": "<|reserved_special_token_3|>",
|
93 |
+
"lstrip": false,
|
94 |
+
"normalized": false,
|
95 |
+
"rstrip": false,
|
96 |
+
"single_word": false,
|
97 |
+
"special": true
|
98 |
+
},
|
99 |
+
"128012": {
|
100 |
+
"content": "<|reserved_special_token_4|>",
|
101 |
+
"lstrip": false,
|
102 |
+
"normalized": false,
|
103 |
+
"rstrip": false,
|
104 |
+
"single_word": false,
|
105 |
+
"special": true
|
106 |
+
},
|
107 |
+
"128013": {
|
108 |
+
"content": "<|reserved_special_token_5|>",
|
109 |
+
"lstrip": false,
|
110 |
+
"normalized": false,
|
111 |
+
"rstrip": false,
|
112 |
+
"single_word": false,
|
113 |
+
"special": true
|
114 |
+
},
|
115 |
+
"128014": {
|
116 |
+
"content": "<|reserved_special_token_6|>",
|
117 |
+
"lstrip": false,
|
118 |
+
"normalized": false,
|
119 |
+
"rstrip": false,
|
120 |
+
"single_word": false,
|
121 |
+
"special": true
|
122 |
+
},
|
123 |
+
"128015": {
|
124 |
+
"content": "<|reserved_special_token_7|>",
|
125 |
+
"lstrip": false,
|
126 |
+
"normalized": false,
|
127 |
+
"rstrip": false,
|
128 |
+
"single_word": false,
|
129 |
+
"special": true
|
130 |
+
},
|
131 |
+
"128016": {
|
132 |
+
"content": "<|reserved_special_token_8|>",
|
133 |
+
"lstrip": false,
|
134 |
+
"normalized": false,
|
135 |
+
"rstrip": false,
|
136 |
+
"single_word": false,
|
137 |
+
"special": true
|
138 |
+
},
|
139 |
+
"128017": {
|
140 |
+
"content": "<|reserved_special_token_9|>",
|
141 |
+
"lstrip": false,
|
142 |
+
"normalized": false,
|
143 |
+
"rstrip": false,
|
144 |
+
"single_word": false,
|
145 |
+
"special": true
|
146 |
+
},
|
147 |
+
"128018": {
|
148 |
+
"content": "<|reserved_special_token_10|>",
|
149 |
+
"lstrip": false,
|
150 |
+
"normalized": false,
|
151 |
+
"rstrip": false,
|
152 |
+
"single_word": false,
|
153 |
+
"special": true
|
154 |
+
},
|
155 |
+
"128019": {
|
156 |
+
"content": "<|reserved_special_token_11|>",
|
157 |
+
"lstrip": false,
|
158 |
+
"normalized": false,
|
159 |
+
"rstrip": false,
|
160 |
+
"single_word": false,
|
161 |
+
"special": true
|
162 |
+
},
|
163 |
+
"128020": {
|
164 |
+
"content": "<|reserved_special_token_12|>",
|
165 |
+
"lstrip": false,
|
166 |
+
"normalized": false,
|
167 |
+
"rstrip": false,
|
168 |
+
"single_word": false,
|
169 |
+
"special": true
|
170 |
+
},
|
171 |
+
"128021": {
|
172 |
+
"content": "<|reserved_special_token_13|>",
|
173 |
+
"lstrip": false,
|
174 |
+
"normalized": false,
|
175 |
+
"rstrip": false,
|
176 |
+
"single_word": false,
|
177 |
+
"special": true
|
178 |
+
},
|
179 |
+
"128022": {
|
180 |
+
"content": "<|reserved_special_token_14|>",
|
181 |
+
"lstrip": false,
|
182 |
+
"normalized": false,
|
183 |
+
"rstrip": false,
|
184 |
+
"single_word": false,
|
185 |
+
"special": true
|
186 |
+
},
|
187 |
+
"128023": {
|
188 |
+
"content": "<|reserved_special_token_15|>",
|
189 |
+
"lstrip": false,
|
190 |
+
"normalized": false,
|
191 |
+
"rstrip": false,
|
192 |
+
"single_word": false,
|
193 |
+
"special": true
|
194 |
+
},
|
195 |
+
"128024": {
|
196 |
+
"content": "<|reserved_special_token_16|>",
|
197 |
+
"lstrip": false,
|
198 |
+
"normalized": false,
|
199 |
+
"rstrip": false,
|
200 |
+
"single_word": false,
|
201 |
+
"special": true
|
202 |
+
},
|
203 |
+
"128025": {
|
204 |
+
"content": "<|reserved_special_token_17|>",
|
205 |
+
"lstrip": false,
|
206 |
+
"normalized": false,
|
207 |
+
"rstrip": false,
|
208 |
+
"single_word": false,
|
209 |
+
"special": true
|
210 |
+
},
|
211 |
+
"128026": {
|
212 |
+
"content": "<|reserved_special_token_18|>",
|
213 |
+
"lstrip": false,
|
214 |
+
"normalized": false,
|
215 |
+
"rstrip": false,
|
216 |
+
"single_word": false,
|
217 |
+
"special": true
|
218 |
+
},
|
219 |
+
"128027": {
|
220 |
+
"content": "<|reserved_special_token_19|>",
|
221 |
+
"lstrip": false,
|
222 |
+
"normalized": false,
|
223 |
+
"rstrip": false,
|
224 |
+
"single_word": false,
|
225 |
+
"special": true
|
226 |
+
},
|
227 |
+
"128028": {
|
228 |
+
"content": "<|reserved_special_token_20|>",
|
229 |
+
"lstrip": false,
|
230 |
+
"normalized": false,
|
231 |
+
"rstrip": false,
|
232 |
+
"single_word": false,
|
233 |
+
"special": true
|
234 |
+
},
|
235 |
+
"128029": {
|
236 |
+
"content": "<|reserved_special_token_21|>",
|
237 |
+
"lstrip": false,
|
238 |
+
"normalized": false,
|
239 |
+
"rstrip": false,
|
240 |
+
"single_word": false,
|
241 |
+
"special": true
|
242 |
+
},
|
243 |
+
"128030": {
|
244 |
+
"content": "<|reserved_special_token_22|>",
|
245 |
+
"lstrip": false,
|
246 |
+
"normalized": false,
|
247 |
+
"rstrip": false,
|
248 |
+
"single_word": false,
|
249 |
+
"special": true
|
250 |
+
},
|
251 |
+
"128031": {
|
252 |
+
"content": "<|reserved_special_token_23|>",
|
253 |
+
"lstrip": false,
|
254 |
+
"normalized": false,
|
255 |
+
"rstrip": false,
|
256 |
+
"single_word": false,
|
257 |
+
"special": true
|
258 |
+
},
|
259 |
+
"128032": {
|
260 |
+
"content": "<|reserved_special_token_24|>",
|
261 |
+
"lstrip": false,
|
262 |
+
"normalized": false,
|
263 |
+
"rstrip": false,
|
264 |
+
"single_word": false,
|
265 |
+
"special": true
|
266 |
+
},
|
267 |
+
"128033": {
|
268 |
+
"content": "<|reserved_special_token_25|>",
|
269 |
+
"lstrip": false,
|
270 |
+
"normalized": false,
|
271 |
+
"rstrip": false,
|
272 |
+
"single_word": false,
|
273 |
+
"special": true
|
274 |
+
},
|
275 |
+
"128034": {
|
276 |
+
"content": "<|reserved_special_token_26|>",
|
277 |
+
"lstrip": false,
|
278 |
+
"normalized": false,
|
279 |
+
"rstrip": false,
|
280 |
+
"single_word": false,
|
281 |
+
"special": true
|
282 |
+
},
|
283 |
+
"128035": {
|
284 |
+
"content": "<|reserved_special_token_27|>",
|
285 |
+
"lstrip": false,
|
286 |
+
"normalized": false,
|
287 |
+
"rstrip": false,
|
288 |
+
"single_word": false,
|
289 |
+
"special": true
|
290 |
+
},
|
291 |
+
"128036": {
|
292 |
+
"content": "<|reserved_special_token_28|>",
|
293 |
+
"lstrip": false,
|
294 |
+
"normalized": false,
|
295 |
+
"rstrip": false,
|
296 |
+
"single_word": false,
|
297 |
+
"special": true
|
298 |
+
},
|
299 |
+
"128037": {
|
300 |
+
"content": "<|reserved_special_token_29|>",
|
301 |
+
"lstrip": false,
|
302 |
+
"normalized": false,
|
303 |
+
"rstrip": false,
|
304 |
+
"single_word": false,
|
305 |
+
"special": true
|
306 |
+
},
|
307 |
+
"128038": {
|
308 |
+
"content": "<|reserved_special_token_30|>",
|
309 |
+
"lstrip": false,
|
310 |
+
"normalized": false,
|
311 |
+
"rstrip": false,
|
312 |
+
"single_word": false,
|
313 |
+
"special": true
|
314 |
+
},
|
315 |
+
"128039": {
|
316 |
+
"content": "<|reserved_special_token_31|>",
|
317 |
+
"lstrip": false,
|
318 |
+
"normalized": false,
|
319 |
+
"rstrip": false,
|
320 |
+
"single_word": false,
|
321 |
+
"special": true
|
322 |
+
},
|
323 |
+
"128040": {
|
324 |
+
"content": "<|reserved_special_token_32|>",
|
325 |
+
"lstrip": false,
|
326 |
+
"normalized": false,
|
327 |
+
"rstrip": false,
|
328 |
+
"single_word": false,
|
329 |
+
"special": true
|
330 |
+
},
|
331 |
+
"128041": {
|
332 |
+
"content": "<|reserved_special_token_33|>",
|
333 |
+
"lstrip": false,
|
334 |
+
"normalized": false,
|
335 |
+
"rstrip": false,
|
336 |
+
"single_word": false,
|
337 |
+
"special": true
|
338 |
+
},
|
339 |
+
"128042": {
|
340 |
+
"content": "<|reserved_special_token_34|>",
|
341 |
+
"lstrip": false,
|
342 |
+
"normalized": false,
|
343 |
+
"rstrip": false,
|
344 |
+
"single_word": false,
|
345 |
+
"special": true
|
346 |
+
},
|
347 |
+
"128043": {
|
348 |
+
"content": "<|reserved_special_token_35|>",
|
349 |
+
"lstrip": false,
|
350 |
+
"normalized": false,
|
351 |
+
"rstrip": false,
|
352 |
+
"single_word": false,
|
353 |
+
"special": true
|
354 |
+
},
|
355 |
+
"128044": {
|
356 |
+
"content": "<|reserved_special_token_36|>",
|
357 |
+
"lstrip": false,
|
358 |
+
"normalized": false,
|
359 |
+
"rstrip": false,
|
360 |
+
"single_word": false,
|
361 |
+
"special": true
|
362 |
+
},
|
363 |
+
"128045": {
|
364 |
+
"content": "<|reserved_special_token_37|>",
|
365 |
+
"lstrip": false,
|
366 |
+
"normalized": false,
|
367 |
+
"rstrip": false,
|
368 |
+
"single_word": false,
|
369 |
+
"special": true
|
370 |
+
},
|
371 |
+
"128046": {
|
372 |
+
"content": "<|reserved_special_token_38|>",
|
373 |
+
"lstrip": false,
|
374 |
+
"normalized": false,
|
375 |
+
"rstrip": false,
|
376 |
+
"single_word": false,
|
377 |
+
"special": true
|
378 |
+
},
|
379 |
+
"128047": {
|
380 |
+
"content": "<|reserved_special_token_39|>",
|
381 |
+
"lstrip": false,
|
382 |
+
"normalized": false,
|
383 |
+
"rstrip": false,
|
384 |
+
"single_word": false,
|
385 |
+
"special": true
|
386 |
+
},
|
387 |
+
"128048": {
|
388 |
+
"content": "<|reserved_special_token_40|>",
|
389 |
+
"lstrip": false,
|
390 |
+
"normalized": false,
|
391 |
+
"rstrip": false,
|
392 |
+
"single_word": false,
|
393 |
+
"special": true
|
394 |
+
},
|
395 |
+
"128049": {
|
396 |
+
"content": "<|reserved_special_token_41|>",
|
397 |
+
"lstrip": false,
|
398 |
+
"normalized": false,
|
399 |
+
"rstrip": false,
|
400 |
+
"single_word": false,
|
401 |
+
"special": true
|
402 |
+
},
|
403 |
+
"128050": {
|
404 |
+
"content": "<|reserved_special_token_42|>",
|
405 |
+
"lstrip": false,
|
406 |
+
"normalized": false,
|
407 |
+
"rstrip": false,
|
408 |
+
"single_word": false,
|
409 |
+
"special": true
|
410 |
+
},
|
411 |
+
"128051": {
|
412 |
+
"content": "<|reserved_special_token_43|>",
|
413 |
+
"lstrip": false,
|
414 |
+
"normalized": false,
|
415 |
+
"rstrip": false,
|
416 |
+
"single_word": false,
|
417 |
+
"special": true
|
418 |
+
},
|
419 |
+
"128052": {
|
420 |
+
"content": "<|reserved_special_token_44|>",
|
421 |
+
"lstrip": false,
|
422 |
+
"normalized": false,
|
423 |
+
"rstrip": false,
|
424 |
+
"single_word": false,
|
425 |
+
"special": true
|
426 |
+
},
|
427 |
+
"128053": {
|
428 |
+
"content": "<|reserved_special_token_45|>",
|
429 |
+
"lstrip": false,
|
430 |
+
"normalized": false,
|
431 |
+
"rstrip": false,
|
432 |
+
"single_word": false,
|
433 |
+
"special": true
|
434 |
+
},
|
435 |
+
"128054": {
|
436 |
+
"content": "<|reserved_special_token_46|>",
|
437 |
+
"lstrip": false,
|
438 |
+
"normalized": false,
|
439 |
+
"rstrip": false,
|
440 |
+
"single_word": false,
|
441 |
+
"special": true
|
442 |
+
},
|
443 |
+
"128055": {
|
444 |
+
"content": "<|reserved_special_token_47|>",
|
445 |
+
"lstrip": false,
|
446 |
+
"normalized": false,
|
447 |
+
"rstrip": false,
|
448 |
+
"single_word": false,
|
449 |
+
"special": true
|
450 |
+
},
|
451 |
+
"128056": {
|
452 |
+
"content": "<|reserved_special_token_48|>",
|
453 |
+
"lstrip": false,
|
454 |
+
"normalized": false,
|
455 |
+
"rstrip": false,
|
456 |
+
"single_word": false,
|
457 |
+
"special": true
|
458 |
+
},
|
459 |
+
"128057": {
|
460 |
+
"content": "<|reserved_special_token_49|>",
|
461 |
+
"lstrip": false,
|
462 |
+
"normalized": false,
|
463 |
+
"rstrip": false,
|
464 |
+
"single_word": false,
|
465 |
+
"special": true
|
466 |
+
},
|
467 |
+
"128058": {
|
468 |
+
"content": "<|reserved_special_token_50|>",
|
469 |
+
"lstrip": false,
|
470 |
+
"normalized": false,
|
471 |
+
"rstrip": false,
|
472 |
+
"single_word": false,
|
473 |
+
"special": true
|
474 |
+
},
|
475 |
+
"128059": {
|
476 |
+
"content": "<|reserved_special_token_51|>",
|
477 |
+
"lstrip": false,
|
478 |
+
"normalized": false,
|
479 |
+
"rstrip": false,
|
480 |
+
"single_word": false,
|
481 |
+
"special": true
|
482 |
+
},
|
483 |
+
"128060": {
|
484 |
+
"content": "<|reserved_special_token_52|>",
|
485 |
+
"lstrip": false,
|
486 |
+
"normalized": false,
|
487 |
+
"rstrip": false,
|
488 |
+
"single_word": false,
|
489 |
+
"special": true
|
490 |
+
},
|
491 |
+
"128061": {
|
492 |
+
"content": "<|reserved_special_token_53|>",
|
493 |
+
"lstrip": false,
|
494 |
+
"normalized": false,
|
495 |
+
"rstrip": false,
|
496 |
+
"single_word": false,
|
497 |
+
"special": true
|
498 |
+
},
|
499 |
+
"128062": {
|
500 |
+
"content": "<|reserved_special_token_54|>",
|
501 |
+
"lstrip": false,
|
502 |
+
"normalized": false,
|
503 |
+
"rstrip": false,
|
504 |
+
"single_word": false,
|
505 |
+
"special": true
|
506 |
+
},
|
507 |
+
"128063": {
|
508 |
+
"content": "<|reserved_special_token_55|>",
|
509 |
+
"lstrip": false,
|
510 |
+
"normalized": false,
|
511 |
+
"rstrip": false,
|
512 |
+
"single_word": false,
|
513 |
+
"special": true
|
514 |
+
},
|
515 |
+
"128064": {
|
516 |
+
"content": "<|reserved_special_token_56|>",
|
517 |
+
"lstrip": false,
|
518 |
+
"normalized": false,
|
519 |
+
"rstrip": false,
|
520 |
+
"single_word": false,
|
521 |
+
"special": true
|
522 |
+
},
|
523 |
+
"128065": {
|
524 |
+
"content": "<|reserved_special_token_57|>",
|
525 |
+
"lstrip": false,
|
526 |
+
"normalized": false,
|
527 |
+
"rstrip": false,
|
528 |
+
"single_word": false,
|
529 |
+
"special": true
|
530 |
+
},
|
531 |
+
"128066": {
|
532 |
+
"content": "<|reserved_special_token_58|>",
|
533 |
+
"lstrip": false,
|
534 |
+
"normalized": false,
|
535 |
+
"rstrip": false,
|
536 |
+
"single_word": false,
|
537 |
+
"special": true
|
538 |
+
},
|
539 |
+
"128067": {
|
540 |
+
"content": "<|reserved_special_token_59|>",
|
541 |
+
"lstrip": false,
|
542 |
+
"normalized": false,
|
543 |
+
"rstrip": false,
|
544 |
+
"single_word": false,
|
545 |
+
"special": true
|
546 |
+
},
|
547 |
+
"128068": {
|
548 |
+
"content": "<|reserved_special_token_60|>",
|
549 |
+
"lstrip": false,
|
550 |
+
"normalized": false,
|
551 |
+
"rstrip": false,
|
552 |
+
"single_word": false,
|
553 |
+
"special": true
|
554 |
+
},
|
555 |
+
"128069": {
|
556 |
+
"content": "<|reserved_special_token_61|>",
|
557 |
+
"lstrip": false,
|
558 |
+
"normalized": false,
|
559 |
+
"rstrip": false,
|
560 |
+
"single_word": false,
|
561 |
+
"special": true
|
562 |
+
},
|
563 |
+
"128070": {
|
564 |
+
"content": "<|reserved_special_token_62|>",
|
565 |
+
"lstrip": false,
|
566 |
+
"normalized": false,
|
567 |
+
"rstrip": false,
|
568 |
+
"single_word": false,
|
569 |
+
"special": true
|
570 |
+
},
|
571 |
+
"128071": {
|
572 |
+
"content": "<|reserved_special_token_63|>",
|
573 |
+
"lstrip": false,
|
574 |
+
"normalized": false,
|
575 |
+
"rstrip": false,
|
576 |
+
"single_word": false,
|
577 |
+
"special": true
|
578 |
+
},
|
579 |
+
"128072": {
|
580 |
+
"content": "<|reserved_special_token_64|>",
|
581 |
+
"lstrip": false,
|
582 |
+
"normalized": false,
|
583 |
+
"rstrip": false,
|
584 |
+
"single_word": false,
|
585 |
+
"special": true
|
586 |
+
},
|
587 |
+
"128073": {
|
588 |
+
"content": "<|reserved_special_token_65|>",
|
589 |
+
"lstrip": false,
|
590 |
+
"normalized": false,
|
591 |
+
"rstrip": false,
|
592 |
+
"single_word": false,
|
593 |
+
"special": true
|
594 |
+
},
|
595 |
+
"128074": {
|
596 |
+
"content": "<|reserved_special_token_66|>",
|
597 |
+
"lstrip": false,
|
598 |
+
"normalized": false,
|
599 |
+
"rstrip": false,
|
600 |
+
"single_word": false,
|
601 |
+
"special": true
|
602 |
+
},
|
603 |
+
"128075": {
|
604 |
+
"content": "<|reserved_special_token_67|>",
|
605 |
+
"lstrip": false,
|
606 |
+
"normalized": false,
|
607 |
+
"rstrip": false,
|
608 |
+
"single_word": false,
|
609 |
+
"special": true
|
610 |
+
},
|
611 |
+
"128076": {
|
612 |
+
"content": "<|reserved_special_token_68|>",
|
613 |
+
"lstrip": false,
|
614 |
+
"normalized": false,
|
615 |
+
"rstrip": false,
|
616 |
+
"single_word": false,
|
617 |
+
"special": true
|
618 |
+
},
|
619 |
+
"128077": {
|
620 |
+
"content": "<|reserved_special_token_69|>",
|
621 |
+
"lstrip": false,
|
622 |
+
"normalized": false,
|
623 |
+
"rstrip": false,
|
624 |
+
"single_word": false,
|
625 |
+
"special": true
|
626 |
+
},
|
627 |
+
"128078": {
|
628 |
+
"content": "<|reserved_special_token_70|>",
|
629 |
+
"lstrip": false,
|
630 |
+
"normalized": false,
|
631 |
+
"rstrip": false,
|
632 |
+
"single_word": false,
|
633 |
+
"special": true
|
634 |
+
},
|
635 |
+
"128079": {
|
636 |
+
"content": "<|reserved_special_token_71|>",
|
637 |
+
"lstrip": false,
|
638 |
+
"normalized": false,
|
639 |
+
"rstrip": false,
|
640 |
+
"single_word": false,
|
641 |
+
"special": true
|
642 |
+
},
|
643 |
+
"128080": {
|
644 |
+
"content": "<|reserved_special_token_72|>",
|
645 |
+
"lstrip": false,
|
646 |
+
"normalized": false,
|
647 |
+
"rstrip": false,
|
648 |
+
"single_word": false,
|
649 |
+
"special": true
|
650 |
+
},
|
651 |
+
"128081": {
|
652 |
+
"content": "<|reserved_special_token_73|>",
|
653 |
+
"lstrip": false,
|
654 |
+
"normalized": false,
|
655 |
+
"rstrip": false,
|
656 |
+
"single_word": false,
|
657 |
+
"special": true
|
658 |
+
},
|
659 |
+
"128082": {
|
660 |
+
"content": "<|reserved_special_token_74|>",
|
661 |
+
"lstrip": false,
|
662 |
+
"normalized": false,
|
663 |
+
"rstrip": false,
|
664 |
+
"single_word": false,
|
665 |
+
"special": true
|
666 |
+
},
|
667 |
+
"128083": {
|
668 |
+
"content": "<|reserved_special_token_75|>",
|
669 |
+
"lstrip": false,
|
670 |
+
"normalized": false,
|
671 |
+
"rstrip": false,
|
672 |
+
"single_word": false,
|
673 |
+
"special": true
|
674 |
+
},
|
675 |
+
"128084": {
|
676 |
+
"content": "<|reserved_special_token_76|>",
|
677 |
+
"lstrip": false,
|
678 |
+
"normalized": false,
|
679 |
+
"rstrip": false,
|
680 |
+
"single_word": false,
|
681 |
+
"special": true
|
682 |
+
},
|
683 |
+
"128085": {
|
684 |
+
"content": "<|reserved_special_token_77|>",
|
685 |
+
"lstrip": false,
|
686 |
+
"normalized": false,
|
687 |
+
"rstrip": false,
|
688 |
+
"single_word": false,
|
689 |
+
"special": true
|
690 |
+
},
|
691 |
+
"128086": {
|
692 |
+
"content": "<|reserved_special_token_78|>",
|
693 |
+
"lstrip": false,
|
694 |
+
"normalized": false,
|
695 |
+
"rstrip": false,
|
696 |
+
"single_word": false,
|
697 |
+
"special": true
|
698 |
+
},
|
699 |
+
"128087": {
|
700 |
+
"content": "<|reserved_special_token_79|>",
|
701 |
+
"lstrip": false,
|
702 |
+
"normalized": false,
|
703 |
+
"rstrip": false,
|
704 |
+
"single_word": false,
|
705 |
+
"special": true
|
706 |
+
},
|
707 |
+
"128088": {
|
708 |
+
"content": "<|reserved_special_token_80|>",
|
709 |
+
"lstrip": false,
|
710 |
+
"normalized": false,
|
711 |
+
"rstrip": false,
|
712 |
+
"single_word": false,
|
713 |
+
"special": true
|
714 |
+
},
|
715 |
+
"128089": {
|
716 |
+
"content": "<|reserved_special_token_81|>",
|
717 |
+
"lstrip": false,
|
718 |
+
"normalized": false,
|
719 |
+
"rstrip": false,
|
720 |
+
"single_word": false,
|
721 |
+
"special": true
|
722 |
+
},
|
723 |
+
"128090": {
|
724 |
+
"content": "<|reserved_special_token_82|>",
|
725 |
+
"lstrip": false,
|
726 |
+
"normalized": false,
|
727 |
+
"rstrip": false,
|
728 |
+
"single_word": false,
|
729 |
+
"special": true
|
730 |
+
},
|
731 |
+
"128091": {
|
732 |
+
"content": "<|reserved_special_token_83|>",
|
733 |
+
"lstrip": false,
|
734 |
+
"normalized": false,
|
735 |
+
"rstrip": false,
|
736 |
+
"single_word": false,
|
737 |
+
"special": true
|
738 |
+
},
|
739 |
+
"128092": {
|
740 |
+
"content": "<|reserved_special_token_84|>",
|
741 |
+
"lstrip": false,
|
742 |
+
"normalized": false,
|
743 |
+
"rstrip": false,
|
744 |
+
"single_word": false,
|
745 |
+
"special": true
|
746 |
+
},
|
747 |
+
"128093": {
|
748 |
+
"content": "<|reserved_special_token_85|>",
|
749 |
+
"lstrip": false,
|
750 |
+
"normalized": false,
|
751 |
+
"rstrip": false,
|
752 |
+
"single_word": false,
|
753 |
+
"special": true
|
754 |
+
},
|
755 |
+
"128094": {
|
756 |
+
"content": "<|reserved_special_token_86|>",
|
757 |
+
"lstrip": false,
|
758 |
+
"normalized": false,
|
759 |
+
"rstrip": false,
|
760 |
+
"single_word": false,
|
761 |
+
"special": true
|
762 |
+
},
|
763 |
+
"128095": {
|
764 |
+
"content": "<|reserved_special_token_87|>",
|
765 |
+
"lstrip": false,
|
766 |
+
"normalized": false,
|
767 |
+
"rstrip": false,
|
768 |
+
"single_word": false,
|
769 |
+
"special": true
|
770 |
+
},
|
771 |
+
"128096": {
|
772 |
+
"content": "<|reserved_special_token_88|>",
|
773 |
+
"lstrip": false,
|
774 |
+
"normalized": false,
|
775 |
+
"rstrip": false,
|
776 |
+
"single_word": false,
|
777 |
+
"special": true
|
778 |
+
},
|
779 |
+
"128097": {
|
780 |
+
"content": "<|reserved_special_token_89|>",
|
781 |
+
"lstrip": false,
|
782 |
+
"normalized": false,
|
783 |
+
"rstrip": false,
|
784 |
+
"single_word": false,
|
785 |
+
"special": true
|
786 |
+
},
|
787 |
+
"128098": {
|
788 |
+
"content": "<|reserved_special_token_90|>",
|
789 |
+
"lstrip": false,
|
790 |
+
"normalized": false,
|
791 |
+
"rstrip": false,
|
792 |
+
"single_word": false,
|
793 |
+
"special": true
|
794 |
+
},
|
795 |
+
"128099": {
|
796 |
+
"content": "<|reserved_special_token_91|>",
|
797 |
+
"lstrip": false,
|
798 |
+
"normalized": false,
|
799 |
+
"rstrip": false,
|
800 |
+
"single_word": false,
|
801 |
+
"special": true
|
802 |
+
},
|
803 |
+
"128100": {
|
804 |
+
"content": "<|reserved_special_token_92|>",
|
805 |
+
"lstrip": false,
|
806 |
+
"normalized": false,
|
807 |
+
"rstrip": false,
|
808 |
+
"single_word": false,
|
809 |
+
"special": true
|
810 |
+
},
|
811 |
+
"128101": {
|
812 |
+
"content": "<|reserved_special_token_93|>",
|
813 |
+
"lstrip": false,
|
814 |
+
"normalized": false,
|
815 |
+
"rstrip": false,
|
816 |
+
"single_word": false,
|
817 |
+
"special": true
|
818 |
+
},
|
819 |
+
"128102": {
|
820 |
+
"content": "<|reserved_special_token_94|>",
|
821 |
+
"lstrip": false,
|
822 |
+
"normalized": false,
|
823 |
+
"rstrip": false,
|
824 |
+
"single_word": false,
|
825 |
+
"special": true
|
826 |
+
},
|
827 |
+
"128103": {
|
828 |
+
"content": "<|reserved_special_token_95|>",
|
829 |
+
"lstrip": false,
|
830 |
+
"normalized": false,
|
831 |
+
"rstrip": false,
|
832 |
+
"single_word": false,
|
833 |
+
"special": true
|
834 |
+
},
|
835 |
+
"128104": {
|
836 |
+
"content": "<|reserved_special_token_96|>",
|
837 |
+
"lstrip": false,
|
838 |
+
"normalized": false,
|
839 |
+
"rstrip": false,
|
840 |
+
"single_word": false,
|
841 |
+
"special": true
|
842 |
+
},
|
843 |
+
"128105": {
|
844 |
+
"content": "<|reserved_special_token_97|>",
|
845 |
+
"lstrip": false,
|
846 |
+
"normalized": false,
|
847 |
+
"rstrip": false,
|
848 |
+
"single_word": false,
|
849 |
+
"special": true
|
850 |
+
},
|
851 |
+
"128106": {
|
852 |
+
"content": "<|reserved_special_token_98|>",
|
853 |
+
"lstrip": false,
|
854 |
+
"normalized": false,
|
855 |
+
"rstrip": false,
|
856 |
+
"single_word": false,
|
857 |
+
"special": true
|
858 |
+
},
|
859 |
+
"128107": {
|
860 |
+
"content": "<|reserved_special_token_99|>",
|
861 |
+
"lstrip": false,
|
862 |
+
"normalized": false,
|
863 |
+
"rstrip": false,
|
864 |
+
"single_word": false,
|
865 |
+
"special": true
|
866 |
+
},
|
867 |
+
"128108": {
|
868 |
+
"content": "<|reserved_special_token_100|>",
|
869 |
+
"lstrip": false,
|
870 |
+
"normalized": false,
|
871 |
+
"rstrip": false,
|
872 |
+
"single_word": false,
|
873 |
+
"special": true
|
874 |
+
},
|
875 |
+
"128109": {
|
876 |
+
"content": "<|reserved_special_token_101|>",
|
877 |
+
"lstrip": false,
|
878 |
+
"normalized": false,
|
879 |
+
"rstrip": false,
|
880 |
+
"single_word": false,
|
881 |
+
"special": true
|
882 |
+
},
|
883 |
+
"128110": {
|
884 |
+
"content": "<|reserved_special_token_102|>",
|
885 |
+
"lstrip": false,
|
886 |
+
"normalized": false,
|
887 |
+
"rstrip": false,
|
888 |
+
"single_word": false,
|
889 |
+
"special": true
|
890 |
+
},
|
891 |
+
"128111": {
|
892 |
+
"content": "<|reserved_special_token_103|>",
|
893 |
+
"lstrip": false,
|
894 |
+
"normalized": false,
|
895 |
+
"rstrip": false,
|
896 |
+
"single_word": false,
|
897 |
+
"special": true
|
898 |
+
},
|
899 |
+
"128112": {
|
900 |
+
"content": "<|reserved_special_token_104|>",
|
901 |
+
"lstrip": false,
|
902 |
+
"normalized": false,
|
903 |
+
"rstrip": false,
|
904 |
+
"single_word": false,
|
905 |
+
"special": true
|
906 |
+
},
|
907 |
+
"128113": {
|
908 |
+
"content": "<|reserved_special_token_105|>",
|
909 |
+
"lstrip": false,
|
910 |
+
"normalized": false,
|
911 |
+
"rstrip": false,
|
912 |
+
"single_word": false,
|
913 |
+
"special": true
|
914 |
+
},
|
915 |
+
"128114": {
|
916 |
+
"content": "<|reserved_special_token_106|>",
|
917 |
+
"lstrip": false,
|
918 |
+
"normalized": false,
|
919 |
+
"rstrip": false,
|
920 |
+
"single_word": false,
|
921 |
+
"special": true
|
922 |
+
},
|
923 |
+
"128115": {
|
924 |
+
"content": "<|reserved_special_token_107|>",
|
925 |
+
"lstrip": false,
|
926 |
+
"normalized": false,
|
927 |
+
"rstrip": false,
|
928 |
+
"single_word": false,
|
929 |
+
"special": true
|
930 |
+
},
|
931 |
+
"128116": {
|
932 |
+
"content": "<|reserved_special_token_108|>",
|
933 |
+
"lstrip": false,
|
934 |
+
"normalized": false,
|
935 |
+
"rstrip": false,
|
936 |
+
"single_word": false,
|
937 |
+
"special": true
|
938 |
+
},
|
939 |
+
"128117": {
|
940 |
+
"content": "<|reserved_special_token_109|>",
|
941 |
+
"lstrip": false,
|
942 |
+
"normalized": false,
|
943 |
+
"rstrip": false,
|
944 |
+
"single_word": false,
|
945 |
+
"special": true
|
946 |
+
},
|
947 |
+
"128118": {
|
948 |
+
"content": "<|reserved_special_token_110|>",
|
949 |
+
"lstrip": false,
|
950 |
+
"normalized": false,
|
951 |
+
"rstrip": false,
|
952 |
+
"single_word": false,
|
953 |
+
"special": true
|
954 |
+
},
|
955 |
+
"128119": {
|
956 |
+
"content": "<|reserved_special_token_111|>",
|
957 |
+
"lstrip": false,
|
958 |
+
"normalized": false,
|
959 |
+
"rstrip": false,
|
960 |
+
"single_word": false,
|
961 |
+
"special": true
|
962 |
+
},
|
963 |
+
"128120": {
|
964 |
+
"content": "<|reserved_special_token_112|>",
|
965 |
+
"lstrip": false,
|
966 |
+
"normalized": false,
|
967 |
+
"rstrip": false,
|
968 |
+
"single_word": false,
|
969 |
+
"special": true
|
970 |
+
},
|
971 |
+
"128121": {
|
972 |
+
"content": "<|reserved_special_token_113|>",
|
973 |
+
"lstrip": false,
|
974 |
+
"normalized": false,
|
975 |
+
"rstrip": false,
|
976 |
+
"single_word": false,
|
977 |
+
"special": true
|
978 |
+
},
|
979 |
+
"128122": {
|
980 |
+
"content": "<|reserved_special_token_114|>",
|
981 |
+
"lstrip": false,
|
982 |
+
"normalized": false,
|
983 |
+
"rstrip": false,
|
984 |
+
"single_word": false,
|
985 |
+
"special": true
|
986 |
+
},
|
987 |
+
"128123": {
|
988 |
+
"content": "<|reserved_special_token_115|>",
|
989 |
+
"lstrip": false,
|
990 |
+
"normalized": false,
|
991 |
+
"rstrip": false,
|
992 |
+
"single_word": false,
|
993 |
+
"special": true
|
994 |
+
},
|
995 |
+
"128124": {
|
996 |
+
"content": "<|reserved_special_token_116|>",
|
997 |
+
"lstrip": false,
|
998 |
+
"normalized": false,
|
999 |
+
"rstrip": false,
|
1000 |
+
"single_word": false,
|
1001 |
+
"special": true
|
1002 |
+
},
|
1003 |
+
"128125": {
|
1004 |
+
"content": "<|reserved_special_token_117|>",
|
1005 |
+
"lstrip": false,
|
1006 |
+
"normalized": false,
|
1007 |
+
"rstrip": false,
|
1008 |
+
"single_word": false,
|
1009 |
+
"special": true
|
1010 |
+
},
|
1011 |
+
"128126": {
|
1012 |
+
"content": "<|reserved_special_token_118|>",
|
1013 |
+
"lstrip": false,
|
1014 |
+
"normalized": false,
|
1015 |
+
"rstrip": false,
|
1016 |
+
"single_word": false,
|
1017 |
+
"special": true
|
1018 |
+
},
|
1019 |
+
"128127": {
|
1020 |
+
"content": "<|reserved_special_token_119|>",
|
1021 |
+
"lstrip": false,
|
1022 |
+
"normalized": false,
|
1023 |
+
"rstrip": false,
|
1024 |
+
"single_word": false,
|
1025 |
+
"special": true
|
1026 |
+
},
|
1027 |
+
"128128": {
|
1028 |
+
"content": "<|reserved_special_token_120|>",
|
1029 |
+
"lstrip": false,
|
1030 |
+
"normalized": false,
|
1031 |
+
"rstrip": false,
|
1032 |
+
"single_word": false,
|
1033 |
+
"special": true
|
1034 |
+
},
|
1035 |
+
"128129": {
|
1036 |
+
"content": "<|reserved_special_token_121|>",
|
1037 |
+
"lstrip": false,
|
1038 |
+
"normalized": false,
|
1039 |
+
"rstrip": false,
|
1040 |
+
"single_word": false,
|
1041 |
+
"special": true
|
1042 |
+
},
|
1043 |
+
"128130": {
|
1044 |
+
"content": "<|reserved_special_token_122|>",
|
1045 |
+
"lstrip": false,
|
1046 |
+
"normalized": false,
|
1047 |
+
"rstrip": false,
|
1048 |
+
"single_word": false,
|
1049 |
+
"special": true
|
1050 |
+
},
|
1051 |
+
"128131": {
|
1052 |
+
"content": "<|reserved_special_token_123|>",
|
1053 |
+
"lstrip": false,
|
1054 |
+
"normalized": false,
|
1055 |
+
"rstrip": false,
|
1056 |
+
"single_word": false,
|
1057 |
+
"special": true
|
1058 |
+
},
|
1059 |
+
"128132": {
|
1060 |
+
"content": "<|reserved_special_token_124|>",
|
1061 |
+
"lstrip": false,
|
1062 |
+
"normalized": false,
|
1063 |
+
"rstrip": false,
|
1064 |
+
"single_word": false,
|
1065 |
+
"special": true
|
1066 |
+
},
|
1067 |
+
"128133": {
|
1068 |
+
"content": "<|reserved_special_token_125|>",
|
1069 |
+
"lstrip": false,
|
1070 |
+
"normalized": false,
|
1071 |
+
"rstrip": false,
|
1072 |
+
"single_word": false,
|
1073 |
+
"special": true
|
1074 |
+
},
|
1075 |
+
"128134": {
|
1076 |
+
"content": "<|reserved_special_token_126|>",
|
1077 |
+
"lstrip": false,
|
1078 |
+
"normalized": false,
|
1079 |
+
"rstrip": false,
|
1080 |
+
"single_word": false,
|
1081 |
+
"special": true
|
1082 |
+
},
|
1083 |
+
"128135": {
|
1084 |
+
"content": "<|reserved_special_token_127|>",
|
1085 |
+
"lstrip": false,
|
1086 |
+
"normalized": false,
|
1087 |
+
"rstrip": false,
|
1088 |
+
"single_word": false,
|
1089 |
+
"special": true
|
1090 |
+
},
|
1091 |
+
"128136": {
|
1092 |
+
"content": "<|reserved_special_token_128|>",
|
1093 |
+
"lstrip": false,
|
1094 |
+
"normalized": false,
|
1095 |
+
"rstrip": false,
|
1096 |
+
"single_word": false,
|
1097 |
+
"special": true
|
1098 |
+
},
|
1099 |
+
"128137": {
|
1100 |
+
"content": "<|reserved_special_token_129|>",
|
1101 |
+
"lstrip": false,
|
1102 |
+
"normalized": false,
|
1103 |
+
"rstrip": false,
|
1104 |
+
"single_word": false,
|
1105 |
+
"special": true
|
1106 |
+
},
|
1107 |
+
"128138": {
|
1108 |
+
"content": "<|reserved_special_token_130|>",
|
1109 |
+
"lstrip": false,
|
1110 |
+
"normalized": false,
|
1111 |
+
"rstrip": false,
|
1112 |
+
"single_word": false,
|
1113 |
+
"special": true
|
1114 |
+
},
|
1115 |
+
"128139": {
|
1116 |
+
"content": "<|reserved_special_token_131|>",
|
1117 |
+
"lstrip": false,
|
1118 |
+
"normalized": false,
|
1119 |
+
"rstrip": false,
|
1120 |
+
"single_word": false,
|
1121 |
+
"special": true
|
1122 |
+
},
|
1123 |
+
"128140": {
|
1124 |
+
"content": "<|reserved_special_token_132|>",
|
1125 |
+
"lstrip": false,
|
1126 |
+
"normalized": false,
|
1127 |
+
"rstrip": false,
|
1128 |
+
"single_word": false,
|
1129 |
+
"special": true
|
1130 |
+
},
|
1131 |
+
"128141": {
|
1132 |
+
"content": "<|reserved_special_token_133|>",
|
1133 |
+
"lstrip": false,
|
1134 |
+
"normalized": false,
|
1135 |
+
"rstrip": false,
|
1136 |
+
"single_word": false,
|
1137 |
+
"special": true
|
1138 |
+
},
|
1139 |
+
"128142": {
|
1140 |
+
"content": "<|reserved_special_token_134|>",
|
1141 |
+
"lstrip": false,
|
1142 |
+
"normalized": false,
|
1143 |
+
"rstrip": false,
|
1144 |
+
"single_word": false,
|
1145 |
+
"special": true
|
1146 |
+
},
|
1147 |
+
"128143": {
|
1148 |
+
"content": "<|reserved_special_token_135|>",
|
1149 |
+
"lstrip": false,
|
1150 |
+
"normalized": false,
|
1151 |
+
"rstrip": false,
|
1152 |
+
"single_word": false,
|
1153 |
+
"special": true
|
1154 |
+
},
|
1155 |
+
"128144": {
|
1156 |
+
"content": "<|reserved_special_token_136|>",
|
1157 |
+
"lstrip": false,
|
1158 |
+
"normalized": false,
|
1159 |
+
"rstrip": false,
|
1160 |
+
"single_word": false,
|
1161 |
+
"special": true
|
1162 |
+
},
|
1163 |
+
"128145": {
|
1164 |
+
"content": "<|reserved_special_token_137|>",
|
1165 |
+
"lstrip": false,
|
1166 |
+
"normalized": false,
|
1167 |
+
"rstrip": false,
|
1168 |
+
"single_word": false,
|
1169 |
+
"special": true
|
1170 |
+
},
|
1171 |
+
"128146": {
|
1172 |
+
"content": "<|reserved_special_token_138|>",
|
1173 |
+
"lstrip": false,
|
1174 |
+
"normalized": false,
|
1175 |
+
"rstrip": false,
|
1176 |
+
"single_word": false,
|
1177 |
+
"special": true
|
1178 |
+
},
|
1179 |
+
"128147": {
|
1180 |
+
"content": "<|reserved_special_token_139|>",
|
1181 |
+
"lstrip": false,
|
1182 |
+
"normalized": false,
|
1183 |
+
"rstrip": false,
|
1184 |
+
"single_word": false,
|
1185 |
+
"special": true
|
1186 |
+
},
|
1187 |
+
"128148": {
|
1188 |
+
"content": "<|reserved_special_token_140|>",
|
1189 |
+
"lstrip": false,
|
1190 |
+
"normalized": false,
|
1191 |
+
"rstrip": false,
|
1192 |
+
"single_word": false,
|
1193 |
+
"special": true
|
1194 |
+
},
|
1195 |
+
"128149": {
|
1196 |
+
"content": "<|reserved_special_token_141|>",
|
1197 |
+
"lstrip": false,
|
1198 |
+
"normalized": false,
|
1199 |
+
"rstrip": false,
|
1200 |
+
"single_word": false,
|
1201 |
+
"special": true
|
1202 |
+
},
|
1203 |
+
"128150": {
|
1204 |
+
"content": "<|reserved_special_token_142|>",
|
1205 |
+
"lstrip": false,
|
1206 |
+
"normalized": false,
|
1207 |
+
"rstrip": false,
|
1208 |
+
"single_word": false,
|
1209 |
+
"special": true
|
1210 |
+
},
|
1211 |
+
"128151": {
|
1212 |
+
"content": "<|reserved_special_token_143|>",
|
1213 |
+
"lstrip": false,
|
1214 |
+
"normalized": false,
|
1215 |
+
"rstrip": false,
|
1216 |
+
"single_word": false,
|
1217 |
+
"special": true
|
1218 |
+
},
|
1219 |
+
"128152": {
|
1220 |
+
"content": "<|reserved_special_token_144|>",
|
1221 |
+
"lstrip": false,
|
1222 |
+
"normalized": false,
|
1223 |
+
"rstrip": false,
|
1224 |
+
"single_word": false,
|
1225 |
+
"special": true
|
1226 |
+
},
|
1227 |
+
"128153": {
|
1228 |
+
"content": "<|reserved_special_token_145|>",
|
1229 |
+
"lstrip": false,
|
1230 |
+
"normalized": false,
|
1231 |
+
"rstrip": false,
|
1232 |
+
"single_word": false,
|
1233 |
+
"special": true
|
1234 |
+
},
|
1235 |
+
"128154": {
|
1236 |
+
"content": "<|reserved_special_token_146|>",
|
1237 |
+
"lstrip": false,
|
1238 |
+
"normalized": false,
|
1239 |
+
"rstrip": false,
|
1240 |
+
"single_word": false,
|
1241 |
+
"special": true
|
1242 |
+
},
|
1243 |
+
"128155": {
|
1244 |
+
"content": "<|reserved_special_token_147|>",
|
1245 |
+
"lstrip": false,
|
1246 |
+
"normalized": false,
|
1247 |
+
"rstrip": false,
|
1248 |
+
"single_word": false,
|
1249 |
+
"special": true
|
1250 |
+
},
|
1251 |
+
"128156": {
|
1252 |
+
"content": "<|reserved_special_token_148|>",
|
1253 |
+
"lstrip": false,
|
1254 |
+
"normalized": false,
|
1255 |
+
"rstrip": false,
|
1256 |
+
"single_word": false,
|
1257 |
+
"special": true
|
1258 |
+
},
|
1259 |
+
"128157": {
|
1260 |
+
"content": "<|reserved_special_token_149|>",
|
1261 |
+
"lstrip": false,
|
1262 |
+
"normalized": false,
|
1263 |
+
"rstrip": false,
|
1264 |
+
"single_word": false,
|
1265 |
+
"special": true
|
1266 |
+
},
|
1267 |
+
"128158": {
|
1268 |
+
"content": "<|reserved_special_token_150|>",
|
1269 |
+
"lstrip": false,
|
1270 |
+
"normalized": false,
|
1271 |
+
"rstrip": false,
|
1272 |
+
"single_word": false,
|
1273 |
+
"special": true
|
1274 |
+
},
|
1275 |
+
"128159": {
|
1276 |
+
"content": "<|reserved_special_token_151|>",
|
1277 |
+
"lstrip": false,
|
1278 |
+
"normalized": false,
|
1279 |
+
"rstrip": false,
|
1280 |
+
"single_word": false,
|
1281 |
+
"special": true
|
1282 |
+
},
|
1283 |
+
"128160": {
|
1284 |
+
"content": "<|reserved_special_token_152|>",
|
1285 |
+
"lstrip": false,
|
1286 |
+
"normalized": false,
|
1287 |
+
"rstrip": false,
|
1288 |
+
"single_word": false,
|
1289 |
+
"special": true
|
1290 |
+
},
|
1291 |
+
"128161": {
|
1292 |
+
"content": "<|reserved_special_token_153|>",
|
1293 |
+
"lstrip": false,
|
1294 |
+
"normalized": false,
|
1295 |
+
"rstrip": false,
|
1296 |
+
"single_word": false,
|
1297 |
+
"special": true
|
1298 |
+
},
|
1299 |
+
"128162": {
|
1300 |
+
"content": "<|reserved_special_token_154|>",
|
1301 |
+
"lstrip": false,
|
1302 |
+
"normalized": false,
|
1303 |
+
"rstrip": false,
|
1304 |
+
"single_word": false,
|
1305 |
+
"special": true
|
1306 |
+
},
|
1307 |
+
"128163": {
|
1308 |
+
"content": "<|reserved_special_token_155|>",
|
1309 |
+
"lstrip": false,
|
1310 |
+
"normalized": false,
|
1311 |
+
"rstrip": false,
|
1312 |
+
"single_word": false,
|
1313 |
+
"special": true
|
1314 |
+
},
|
1315 |
+
"128164": {
|
1316 |
+
"content": "<|reserved_special_token_156|>",
|
1317 |
+
"lstrip": false,
|
1318 |
+
"normalized": false,
|
1319 |
+
"rstrip": false,
|
1320 |
+
"single_word": false,
|
1321 |
+
"special": true
|
1322 |
+
},
|
1323 |
+
"128165": {
|
1324 |
+
"content": "<|reserved_special_token_157|>",
|
1325 |
+
"lstrip": false,
|
1326 |
+
"normalized": false,
|
1327 |
+
"rstrip": false,
|
1328 |
+
"single_word": false,
|
1329 |
+
"special": true
|
1330 |
+
},
|
1331 |
+
"128166": {
|
1332 |
+
"content": "<|reserved_special_token_158|>",
|
1333 |
+
"lstrip": false,
|
1334 |
+
"normalized": false,
|
1335 |
+
"rstrip": false,
|
1336 |
+
"single_word": false,
|
1337 |
+
"special": true
|
1338 |
+
},
|
1339 |
+
"128167": {
|
1340 |
+
"content": "<|reserved_special_token_159|>",
|
1341 |
+
"lstrip": false,
|
1342 |
+
"normalized": false,
|
1343 |
+
"rstrip": false,
|
1344 |
+
"single_word": false,
|
1345 |
+
"special": true
|
1346 |
+
},
|
1347 |
+
"128168": {
|
1348 |
+
"content": "<|reserved_special_token_160|>",
|
1349 |
+
"lstrip": false,
|
1350 |
+
"normalized": false,
|
1351 |
+
"rstrip": false,
|
1352 |
+
"single_word": false,
|
1353 |
+
"special": true
|
1354 |
+
},
|
1355 |
+
"128169": {
|
1356 |
+
"content": "<|reserved_special_token_161|>",
|
1357 |
+
"lstrip": false,
|
1358 |
+
"normalized": false,
|
1359 |
+
"rstrip": false,
|
1360 |
+
"single_word": false,
|
1361 |
+
"special": true
|
1362 |
+
},
|
1363 |
+
"128170": {
|
1364 |
+
"content": "<|reserved_special_token_162|>",
|
1365 |
+
"lstrip": false,
|
1366 |
+
"normalized": false,
|
1367 |
+
"rstrip": false,
|
1368 |
+
"single_word": false,
|
1369 |
+
"special": true
|
1370 |
+
},
|
1371 |
+
"128171": {
|
1372 |
+
"content": "<|reserved_special_token_163|>",
|
1373 |
+
"lstrip": false,
|
1374 |
+
"normalized": false,
|
1375 |
+
"rstrip": false,
|
1376 |
+
"single_word": false,
|
1377 |
+
"special": true
|
1378 |
+
},
|
1379 |
+
"128172": {
|
1380 |
+
"content": "<|reserved_special_token_164|>",
|
1381 |
+
"lstrip": false,
|
1382 |
+
"normalized": false,
|
1383 |
+
"rstrip": false,
|
1384 |
+
"single_word": false,
|
1385 |
+
"special": true
|
1386 |
+
},
|
1387 |
+
"128173": {
|
1388 |
+
"content": "<|reserved_special_token_165|>",
|
1389 |
+
"lstrip": false,
|
1390 |
+
"normalized": false,
|
1391 |
+
"rstrip": false,
|
1392 |
+
"single_word": false,
|
1393 |
+
"special": true
|
1394 |
+
},
|
1395 |
+
"128174": {
|
1396 |
+
"content": "<|reserved_special_token_166|>",
|
1397 |
+
"lstrip": false,
|
1398 |
+
"normalized": false,
|
1399 |
+
"rstrip": false,
|
1400 |
+
"single_word": false,
|
1401 |
+
"special": true
|
1402 |
+
},
|
1403 |
+
"128175": {
|
1404 |
+
"content": "<|reserved_special_token_167|>",
|
1405 |
+
"lstrip": false,
|
1406 |
+
"normalized": false,
|
1407 |
+
"rstrip": false,
|
1408 |
+
"single_word": false,
|
1409 |
+
"special": true
|
1410 |
+
},
|
1411 |
+
"128176": {
|
1412 |
+
"content": "<|reserved_special_token_168|>",
|
1413 |
+
"lstrip": false,
|
1414 |
+
"normalized": false,
|
1415 |
+
"rstrip": false,
|
1416 |
+
"single_word": false,
|
1417 |
+
"special": true
|
1418 |
+
},
|
1419 |
+
"128177": {
|
1420 |
+
"content": "<|reserved_special_token_169|>",
|
1421 |
+
"lstrip": false,
|
1422 |
+
"normalized": false,
|
1423 |
+
"rstrip": false,
|
1424 |
+
"single_word": false,
|
1425 |
+
"special": true
|
1426 |
+
},
|
1427 |
+
"128178": {
|
1428 |
+
"content": "<|reserved_special_token_170|>",
|
1429 |
+
"lstrip": false,
|
1430 |
+
"normalized": false,
|
1431 |
+
"rstrip": false,
|
1432 |
+
"single_word": false,
|
1433 |
+
"special": true
|
1434 |
+
},
|
1435 |
+
"128179": {
|
1436 |
+
"content": "<|reserved_special_token_171|>",
|
1437 |
+
"lstrip": false,
|
1438 |
+
"normalized": false,
|
1439 |
+
"rstrip": false,
|
1440 |
+
"single_word": false,
|
1441 |
+
"special": true
|
1442 |
+
},
|
1443 |
+
"128180": {
|
1444 |
+
"content": "<|reserved_special_token_172|>",
|
1445 |
+
"lstrip": false,
|
1446 |
+
"normalized": false,
|
1447 |
+
"rstrip": false,
|
1448 |
+
"single_word": false,
|
1449 |
+
"special": true
|
1450 |
+
},
|
1451 |
+
"128181": {
|
1452 |
+
"content": "<|reserved_special_token_173|>",
|
1453 |
+
"lstrip": false,
|
1454 |
+
"normalized": false,
|
1455 |
+
"rstrip": false,
|
1456 |
+
"single_word": false,
|
1457 |
+
"special": true
|
1458 |
+
},
|
1459 |
+
"128182": {
|
1460 |
+
"content": "<|reserved_special_token_174|>",
|
1461 |
+
"lstrip": false,
|
1462 |
+
"normalized": false,
|
1463 |
+
"rstrip": false,
|
1464 |
+
"single_word": false,
|
1465 |
+
"special": true
|
1466 |
+
},
|
1467 |
+
"128183": {
|
1468 |
+
"content": "<|reserved_special_token_175|>",
|
1469 |
+
"lstrip": false,
|
1470 |
+
"normalized": false,
|
1471 |
+
"rstrip": false,
|
1472 |
+
"single_word": false,
|
1473 |
+
"special": true
|
1474 |
+
},
|
1475 |
+
"128184": {
|
1476 |
+
"content": "<|reserved_special_token_176|>",
|
1477 |
+
"lstrip": false,
|
1478 |
+
"normalized": false,
|
1479 |
+
"rstrip": false,
|
1480 |
+
"single_word": false,
|
1481 |
+
"special": true
|
1482 |
+
},
|
1483 |
+
"128185": {
|
1484 |
+
"content": "<|reserved_special_token_177|>",
|
1485 |
+
"lstrip": false,
|
1486 |
+
"normalized": false,
|
1487 |
+
"rstrip": false,
|
1488 |
+
"single_word": false,
|
1489 |
+
"special": true
|
1490 |
+
},
|
1491 |
+
"128186": {
|
1492 |
+
"content": "<|reserved_special_token_178|>",
|
1493 |
+
"lstrip": false,
|
1494 |
+
"normalized": false,
|
1495 |
+
"rstrip": false,
|
1496 |
+
"single_word": false,
|
1497 |
+
"special": true
|
1498 |
+
},
|
1499 |
+
"128187": {
|
1500 |
+
"content": "<|reserved_special_token_179|>",
|
1501 |
+
"lstrip": false,
|
1502 |
+
"normalized": false,
|
1503 |
+
"rstrip": false,
|
1504 |
+
"single_word": false,
|
1505 |
+
"special": true
|
1506 |
+
},
|
1507 |
+
"128188": {
|
1508 |
+
"content": "<|reserved_special_token_180|>",
|
1509 |
+
"lstrip": false,
|
1510 |
+
"normalized": false,
|
1511 |
+
"rstrip": false,
|
1512 |
+
"single_word": false,
|
1513 |
+
"special": true
|
1514 |
+
},
|
1515 |
+
"128189": {
|
1516 |
+
"content": "<|reserved_special_token_181|>",
|
1517 |
+
"lstrip": false,
|
1518 |
+
"normalized": false,
|
1519 |
+
"rstrip": false,
|
1520 |
+
"single_word": false,
|
1521 |
+
"special": true
|
1522 |
+
},
|
1523 |
+
"128190": {
|
1524 |
+
"content": "<|reserved_special_token_182|>",
|
1525 |
+
"lstrip": false,
|
1526 |
+
"normalized": false,
|
1527 |
+
"rstrip": false,
|
1528 |
+
"single_word": false,
|
1529 |
+
"special": true
|
1530 |
+
},
|
1531 |
+
"128191": {
|
1532 |
+
"content": "<|reserved_special_token_183|>",
|
1533 |
+
"lstrip": false,
|
1534 |
+
"normalized": false,
|
1535 |
+
"rstrip": false,
|
1536 |
+
"single_word": false,
|
1537 |
+
"special": true
|
1538 |
+
},
|
1539 |
+
"128192": {
|
1540 |
+
"content": "<|reserved_special_token_184|>",
|
1541 |
+
"lstrip": false,
|
1542 |
+
"normalized": false,
|
1543 |
+
"rstrip": false,
|
1544 |
+
"single_word": false,
|
1545 |
+
"special": true
|
1546 |
+
},
|
1547 |
+
"128193": {
|
1548 |
+
"content": "<|reserved_special_token_185|>",
|
1549 |
+
"lstrip": false,
|
1550 |
+
"normalized": false,
|
1551 |
+
"rstrip": false,
|
1552 |
+
"single_word": false,
|
1553 |
+
"special": true
|
1554 |
+
},
|
1555 |
+
"128194": {
|
1556 |
+
"content": "<|reserved_special_token_186|>",
|
1557 |
+
"lstrip": false,
|
1558 |
+
"normalized": false,
|
1559 |
+
"rstrip": false,
|
1560 |
+
"single_word": false,
|
1561 |
+
"special": true
|
1562 |
+
},
|
1563 |
+
"128195": {
|
1564 |
+
"content": "<|reserved_special_token_187|>",
|
1565 |
+
"lstrip": false,
|
1566 |
+
"normalized": false,
|
1567 |
+
"rstrip": false,
|
1568 |
+
"single_word": false,
|
1569 |
+
"special": true
|
1570 |
+
},
|
1571 |
+
"128196": {
|
1572 |
+
"content": "<|reserved_special_token_188|>",
|
1573 |
+
"lstrip": false,
|
1574 |
+
"normalized": false,
|
1575 |
+
"rstrip": false,
|
1576 |
+
"single_word": false,
|
1577 |
+
"special": true
|
1578 |
+
},
|
1579 |
+
"128197": {
|
1580 |
+
"content": "<|reserved_special_token_189|>",
|
1581 |
+
"lstrip": false,
|
1582 |
+
"normalized": false,
|
1583 |
+
"rstrip": false,
|
1584 |
+
"single_word": false,
|
1585 |
+
"special": true
|
1586 |
+
},
|
1587 |
+
"128198": {
|
1588 |
+
"content": "<|reserved_special_token_190|>",
|
1589 |
+
"lstrip": false,
|
1590 |
+
"normalized": false,
|
1591 |
+
"rstrip": false,
|
1592 |
+
"single_word": false,
|
1593 |
+
"special": true
|
1594 |
+
},
|
1595 |
+
"128199": {
|
1596 |
+
"content": "<|reserved_special_token_191|>",
|
1597 |
+
"lstrip": false,
|
1598 |
+
"normalized": false,
|
1599 |
+
"rstrip": false,
|
1600 |
+
"single_word": false,
|
1601 |
+
"special": true
|
1602 |
+
},
|
1603 |
+
"128200": {
|
1604 |
+
"content": "<|reserved_special_token_192|>",
|
1605 |
+
"lstrip": false,
|
1606 |
+
"normalized": false,
|
1607 |
+
"rstrip": false,
|
1608 |
+
"single_word": false,
|
1609 |
+
"special": true
|
1610 |
+
},
|
1611 |
+
"128201": {
|
1612 |
+
"content": "<|reserved_special_token_193|>",
|
1613 |
+
"lstrip": false,
|
1614 |
+
"normalized": false,
|
1615 |
+
"rstrip": false,
|
1616 |
+
"single_word": false,
|
1617 |
+
"special": true
|
1618 |
+
},
|
1619 |
+
"128202": {
|
1620 |
+
"content": "<|reserved_special_token_194|>",
|
1621 |
+
"lstrip": false,
|
1622 |
+
"normalized": false,
|
1623 |
+
"rstrip": false,
|
1624 |
+
"single_word": false,
|
1625 |
+
"special": true
|
1626 |
+
},
|
1627 |
+
"128203": {
|
1628 |
+
"content": "<|reserved_special_token_195|>",
|
1629 |
+
"lstrip": false,
|
1630 |
+
"normalized": false,
|
1631 |
+
"rstrip": false,
|
1632 |
+
"single_word": false,
|
1633 |
+
"special": true
|
1634 |
+
},
|
1635 |
+
"128204": {
|
1636 |
+
"content": "<|reserved_special_token_196|>",
|
1637 |
+
"lstrip": false,
|
1638 |
+
"normalized": false,
|
1639 |
+
"rstrip": false,
|
1640 |
+
"single_word": false,
|
1641 |
+
"special": true
|
1642 |
+
},
|
1643 |
+
"128205": {
|
1644 |
+
"content": "<|reserved_special_token_197|>",
|
1645 |
+
"lstrip": false,
|
1646 |
+
"normalized": false,
|
1647 |
+
"rstrip": false,
|
1648 |
+
"single_word": false,
|
1649 |
+
"special": true
|
1650 |
+
},
|
1651 |
+
"128206": {
|
1652 |
+
"content": "<|reserved_special_token_198|>",
|
1653 |
+
"lstrip": false,
|
1654 |
+
"normalized": false,
|
1655 |
+
"rstrip": false,
|
1656 |
+
"single_word": false,
|
1657 |
+
"special": true
|
1658 |
+
},
|
1659 |
+
"128207": {
|
1660 |
+
"content": "<|reserved_special_token_199|>",
|
1661 |
+
"lstrip": false,
|
1662 |
+
"normalized": false,
|
1663 |
+
"rstrip": false,
|
1664 |
+
"single_word": false,
|
1665 |
+
"special": true
|
1666 |
+
},
|
1667 |
+
"128208": {
|
1668 |
+
"content": "<|reserved_special_token_200|>",
|
1669 |
+
"lstrip": false,
|
1670 |
+
"normalized": false,
|
1671 |
+
"rstrip": false,
|
1672 |
+
"single_word": false,
|
1673 |
+
"special": true
|
1674 |
+
},
|
1675 |
+
"128209": {
|
1676 |
+
"content": "<|reserved_special_token_201|>",
|
1677 |
+
"lstrip": false,
|
1678 |
+
"normalized": false,
|
1679 |
+
"rstrip": false,
|
1680 |
+
"single_word": false,
|
1681 |
+
"special": true
|
1682 |
+
},
|
1683 |
+
"128210": {
|
1684 |
+
"content": "<|reserved_special_token_202|>",
|
1685 |
+
"lstrip": false,
|
1686 |
+
"normalized": false,
|
1687 |
+
"rstrip": false,
|
1688 |
+
"single_word": false,
|
1689 |
+
"special": true
|
1690 |
+
},
|
1691 |
+
"128211": {
|
1692 |
+
"content": "<|reserved_special_token_203|>",
|
1693 |
+
"lstrip": false,
|
1694 |
+
"normalized": false,
|
1695 |
+
"rstrip": false,
|
1696 |
+
"single_word": false,
|
1697 |
+
"special": true
|
1698 |
+
},
|
1699 |
+
"128212": {
|
1700 |
+
"content": "<|reserved_special_token_204|>",
|
1701 |
+
"lstrip": false,
|
1702 |
+
"normalized": false,
|
1703 |
+
"rstrip": false,
|
1704 |
+
"single_word": false,
|
1705 |
+
"special": true
|
1706 |
+
},
|
1707 |
+
"128213": {
|
1708 |
+
"content": "<|reserved_special_token_205|>",
|
1709 |
+
"lstrip": false,
|
1710 |
+
"normalized": false,
|
1711 |
+
"rstrip": false,
|
1712 |
+
"single_word": false,
|
1713 |
+
"special": true
|
1714 |
+
},
|
1715 |
+
"128214": {
|
1716 |
+
"content": "<|reserved_special_token_206|>",
|
1717 |
+
"lstrip": false,
|
1718 |
+
"normalized": false,
|
1719 |
+
"rstrip": false,
|
1720 |
+
"single_word": false,
|
1721 |
+
"special": true
|
1722 |
+
},
|
1723 |
+
"128215": {
|
1724 |
+
"content": "<|reserved_special_token_207|>",
|
1725 |
+
"lstrip": false,
|
1726 |
+
"normalized": false,
|
1727 |
+
"rstrip": false,
|
1728 |
+
"single_word": false,
|
1729 |
+
"special": true
|
1730 |
+
},
|
1731 |
+
"128216": {
|
1732 |
+
"content": "<|reserved_special_token_208|>",
|
1733 |
+
"lstrip": false,
|
1734 |
+
"normalized": false,
|
1735 |
+
"rstrip": false,
|
1736 |
+
"single_word": false,
|
1737 |
+
"special": true
|
1738 |
+
},
|
1739 |
+
"128217": {
|
1740 |
+
"content": "<|reserved_special_token_209|>",
|
1741 |
+
"lstrip": false,
|
1742 |
+
"normalized": false,
|
1743 |
+
"rstrip": false,
|
1744 |
+
"single_word": false,
|
1745 |
+
"special": true
|
1746 |
+
},
|
1747 |
+
"128218": {
|
1748 |
+
"content": "<|reserved_special_token_210|>",
|
1749 |
+
"lstrip": false,
|
1750 |
+
"normalized": false,
|
1751 |
+
"rstrip": false,
|
1752 |
+
"single_word": false,
|
1753 |
+
"special": true
|
1754 |
+
},
|
1755 |
+
"128219": {
|
1756 |
+
"content": "<|reserved_special_token_211|>",
|
1757 |
+
"lstrip": false,
|
1758 |
+
"normalized": false,
|
1759 |
+
"rstrip": false,
|
1760 |
+
"single_word": false,
|
1761 |
+
"special": true
|
1762 |
+
},
|
1763 |
+
"128220": {
|
1764 |
+
"content": "<|reserved_special_token_212|>",
|
1765 |
+
"lstrip": false,
|
1766 |
+
"normalized": false,
|
1767 |
+
"rstrip": false,
|
1768 |
+
"single_word": false,
|
1769 |
+
"special": true
|
1770 |
+
},
|
1771 |
+
"128221": {
|
1772 |
+
"content": "<|reserved_special_token_213|>",
|
1773 |
+
"lstrip": false,
|
1774 |
+
"normalized": false,
|
1775 |
+
"rstrip": false,
|
1776 |
+
"single_word": false,
|
1777 |
+
"special": true
|
1778 |
+
},
|
1779 |
+
"128222": {
|
1780 |
+
"content": "<|reserved_special_token_214|>",
|
1781 |
+
"lstrip": false,
|
1782 |
+
"normalized": false,
|
1783 |
+
"rstrip": false,
|
1784 |
+
"single_word": false,
|
1785 |
+
"special": true
|
1786 |
+
},
|
1787 |
+
"128223": {
|
1788 |
+
"content": "<|reserved_special_token_215|>",
|
1789 |
+
"lstrip": false,
|
1790 |
+
"normalized": false,
|
1791 |
+
"rstrip": false,
|
1792 |
+
"single_word": false,
|
1793 |
+
"special": true
|
1794 |
+
},
|
1795 |
+
"128224": {
|
1796 |
+
"content": "<|reserved_special_token_216|>",
|
1797 |
+
"lstrip": false,
|
1798 |
+
"normalized": false,
|
1799 |
+
"rstrip": false,
|
1800 |
+
"single_word": false,
|
1801 |
+
"special": true
|
1802 |
+
},
|
1803 |
+
"128225": {
|
1804 |
+
"content": "<|reserved_special_token_217|>",
|
1805 |
+
"lstrip": false,
|
1806 |
+
"normalized": false,
|
1807 |
+
"rstrip": false,
|
1808 |
+
"single_word": false,
|
1809 |
+
"special": true
|
1810 |
+
},
|
1811 |
+
"128226": {
|
1812 |
+
"content": "<|reserved_special_token_218|>",
|
1813 |
+
"lstrip": false,
|
1814 |
+
"normalized": false,
|
1815 |
+
"rstrip": false,
|
1816 |
+
"single_word": false,
|
1817 |
+
"special": true
|
1818 |
+
},
|
1819 |
+
"128227": {
|
1820 |
+
"content": "<|reserved_special_token_219|>",
|
1821 |
+
"lstrip": false,
|
1822 |
+
"normalized": false,
|
1823 |
+
"rstrip": false,
|
1824 |
+
"single_word": false,
|
1825 |
+
"special": true
|
1826 |
+
},
|
1827 |
+
"128228": {
|
1828 |
+
"content": "<|reserved_special_token_220|>",
|
1829 |
+
"lstrip": false,
|
1830 |
+
"normalized": false,
|
1831 |
+
"rstrip": false,
|
1832 |
+
"single_word": false,
|
1833 |
+
"special": true
|
1834 |
+
},
|
1835 |
+
"128229": {
|
1836 |
+
"content": "<|reserved_special_token_221|>",
|
1837 |
+
"lstrip": false,
|
1838 |
+
"normalized": false,
|
1839 |
+
"rstrip": false,
|
1840 |
+
"single_word": false,
|
1841 |
+
"special": true
|
1842 |
+
},
|
1843 |
+
"128230": {
|
1844 |
+
"content": "<|reserved_special_token_222|>",
|
1845 |
+
"lstrip": false,
|
1846 |
+
"normalized": false,
|
1847 |
+
"rstrip": false,
|
1848 |
+
"single_word": false,
|
1849 |
+
"special": true
|
1850 |
+
},
|
1851 |
+
"128231": {
|
1852 |
+
"content": "<|reserved_special_token_223|>",
|
1853 |
+
"lstrip": false,
|
1854 |
+
"normalized": false,
|
1855 |
+
"rstrip": false,
|
1856 |
+
"single_word": false,
|
1857 |
+
"special": true
|
1858 |
+
},
|
1859 |
+
"128232": {
|
1860 |
+
"content": "<|reserved_special_token_224|>",
|
1861 |
+
"lstrip": false,
|
1862 |
+
"normalized": false,
|
1863 |
+
"rstrip": false,
|
1864 |
+
"single_word": false,
|
1865 |
+
"special": true
|
1866 |
+
},
|
1867 |
+
"128233": {
|
1868 |
+
"content": "<|reserved_special_token_225|>",
|
1869 |
+
"lstrip": false,
|
1870 |
+
"normalized": false,
|
1871 |
+
"rstrip": false,
|
1872 |
+
"single_word": false,
|
1873 |
+
"special": true
|
1874 |
+
},
|
1875 |
+
"128234": {
|
1876 |
+
"content": "<|reserved_special_token_226|>",
|
1877 |
+
"lstrip": false,
|
1878 |
+
"normalized": false,
|
1879 |
+
"rstrip": false,
|
1880 |
+
"single_word": false,
|
1881 |
+
"special": true
|
1882 |
+
},
|
1883 |
+
"128235": {
|
1884 |
+
"content": "<|reserved_special_token_227|>",
|
1885 |
+
"lstrip": false,
|
1886 |
+
"normalized": false,
|
1887 |
+
"rstrip": false,
|
1888 |
+
"single_word": false,
|
1889 |
+
"special": true
|
1890 |
+
},
|
1891 |
+
"128236": {
|
1892 |
+
"content": "<|reserved_special_token_228|>",
|
1893 |
+
"lstrip": false,
|
1894 |
+
"normalized": false,
|
1895 |
+
"rstrip": false,
|
1896 |
+
"single_word": false,
|
1897 |
+
"special": true
|
1898 |
+
},
|
1899 |
+
"128237": {
|
1900 |
+
"content": "<|reserved_special_token_229|>",
|
1901 |
+
"lstrip": false,
|
1902 |
+
"normalized": false,
|
1903 |
+
"rstrip": false,
|
1904 |
+
"single_word": false,
|
1905 |
+
"special": true
|
1906 |
+
},
|
1907 |
+
"128238": {
|
1908 |
+
"content": "<|reserved_special_token_230|>",
|
1909 |
+
"lstrip": false,
|
1910 |
+
"normalized": false,
|
1911 |
+
"rstrip": false,
|
1912 |
+
"single_word": false,
|
1913 |
+
"special": true
|
1914 |
+
},
|
1915 |
+
"128239": {
|
1916 |
+
"content": "<|reserved_special_token_231|>",
|
1917 |
+
"lstrip": false,
|
1918 |
+
"normalized": false,
|
1919 |
+
"rstrip": false,
|
1920 |
+
"single_word": false,
|
1921 |
+
"special": true
|
1922 |
+
},
|
1923 |
+
"128240": {
|
1924 |
+
"content": "<|reserved_special_token_232|>",
|
1925 |
+
"lstrip": false,
|
1926 |
+
"normalized": false,
|
1927 |
+
"rstrip": false,
|
1928 |
+
"single_word": false,
|
1929 |
+
"special": true
|
1930 |
+
},
|
1931 |
+
"128241": {
|
1932 |
+
"content": "<|reserved_special_token_233|>",
|
1933 |
+
"lstrip": false,
|
1934 |
+
"normalized": false,
|
1935 |
+
"rstrip": false,
|
1936 |
+
"single_word": false,
|
1937 |
+
"special": true
|
1938 |
+
},
|
1939 |
+
"128242": {
|
1940 |
+
"content": "<|reserved_special_token_234|>",
|
1941 |
+
"lstrip": false,
|
1942 |
+
"normalized": false,
|
1943 |
+
"rstrip": false,
|
1944 |
+
"single_word": false,
|
1945 |
+
"special": true
|
1946 |
+
},
|
1947 |
+
"128243": {
|
1948 |
+
"content": "<|reserved_special_token_235|>",
|
1949 |
+
"lstrip": false,
|
1950 |
+
"normalized": false,
|
1951 |
+
"rstrip": false,
|
1952 |
+
"single_word": false,
|
1953 |
+
"special": true
|
1954 |
+
},
|
1955 |
+
"128244": {
|
1956 |
+
"content": "<|reserved_special_token_236|>",
|
1957 |
+
"lstrip": false,
|
1958 |
+
"normalized": false,
|
1959 |
+
"rstrip": false,
|
1960 |
+
"single_word": false,
|
1961 |
+
"special": true
|
1962 |
+
},
|
1963 |
+
"128245": {
|
1964 |
+
"content": "<|reserved_special_token_237|>",
|
1965 |
+
"lstrip": false,
|
1966 |
+
"normalized": false,
|
1967 |
+
"rstrip": false,
|
1968 |
+
"single_word": false,
|
1969 |
+
"special": true
|
1970 |
+
},
|
1971 |
+
"128246": {
|
1972 |
+
"content": "<|reserved_special_token_238|>",
|
1973 |
+
"lstrip": false,
|
1974 |
+
"normalized": false,
|
1975 |
+
"rstrip": false,
|
1976 |
+
"single_word": false,
|
1977 |
+
"special": true
|
1978 |
+
},
|
1979 |
+
"128247": {
|
1980 |
+
"content": "<|reserved_special_token_239|>",
|
1981 |
+
"lstrip": false,
|
1982 |
+
"normalized": false,
|
1983 |
+
"rstrip": false,
|
1984 |
+
"single_word": false,
|
1985 |
+
"special": true
|
1986 |
+
},
|
1987 |
+
"128248": {
|
1988 |
+
"content": "<|reserved_special_token_240|>",
|
1989 |
+
"lstrip": false,
|
1990 |
+
"normalized": false,
|
1991 |
+
"rstrip": false,
|
1992 |
+
"single_word": false,
|
1993 |
+
"special": true
|
1994 |
+
},
|
1995 |
+
"128249": {
|
1996 |
+
"content": "<|reserved_special_token_241|>",
|
1997 |
+
"lstrip": false,
|
1998 |
+
"normalized": false,
|
1999 |
+
"rstrip": false,
|
2000 |
+
"single_word": false,
|
2001 |
+
"special": true
|
2002 |
+
},
|
2003 |
+
"128250": {
|
2004 |
+
"content": "<|reserved_special_token_242|>",
|
2005 |
+
"lstrip": false,
|
2006 |
+
"normalized": false,
|
2007 |
+
"rstrip": false,
|
2008 |
+
"single_word": false,
|
2009 |
+
"special": true
|
2010 |
+
},
|
2011 |
+
"128251": {
|
2012 |
+
"content": "<|reserved_special_token_243|>",
|
2013 |
+
"lstrip": false,
|
2014 |
+
"normalized": false,
|
2015 |
+
"rstrip": false,
|
2016 |
+
"single_word": false,
|
2017 |
+
"special": true
|
2018 |
+
},
|
2019 |
+
"128252": {
|
2020 |
+
"content": "<|reserved_special_token_244|>",
|
2021 |
+
"lstrip": false,
|
2022 |
+
"normalized": false,
|
2023 |
+
"rstrip": false,
|
2024 |
+
"single_word": false,
|
2025 |
+
"special": true
|
2026 |
+
},
|
2027 |
+
"128253": {
|
2028 |
+
"content": "<|reserved_special_token_245|>",
|
2029 |
+
"lstrip": false,
|
2030 |
+
"normalized": false,
|
2031 |
+
"rstrip": false,
|
2032 |
+
"single_word": false,
|
2033 |
+
"special": true
|
2034 |
+
},
|
2035 |
+
"128254": {
|
2036 |
+
"content": "<|reserved_special_token_246|>",
|
2037 |
+
"lstrip": false,
|
2038 |
+
"normalized": false,
|
2039 |
+
"rstrip": false,
|
2040 |
+
"single_word": false,
|
2041 |
+
"special": true
|
2042 |
+
},
|
2043 |
+
"128255": {
|
2044 |
+
"content": "<|reserved_special_token_247|>",
|
2045 |
+
"lstrip": false,
|
2046 |
+
"normalized": false,
|
2047 |
+
"rstrip": false,
|
2048 |
+
"single_word": false,
|
2049 |
+
"special": true
|
2050 |
+
}
|
2051 |
+
},
|
2052 |
+
"bos_token": "<|begin_of_text|>",
|
2053 |
+
"chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n",
|
2054 |
+
"clean_up_tokenization_spaces": true,
|
2055 |
+
"eos_token": "<|eot_id|>",
|
2056 |
+
"model_input_names": [
|
2057 |
+
"input_ids",
|
2058 |
+
"attention_mask"
|
2059 |
+
],
|
2060 |
+
"model_max_length": 131072,
|
2061 |
+
"tokenizer_class": "PreTrainedTokenizerFast"
|
2062 |
+
}
|
transformers_4_44_2__activations.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from collections import OrderedDict
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from packaging import version
|
20 |
+
from torch import Tensor, nn
|
21 |
+
|
22 |
+
from transformers.utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
class PytorchGELUTanh(nn.Module):
|
29 |
+
"""
|
30 |
+
A fast C implementation of the tanh approximation of the GeLU activation function. See
|
31 |
+
https://arxiv.org/abs/1606.08415.
|
32 |
+
|
33 |
+
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
|
34 |
+
match due to rounding errors.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self):
|
38 |
+
super().__init__()
|
39 |
+
if version.parse(torch.__version__) < version.parse("1.12.0"):
|
40 |
+
raise ImportError(
|
41 |
+
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
|
42 |
+
"PytorchGELUTanh. Please upgrade torch."
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, input: Tensor) -> Tensor:
|
46 |
+
return nn.functional.gelu(input, approximate="tanh")
|
47 |
+
|
48 |
+
|
49 |
+
class NewGELUActivation(nn.Module):
|
50 |
+
"""
|
51 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
52 |
+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
53 |
+
"""
|
54 |
+
|
55 |
+
def forward(self, input: Tensor) -> Tensor:
|
56 |
+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
57 |
+
|
58 |
+
|
59 |
+
class GELUActivation(nn.Module):
|
60 |
+
"""
|
61 |
+
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
62 |
+
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
63 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
64 |
+
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, use_gelu_python: bool = False):
|
68 |
+
super().__init__()
|
69 |
+
if use_gelu_python:
|
70 |
+
self.act = self._gelu_python
|
71 |
+
else:
|
72 |
+
self.act = nn.functional.gelu
|
73 |
+
|
74 |
+
def _gelu_python(self, input: Tensor) -> Tensor:
|
75 |
+
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
|
76 |
+
|
77 |
+
def forward(self, input: Tensor) -> Tensor:
|
78 |
+
return self.act(input)
|
79 |
+
|
80 |
+
|
81 |
+
class FastGELUActivation(nn.Module):
|
82 |
+
"""
|
83 |
+
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
|
84 |
+
"""
|
85 |
+
|
86 |
+
def forward(self, input: Tensor) -> Tensor:
|
87 |
+
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
|
88 |
+
|
89 |
+
|
90 |
+
class QuickGELUActivation(nn.Module):
|
91 |
+
"""
|
92 |
+
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
|
93 |
+
"""
|
94 |
+
|
95 |
+
def forward(self, input: Tensor) -> Tensor:
|
96 |
+
return input * torch.sigmoid(1.702 * input)
|
97 |
+
|
98 |
+
|
99 |
+
class ClippedGELUActivation(nn.Module):
|
100 |
+
"""
|
101 |
+
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
|
102 |
+
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
|
103 |
+
https://arxiv.org/abs/2004.09602.
|
104 |
+
|
105 |
+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
106 |
+
initially created.
|
107 |
+
|
108 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
109 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, min: float, max: float):
|
113 |
+
if min > max:
|
114 |
+
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
|
115 |
+
|
116 |
+
super().__init__()
|
117 |
+
self.min = min
|
118 |
+
self.max = max
|
119 |
+
|
120 |
+
def forward(self, x: Tensor) -> Tensor:
|
121 |
+
return torch.clip(gelu(x), self.min, self.max)
|
122 |
+
|
123 |
+
|
124 |
+
class AccurateGELUActivation(nn.Module):
|
125 |
+
"""
|
126 |
+
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
|
127 |
+
https://github.com/hendrycks/GELUs
|
128 |
+
|
129 |
+
Implemented along with MEGA (Moving Average Equipped Gated Attention)
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self):
|
133 |
+
super().__init__()
|
134 |
+
self.precomputed_constant = math.sqrt(2 / math.pi)
|
135 |
+
|
136 |
+
def forward(self, input: Tensor) -> Tensor:
|
137 |
+
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
|
138 |
+
|
139 |
+
|
140 |
+
class MishActivation(nn.Module):
|
141 |
+
"""
|
142 |
+
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
143 |
+
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
144 |
+
"""
|
145 |
+
|
146 |
+
def __init__(self):
|
147 |
+
super().__init__()
|
148 |
+
if version.parse(torch.__version__) < version.parse("1.9.0"):
|
149 |
+
self.act = self._mish_python
|
150 |
+
else:
|
151 |
+
self.act = nn.functional.mish
|
152 |
+
|
153 |
+
def _mish_python(self, input: Tensor) -> Tensor:
|
154 |
+
return input * torch.tanh(nn.functional.softplus(input))
|
155 |
+
|
156 |
+
def forward(self, input: Tensor) -> Tensor:
|
157 |
+
return self.act(input)
|
158 |
+
|
159 |
+
|
160 |
+
class LinearActivation(nn.Module):
|
161 |
+
"""
|
162 |
+
Applies the linear activation function, i.e. forwarding input directly to output.
|
163 |
+
"""
|
164 |
+
|
165 |
+
def forward(self, input: Tensor) -> Tensor:
|
166 |
+
return input
|
167 |
+
|
168 |
+
|
169 |
+
class LaplaceActivation(nn.Module):
|
170 |
+
"""
|
171 |
+
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
|
172 |
+
https://arxiv.org/abs/2209.10655
|
173 |
+
|
174 |
+
Inspired by squared relu, but with bounded range and gradient for better stability
|
175 |
+
"""
|
176 |
+
|
177 |
+
def forward(self, input, mu=0.707107, sigma=0.282095):
|
178 |
+
input = (input - mu).div(sigma * math.sqrt(2.0))
|
179 |
+
return 0.5 * (1.0 + torch.erf(input))
|
180 |
+
|
181 |
+
|
182 |
+
class ReLUSquaredActivation(nn.Module):
|
183 |
+
"""
|
184 |
+
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
185 |
+
"""
|
186 |
+
|
187 |
+
def forward(self, input):
|
188 |
+
relu_applied = nn.functional.relu(input)
|
189 |
+
squared = torch.square(relu_applied)
|
190 |
+
return squared
|
191 |
+
|
192 |
+
|
193 |
+
class ClassInstantier(OrderedDict):
|
194 |
+
def __getitem__(self, key):
|
195 |
+
content = super().__getitem__(key)
|
196 |
+
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
197 |
+
return cls(**kwargs)
|
198 |
+
|
199 |
+
|
200 |
+
ACT2CLS = {
|
201 |
+
"gelu": GELUActivation,
|
202 |
+
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
|
203 |
+
"gelu_fast": FastGELUActivation,
|
204 |
+
"gelu_new": NewGELUActivation,
|
205 |
+
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
|
206 |
+
"gelu_pytorch_tanh": PytorchGELUTanh,
|
207 |
+
"gelu_accurate": AccurateGELUActivation,
|
208 |
+
"laplace": LaplaceActivation,
|
209 |
+
"leaky_relu": nn.LeakyReLU,
|
210 |
+
"linear": LinearActivation,
|
211 |
+
"mish": MishActivation,
|
212 |
+
"quick_gelu": QuickGELUActivation,
|
213 |
+
"relu": nn.ReLU,
|
214 |
+
"relu2": ReLUSquaredActivation,
|
215 |
+
"relu6": nn.ReLU6,
|
216 |
+
"sigmoid": nn.Sigmoid,
|
217 |
+
"silu": nn.SiLU,
|
218 |
+
"swish": nn.SiLU,
|
219 |
+
"tanh": nn.Tanh,
|
220 |
+
}
|
221 |
+
ACT2FN = ClassInstantier(ACT2CLS)
|
222 |
+
|
223 |
+
|
224 |
+
def get_activation(activation_string):
|
225 |
+
if activation_string in ACT2FN:
|
226 |
+
return ACT2FN[activation_string]
|
227 |
+
else:
|
228 |
+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
229 |
+
|
230 |
+
|
231 |
+
# For backwards compatibility with: from activations import gelu_python
|
232 |
+
gelu_python = get_activation("gelu_python")
|
233 |
+
gelu_new = get_activation("gelu_new")
|
234 |
+
gelu = get_activation("gelu")
|
235 |
+
gelu_fast = get_activation("gelu_fast")
|
236 |
+
quick_gelu = get_activation("quick_gelu")
|
237 |
+
silu = get_activation("silu")
|
238 |
+
mish = get_activation("mish")
|
239 |
+
linear_act = get_activation("linear")
|
transformers_4_44_2__cache_utils.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import importlib.metadata
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from packaging import version
|
10 |
+
|
11 |
+
from transformers.configuration_utils import PretrainedConfig
|
12 |
+
from transformers.utils import is_torchdynamo_compiling, logging
|
13 |
+
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class Cache(torch.nn.Module):
|
19 |
+
"""
|
20 |
+
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
def update(
|
27 |
+
self,
|
28 |
+
key_states: torch.Tensor,
|
29 |
+
value_states: torch.Tensor,
|
30 |
+
layer_idx: int,
|
31 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
32 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
33 |
+
"""
|
34 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
key_states (`torch.Tensor`):
|
38 |
+
The new key states to cache.
|
39 |
+
value_states (`torch.Tensor`):
|
40 |
+
The new value states to cache.
|
41 |
+
layer_idx (`int`):
|
42 |
+
The index of the layer to cache the states for.
|
43 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
44 |
+
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
45 |
+
cache to be created.
|
46 |
+
|
47 |
+
Return:
|
48 |
+
A tuple containing the updated key and value states.
|
49 |
+
"""
|
50 |
+
raise NotImplementedError("Make sure to implement `update` in a subclass.")
|
51 |
+
|
52 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
53 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
54 |
+
# TODO: deprecate this function in favor of `cache_position`
|
55 |
+
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
56 |
+
|
57 |
+
def get_max_length(self) -> Optional[int]:
|
58 |
+
"""Returns the maximum sequence length of the cached states, if there is any."""
|
59 |
+
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
|
60 |
+
|
61 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
62 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
63 |
+
# Cache without size limit -> all cache is usable
|
64 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
65 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
66 |
+
max_length = self.get_max_length()
|
67 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
68 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
69 |
+
return max_length - new_seq_length
|
70 |
+
return previous_seq_length
|
71 |
+
|
72 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
73 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
74 |
+
for layer_idx in range(len(self.key_cache)):
|
75 |
+
device = self.key_cache[layer_idx].device
|
76 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
77 |
+
device = self.value_cache[layer_idx].device
|
78 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
79 |
+
|
80 |
+
@property
|
81 |
+
def seen_tokens(self):
|
82 |
+
logger.warning_once(
|
83 |
+
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
|
84 |
+
"model input instead."
|
85 |
+
)
|
86 |
+
if hasattr(self, "_seen_tokens"):
|
87 |
+
return self._seen_tokens
|
88 |
+
else:
|
89 |
+
return None
|
90 |
+
|
91 |
+
|
92 |
+
@dataclass
|
93 |
+
class CacheConfig:
|
94 |
+
"""
|
95 |
+
Base class for cache configs
|
96 |
+
"""
|
97 |
+
|
98 |
+
cache_implementation: None
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def from_dict(cls, config_dict, **kwargs):
|
102 |
+
"""
|
103 |
+
Constructs a CacheConfig instance from a dictionary of parameters.
|
104 |
+
Args:
|
105 |
+
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
106 |
+
**kwargs: Additional keyword arguments to override dictionary values.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
CacheConfig: Instance of CacheConfig constructed from the dictionary.
|
110 |
+
"""
|
111 |
+
config = cls(**config_dict)
|
112 |
+
to_remove = []
|
113 |
+
for key, value in kwargs.items():
|
114 |
+
if hasattr(config, key):
|
115 |
+
setattr(config, key, value)
|
116 |
+
to_remove.append(key)
|
117 |
+
for key in to_remove:
|
118 |
+
kwargs.pop(key, None)
|
119 |
+
return config
|
120 |
+
|
121 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
|
122 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
123 |
+
"""
|
124 |
+
Save this instance to a JSON file.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
json_file_path (`str` or `os.PathLike`):
|
128 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
129 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
130 |
+
If set to `True`, only the difference between the config instance and the default
|
131 |
+
`QuantizationConfig()` is serialized to JSON file.
|
132 |
+
"""
|
133 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
134 |
+
config_dict = self.to_dict()
|
135 |
+
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
136 |
+
|
137 |
+
writer.write(json_string)
|
138 |
+
|
139 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
|
140 |
+
def to_dict(self) -> Dict[str, Any]:
|
141 |
+
"""
|
142 |
+
Serializes this instance to a Python dictionary. Returns:
|
143 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
144 |
+
"""
|
145 |
+
return copy.deepcopy(self.__dict__)
|
146 |
+
|
147 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
|
148 |
+
def __iter__(self):
|
149 |
+
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
|
150 |
+
for attr, value in copy.deepcopy(self.__dict__).items():
|
151 |
+
yield attr, value
|
152 |
+
|
153 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
|
154 |
+
def __repr__(self):
|
155 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
156 |
+
|
157 |
+
def to_json_string(self):
|
158 |
+
"""
|
159 |
+
Serializes this instance to a JSON formatted string.
|
160 |
+
Returns:
|
161 |
+
str: JSON formatted string representing the configuration instance.
|
162 |
+
"""
|
163 |
+
return json.dumps(self.__dict__, indent=2) + "\n"
|
164 |
+
|
165 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
|
166 |
+
def update(self, **kwargs):
|
167 |
+
"""
|
168 |
+
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
|
169 |
+
returning all the unused kwargs.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
kwargs (`Dict[str, Any]`):
|
173 |
+
Dictionary of attributes to tentatively update this class.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
177 |
+
"""
|
178 |
+
to_remove = []
|
179 |
+
for key, value in kwargs.items():
|
180 |
+
if hasattr(self, key):
|
181 |
+
setattr(self, key, value)
|
182 |
+
to_remove.append(key)
|
183 |
+
|
184 |
+
# Remove all the attributes that were updated, without modifying the input dict
|
185 |
+
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
186 |
+
return unused_kwargs
|
187 |
+
|
188 |
+
|
189 |
+
class StaticCache(Cache):
|
190 |
+
"""
|
191 |
+
Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
|
192 |
+
|
193 |
+
Parameters:
|
194 |
+
config (`PretrainedConfig`):
|
195 |
+
The configuration file defining the shape-related attributes required to initialize the static cache.
|
196 |
+
max_batch_size (`int`):
|
197 |
+
The maximum batch size with which the model will be used.
|
198 |
+
max_cache_len (`int`):
|
199 |
+
The maximum sequence length with which the model will be used.
|
200 |
+
device (`torch.device`):
|
201 |
+
The device on which the cache should be initialized. Should be the same as the layer.
|
202 |
+
dtype (*optional*, defaults to `torch.float32`):
|
203 |
+
The default `dtype` to use when initializing the layer.
|
204 |
+
|
205 |
+
Example:
|
206 |
+
|
207 |
+
```python
|
208 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
|
209 |
+
|
210 |
+
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
211 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
212 |
+
|
213 |
+
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
|
214 |
+
|
215 |
+
>>> # Prepare a cache class and pass it to model's forward
|
216 |
+
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
217 |
+
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
218 |
+
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
219 |
+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
220 |
+
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
|
221 |
+
```
|
222 |
+
"""
|
223 |
+
|
224 |
+
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
225 |
+
super().__init__()
|
226 |
+
self.max_batch_size = max_batch_size
|
227 |
+
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
228 |
+
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
229 |
+
self.head_dim = (
|
230 |
+
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
231 |
+
)
|
232 |
+
|
233 |
+
self.dtype = dtype if dtype is not None else torch.float32
|
234 |
+
self.num_key_value_heads = (
|
235 |
+
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
236 |
+
)
|
237 |
+
|
238 |
+
self.key_cache: List[torch.Tensor] = []
|
239 |
+
self.value_cache: List[torch.Tensor] = []
|
240 |
+
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
241 |
+
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
242 |
+
for idx in range(config.num_hidden_layers):
|
243 |
+
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
244 |
+
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
245 |
+
# Notes:
|
246 |
+
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
247 |
+
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
248 |
+
# it is not needed anyway)
|
249 |
+
# 2. `torch.export()` requires mutations to be registered as buffers.
|
250 |
+
if not is_torchdynamo_compiling():
|
251 |
+
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
252 |
+
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
253 |
+
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
254 |
+
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
255 |
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
256 |
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
257 |
+
self.key_cache.append(new_layer_key_cache)
|
258 |
+
self.value_cache.append(new_layer_value_cache)
|
259 |
+
|
260 |
+
def update(
|
261 |
+
self,
|
262 |
+
key_states: torch.Tensor,
|
263 |
+
value_states: torch.Tensor,
|
264 |
+
layer_idx: int,
|
265 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
266 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
267 |
+
"""
|
268 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
269 |
+
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
|
270 |
+
|
271 |
+
Parameters:
|
272 |
+
key_states (`torch.Tensor`):
|
273 |
+
The new key states to cache.
|
274 |
+
value_states (`torch.Tensor`):
|
275 |
+
The new value states to cache.
|
276 |
+
layer_idx (`int`):
|
277 |
+
The index of the layer to cache the states for.
|
278 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
279 |
+
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
|
280 |
+
to know how where to write in the cache.
|
281 |
+
|
282 |
+
Return:
|
283 |
+
A tuple containing the updated key and value states.
|
284 |
+
"""
|
285 |
+
cache_position = cache_kwargs.get("cache_position")
|
286 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
|
287 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
|
288 |
+
k_out = self.key_cache[layer_idx]
|
289 |
+
v_out = self.value_cache[layer_idx]
|
290 |
+
|
291 |
+
if cache_position is None:
|
292 |
+
k_out.copy_(key_states)
|
293 |
+
v_out.copy_(value_states)
|
294 |
+
else:
|
295 |
+
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
|
296 |
+
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
|
297 |
+
# operation, that avoids copies and uses less memory.
|
298 |
+
try:
|
299 |
+
k_out.index_copy_(2, cache_position, key_states)
|
300 |
+
v_out.index_copy_(2, cache_position, value_states)
|
301 |
+
except NotImplementedError:
|
302 |
+
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
303 |
+
k_out[:, :, cache_position] = key_states
|
304 |
+
v_out[:, :, cache_position] = value_states
|
305 |
+
|
306 |
+
return k_out, v_out
|
307 |
+
|
308 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
309 |
+
"""Returns the sequence length of the cached states that were seen by the model."""
|
310 |
+
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
311 |
+
# limit the check to the first batch member and head dimension.
|
312 |
+
# TODO: deprecate this function in favor of `cache_position`
|
313 |
+
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
314 |
+
|
315 |
+
def get_max_length(self) -> Optional[int]:
|
316 |
+
"""Returns the maximum sequence length of the cached states."""
|
317 |
+
return self.max_cache_len
|
318 |
+
|
319 |
+
def reset(self):
|
320 |
+
"""Resets the cache values while preserving the objects"""
|
321 |
+
for layer_idx in range(len(self.key_cache)):
|
322 |
+
# In-place ops prevent breaking the static address
|
323 |
+
self.key_cache[layer_idx].zero_()
|
324 |
+
self.value_cache[layer_idx].zero_()
|
325 |
+
|
transformers_4_44_2__configuration_llama.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""LLaMA model configuration"""
|
21 |
+
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from .transformers_4_44_2__modeling_rope_utils import rope_config_validation
|
24 |
+
|
25 |
+
|
26 |
+
class LlamaConfig(PretrainedConfig):
|
27 |
+
r"""
|
28 |
+
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
|
29 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
30 |
+
defaults will yield a similar configuration to that of the LLaMA-7B.
|
31 |
+
|
32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
33 |
+
documentation from [`PretrainedConfig`] for more information.
|
34 |
+
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
38 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
39 |
+
`inputs_ids` passed when calling [`LlamaModel`]
|
40 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
41 |
+
Dimension of the hidden representations.
|
42 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
43 |
+
Dimension of the MLP representations.
|
44 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
45 |
+
Number of hidden layers in the Transformer decoder.
|
46 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
47 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
48 |
+
num_key_value_heads (`int`, *optional*):
|
49 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
50 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
51 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
52 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
53 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
54 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
55 |
+
`num_attention_heads`.
|
56 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
57 |
+
The non-linear activation function (function or string) in the decoder.
|
58 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
59 |
+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
|
60 |
+
Llama 2 up to 4096, CodeLlama up to 16384.
|
61 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
62 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
63 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
64 |
+
The epsilon used by the rms normalization layers.
|
65 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
66 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
67 |
+
relevant if `config.is_decoder=True`.
|
68 |
+
pad_token_id (`int`, *optional*):
|
69 |
+
Padding token id.
|
70 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
71 |
+
Beginning of stream token id.
|
72 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
73 |
+
End of stream token id.
|
74 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
75 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
76 |
+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
77 |
+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
78 |
+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
79 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
80 |
+
Whether to tie weight embeddings
|
81 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
82 |
+
The base period of the RoPE embeddings.
|
83 |
+
rope_scaling (`Dict`, *optional*):
|
84 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
85 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
86 |
+
accordingly.
|
87 |
+
Expected contents:
|
88 |
+
`rope_type` (`str`):
|
89 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
90 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
91 |
+
`factor` (`float`, *optional*):
|
92 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
93 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
94 |
+
original maximum pre-trained length.
|
95 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
96 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
97 |
+
pretraining.
|
98 |
+
`attention_factor` (`float`, *optional*):
|
99 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
100 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
101 |
+
`factor` field to infer the suggested value.
|
102 |
+
`beta_fast` (`float`, *optional*):
|
103 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
104 |
+
ramp function. If unspecified, it defaults to 32.
|
105 |
+
`beta_slow` (`float`, *optional*):
|
106 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
107 |
+
ramp function. If unspecified, it defaults to 1.
|
108 |
+
`short_factor` (`List[float]`, *optional*):
|
109 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
110 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
111 |
+
size divided by the number of attention heads divided by 2
|
112 |
+
`long_factor` (`List[float]`, *optional*):
|
113 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
114 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
115 |
+
size divided by the number of attention heads divided by 2
|
116 |
+
`low_freq_factor` (`float`, *optional*):
|
117 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
118 |
+
`high_freq_factor` (`float`, *optional*):
|
119 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
120 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
121 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
122 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
123 |
+
The dropout ratio for the attention probabilities.
|
124 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
125 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
126 |
+
|
127 |
+
```python
|
128 |
+
>>> from transformers import LlamaModel, LlamaConfig
|
129 |
+
|
130 |
+
>>> # Initializing a LLaMA llama-7b style configuration
|
131 |
+
>>> configuration = LlamaConfig()
|
132 |
+
|
133 |
+
>>> # Initializing a model from the llama-7b style configuration
|
134 |
+
>>> model = LlamaModel(configuration)
|
135 |
+
|
136 |
+
>>> # Accessing the model configuration
|
137 |
+
>>> configuration = model.config
|
138 |
+
```"""
|
139 |
+
|
140 |
+
model_type = "llama"
|
141 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
142 |
+
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
vocab_size=32000,
|
146 |
+
hidden_size=4096,
|
147 |
+
intermediate_size=11008,
|
148 |
+
num_hidden_layers=32,
|
149 |
+
num_attention_heads=32,
|
150 |
+
num_key_value_heads=None,
|
151 |
+
hidden_act="silu",
|
152 |
+
max_position_embeddings=2048,
|
153 |
+
initializer_range=0.02,
|
154 |
+
rms_norm_eps=1e-6,
|
155 |
+
use_cache=True,
|
156 |
+
pad_token_id=None,
|
157 |
+
bos_token_id=1,
|
158 |
+
eos_token_id=2,
|
159 |
+
pretraining_tp=1,
|
160 |
+
tie_word_embeddings=False,
|
161 |
+
rope_theta=10000.0,
|
162 |
+
rope_scaling=None,
|
163 |
+
attention_bias=False,
|
164 |
+
attention_dropout=0.0,
|
165 |
+
mlp_bias=False,
|
166 |
+
**kwargs,
|
167 |
+
):
|
168 |
+
self.vocab_size = vocab_size
|
169 |
+
self.max_position_embeddings = max_position_embeddings
|
170 |
+
self.hidden_size = hidden_size
|
171 |
+
self.intermediate_size = intermediate_size
|
172 |
+
self.num_hidden_layers = num_hidden_layers
|
173 |
+
self.num_attention_heads = num_attention_heads
|
174 |
+
|
175 |
+
# for backward compatibility
|
176 |
+
if num_key_value_heads is None:
|
177 |
+
num_key_value_heads = num_attention_heads
|
178 |
+
|
179 |
+
self.num_key_value_heads = num_key_value_heads
|
180 |
+
self.hidden_act = hidden_act
|
181 |
+
self.initializer_range = initializer_range
|
182 |
+
self.rms_norm_eps = rms_norm_eps
|
183 |
+
self.pretraining_tp = pretraining_tp
|
184 |
+
self.use_cache = use_cache
|
185 |
+
self.rope_theta = rope_theta
|
186 |
+
self.rope_scaling = rope_scaling
|
187 |
+
self.attention_bias = attention_bias
|
188 |
+
self.attention_dropout = attention_dropout
|
189 |
+
self.mlp_bias = mlp_bias
|
190 |
+
|
191 |
+
# Validate the correctness of rotary position embeddings parameters
|
192 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
193 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
194 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
195 |
+
rope_config_validation(self)
|
196 |
+
|
197 |
+
super().__init__(
|
198 |
+
pad_token_id=pad_token_id,
|
199 |
+
bos_token_id=bos_token_id,
|
200 |
+
eos_token_id=eos_token_id,
|
201 |
+
tie_word_embeddings=tie_word_embeddings,
|
202 |
+
**kwargs,
|
203 |
+
)
|
transformers_4_44_2__modeling_attn_mask_utils.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class AttentionMaskConverter:
|
22 |
+
"""
|
23 |
+
A utility attention mask class that allows one to:
|
24 |
+
- Create a causal 4d mask
|
25 |
+
- Create a causal 4d mask with slided window
|
26 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
27 |
+
key_value_length) that can be multiplied with attention scores
|
28 |
+
|
29 |
+
Examples:
|
30 |
+
|
31 |
+
```python
|
32 |
+
>>> import torch
|
33 |
+
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
34 |
+
|
35 |
+
>>> converter = AttentionMaskConverter(True)
|
36 |
+
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
37 |
+
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
38 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
39 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
40 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
|
41 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
|
42 |
+
```
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
is_causal (`bool`):
|
46 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
47 |
+
|
48 |
+
sliding_window (`int`, *optional*):
|
49 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
50 |
+
"""
|
51 |
+
|
52 |
+
is_causal: bool
|
53 |
+
sliding_window: int
|
54 |
+
|
55 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
56 |
+
self.is_causal = is_causal
|
57 |
+
self.sliding_window = sliding_window
|
58 |
+
|
59 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
60 |
+
raise ValueError(
|
61 |
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
62 |
+
)
|
63 |
+
|
64 |
+
def to_causal_4d(
|
65 |
+
self,
|
66 |
+
batch_size: int,
|
67 |
+
query_length: int,
|
68 |
+
key_value_length: int,
|
69 |
+
dtype: torch.dtype,
|
70 |
+
device: Union[torch.device, "str"] = "cpu",
|
71 |
+
) -> Optional[torch.Tensor]:
|
72 |
+
"""
|
73 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
74 |
+
bias to upper right hand triangular matrix (causal mask).
|
75 |
+
"""
|
76 |
+
if not self.is_causal:
|
77 |
+
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
|
78 |
+
|
79 |
+
# If shape is not cached, create a new causal mask and cache it
|
80 |
+
input_shape = (batch_size, query_length)
|
81 |
+
past_key_values_length = key_value_length - query_length
|
82 |
+
|
83 |
+
# create causal mask
|
84 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
85 |
+
causal_4d_mask = None
|
86 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
87 |
+
causal_4d_mask = self._make_causal_mask(
|
88 |
+
input_shape,
|
89 |
+
dtype,
|
90 |
+
device=device,
|
91 |
+
past_key_values_length=past_key_values_length,
|
92 |
+
sliding_window=self.sliding_window,
|
93 |
+
)
|
94 |
+
|
95 |
+
return causal_4d_mask
|
96 |
+
|
97 |
+
def to_4d(
|
98 |
+
self,
|
99 |
+
attention_mask_2d: torch.Tensor,
|
100 |
+
query_length: int,
|
101 |
+
dtype: torch.dtype,
|
102 |
+
key_value_length: Optional[int] = None,
|
103 |
+
) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
106 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
107 |
+
causal, a causal mask will be added.
|
108 |
+
"""
|
109 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
110 |
+
|
111 |
+
# create causal mask
|
112 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
113 |
+
causal_4d_mask = None
|
114 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
115 |
+
if key_value_length is None:
|
116 |
+
raise ValueError(
|
117 |
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
118 |
+
)
|
119 |
+
|
120 |
+
past_key_values_length = key_value_length - query_length
|
121 |
+
causal_4d_mask = self._make_causal_mask(
|
122 |
+
input_shape,
|
123 |
+
dtype,
|
124 |
+
device=attention_mask_2d.device,
|
125 |
+
past_key_values_length=past_key_values_length,
|
126 |
+
sliding_window=self.sliding_window,
|
127 |
+
)
|
128 |
+
elif self.sliding_window is not None:
|
129 |
+
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
|
130 |
+
|
131 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
132 |
+
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
|
133 |
+
attention_mask_2d.device
|
134 |
+
)
|
135 |
+
|
136 |
+
if causal_4d_mask is not None:
|
137 |
+
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
|
138 |
+
|
139 |
+
# expanded_attn_mask + causal_4d_mask can cause some overflow
|
140 |
+
expanded_4d_mask = expanded_attn_mask
|
141 |
+
|
142 |
+
return expanded_4d_mask
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def _make_causal_mask(
|
146 |
+
input_ids_shape: torch.Size,
|
147 |
+
dtype: torch.dtype,
|
148 |
+
device: torch.device,
|
149 |
+
past_key_values_length: int = 0,
|
150 |
+
sliding_window: Optional[int] = None,
|
151 |
+
):
|
152 |
+
"""
|
153 |
+
Make causal mask used for bi-directional self-attention.
|
154 |
+
"""
|
155 |
+
bsz, tgt_len = input_ids_shape
|
156 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
157 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
158 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
159 |
+
|
160 |
+
mask = mask.to(dtype)
|
161 |
+
|
162 |
+
if past_key_values_length > 0:
|
163 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
164 |
+
|
165 |
+
# add lower triangular sliding window mask if necessary
|
166 |
+
if sliding_window is not None:
|
167 |
+
diagonal = past_key_values_length - sliding_window - 1
|
168 |
+
|
169 |
+
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
170 |
+
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
171 |
+
|
172 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
176 |
+
"""
|
177 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
178 |
+
"""
|
179 |
+
bsz, src_len = mask.size()
|
180 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
181 |
+
|
182 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
183 |
+
|
184 |
+
inverted_mask = 1.0 - expanded_mask
|
185 |
+
|
186 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def _unmask_unattended(
|
190 |
+
expanded_mask: torch.FloatTensor,
|
191 |
+
min_dtype: float,
|
192 |
+
):
|
193 |
+
# fmt: off
|
194 |
+
"""
|
195 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
196 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
197 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
198 |
+
|
199 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
200 |
+
`attention_mask` is [bsz, src_seq_len].
|
201 |
+
|
202 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
|
203 |
+
|
204 |
+
For example, if `expanded_mask` is (e.g. here left-padding case)
|
205 |
+
```
|
206 |
+
[[[[0, 0, 0],
|
207 |
+
[0, 0, 0],
|
208 |
+
[0, 0, 1]]],
|
209 |
+
[[[1, 0, 0],
|
210 |
+
[1, 1, 0],
|
211 |
+
[1, 1, 1]]],
|
212 |
+
[[[0, 0, 0],
|
213 |
+
[0, 1, 0],
|
214 |
+
[0, 1, 1]]]]
|
215 |
+
```
|
216 |
+
then the modified `expanded_mask` will be
|
217 |
+
```
|
218 |
+
[[[[1, 1, 1], <-- modified
|
219 |
+
[1, 1, 1], <-- modified
|
220 |
+
[0, 0, 1]]],
|
221 |
+
[[[1, 0, 0],
|
222 |
+
[1, 1, 0],
|
223 |
+
[1, 1, 1]]],
|
224 |
+
[[[1, 1, 1], <-- modified
|
225 |
+
[0, 1, 0],
|
226 |
+
[0, 1, 1]]]]
|
227 |
+
```
|
228 |
+
"""
|
229 |
+
# fmt: on
|
230 |
+
if expanded_mask.dtype == torch.bool:
|
231 |
+
raise ValueError(
|
232 |
+
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
233 |
+
)
|
234 |
+
|
235 |
+
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
|
236 |
+
|
237 |
+
@staticmethod
|
238 |
+
def _ignore_causal_mask_sdpa(
|
239 |
+
attention_mask: Optional[torch.Tensor],
|
240 |
+
inputs_embeds: torch.Tensor,
|
241 |
+
past_key_values_length: int,
|
242 |
+
sliding_window: Optional[int] = None,
|
243 |
+
is_training: bool = False,
|
244 |
+
) -> bool:
|
245 |
+
"""
|
246 |
+
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
247 |
+
|
248 |
+
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
|
249 |
+
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
|
250 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
251 |
+
"""
|
252 |
+
|
253 |
+
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
254 |
+
key_value_length = query_length + past_key_values_length
|
255 |
+
|
256 |
+
is_tracing = (
|
257 |
+
torch.jit.is_tracing()
|
258 |
+
or isinstance(inputs_embeds, torch.fx.Proxy)
|
259 |
+
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
260 |
+
)
|
261 |
+
|
262 |
+
ignore_causal_mask = False
|
263 |
+
|
264 |
+
if attention_mask is None:
|
265 |
+
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
|
266 |
+
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
267 |
+
# Thus, we only set `ignore_causal_mask = True` if the model is set to training.
|
268 |
+
#
|
269 |
+
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
|
270 |
+
if (
|
271 |
+
(is_training or not is_tracing)
|
272 |
+
and (query_length == 1 or key_value_length == query_length)
|
273 |
+
and (sliding_window is None or key_value_length < sliding_window)
|
274 |
+
):
|
275 |
+
ignore_causal_mask = True
|
276 |
+
elif sliding_window is None or key_value_length < sliding_window:
|
277 |
+
if len(attention_mask.shape) == 4:
|
278 |
+
return False
|
279 |
+
elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
|
280 |
+
if query_length == 1 or key_value_length == query_length:
|
281 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
282 |
+
ignore_causal_mask = True
|
283 |
+
|
284 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
285 |
+
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
286 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
287 |
+
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
|
288 |
+
|
289 |
+
return ignore_causal_mask
|
290 |
+
|
291 |
+
|
292 |
+
def _prepare_4d_causal_attention_mask(
|
293 |
+
attention_mask: Optional[torch.Tensor],
|
294 |
+
input_shape: Union[torch.Size, Tuple, List],
|
295 |
+
inputs_embeds: torch.Tensor,
|
296 |
+
past_key_values_length: int,
|
297 |
+
sliding_window: Optional[int] = None,
|
298 |
+
):
|
299 |
+
"""
|
300 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
301 |
+
`(batch_size, key_value_length)`
|
302 |
+
|
303 |
+
Args:
|
304 |
+
attention_mask (`torch.Tensor` or `None`):
|
305 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
306 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
307 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
308 |
+
inputs_embeds (`torch.Tensor`):
|
309 |
+
The embedded inputs as a torch Tensor.
|
310 |
+
past_key_values_length (`int`):
|
311 |
+
The length of the key value cache.
|
312 |
+
sliding_window (`int`, *optional*):
|
313 |
+
If the model uses windowed attention, a sliding window should be passed.
|
314 |
+
"""
|
315 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
316 |
+
|
317 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
318 |
+
|
319 |
+
# 4d mask is passed through the layers
|
320 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
321 |
+
attention_mask = attn_mask_converter.to_4d(
|
322 |
+
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
|
323 |
+
)
|
324 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
325 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
326 |
+
if tuple(attention_mask.shape) != expected_shape:
|
327 |
+
raise ValueError(
|
328 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
329 |
+
)
|
330 |
+
else:
|
331 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
332 |
+
inverted_mask = 1.0 - attention_mask
|
333 |
+
attention_mask = inverted_mask.masked_fill(
|
334 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
338 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
339 |
+
)
|
340 |
+
|
341 |
+
return attention_mask
|
342 |
+
|
343 |
+
|
344 |
+
# Adapted from _prepare_4d_causal_attention_mask
|
345 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
346 |
+
attention_mask: Optional[torch.Tensor],
|
347 |
+
input_shape: Union[torch.Size, Tuple, List],
|
348 |
+
inputs_embeds: torch.Tensor,
|
349 |
+
past_key_values_length: int,
|
350 |
+
sliding_window: Optional[int] = None,
|
351 |
+
):
|
352 |
+
"""
|
353 |
+
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
|
354 |
+
|
355 |
+
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
|
356 |
+
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
|
357 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
358 |
+
"""
|
359 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
360 |
+
|
361 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
362 |
+
|
363 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
364 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
365 |
+
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
366 |
+
is_tracing = (
|
367 |
+
torch.jit.is_tracing()
|
368 |
+
or isinstance(inputs_embeds, torch.fx.Proxy)
|
369 |
+
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
370 |
+
)
|
371 |
+
|
372 |
+
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
|
373 |
+
attention_mask=attention_mask,
|
374 |
+
inputs_embeds=inputs_embeds,
|
375 |
+
past_key_values_length=past_key_values_length,
|
376 |
+
sliding_window=sliding_window,
|
377 |
+
)
|
378 |
+
|
379 |
+
if ignore_causal_mask:
|
380 |
+
expanded_4d_mask = None
|
381 |
+
elif attention_mask is None:
|
382 |
+
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
383 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
384 |
+
)
|
385 |
+
else:
|
386 |
+
if attention_mask.dim() == 4:
|
387 |
+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
388 |
+
if attention_mask.max() != 0:
|
389 |
+
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
|
390 |
+
expanded_4d_mask = attention_mask
|
391 |
+
else:
|
392 |
+
expanded_4d_mask = attn_mask_converter.to_4d(
|
393 |
+
attention_mask,
|
394 |
+
input_shape[-1],
|
395 |
+
dtype=inputs_embeds.dtype,
|
396 |
+
key_value_length=key_value_length,
|
397 |
+
)
|
398 |
+
|
399 |
+
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
400 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
401 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
402 |
+
if not is_tracing and expanded_4d_mask.device.type == "cuda":
|
403 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
404 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
405 |
+
)
|
406 |
+
|
407 |
+
return expanded_4d_mask
|
408 |
+
|
409 |
+
|
410 |
+
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
411 |
+
"""
|
412 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
413 |
+
`(batch_size, key_value_length)`
|
414 |
+
|
415 |
+
Args:
|
416 |
+
mask (`torch.Tensor`):
|
417 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
418 |
+
dtype (`torch.dtype`):
|
419 |
+
The torch dtype the created mask shall have.
|
420 |
+
tgt_len (`int`):
|
421 |
+
The target length or query length the created mask shall have.
|
422 |
+
"""
|
423 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
424 |
+
|
425 |
+
|
426 |
+
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
427 |
+
"""
|
428 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
429 |
+
`(batch_size, key_value_length)`
|
430 |
+
|
431 |
+
Args:
|
432 |
+
mask (`torch.Tensor`):
|
433 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
434 |
+
dtype (`torch.dtype`):
|
435 |
+
The torch dtype the created mask shall have.
|
436 |
+
tgt_len (`int`):
|
437 |
+
The target length or query length the created mask shall have.
|
438 |
+
"""
|
439 |
+
_, key_value_length = mask.shape
|
440 |
+
tgt_len = tgt_len if tgt_len is not None else key_value_length
|
441 |
+
|
442 |
+
is_tracing = (
|
443 |
+
torch.jit.is_tracing()
|
444 |
+
or isinstance(mask, torch.fx.Proxy)
|
445 |
+
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
446 |
+
)
|
447 |
+
|
448 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
|
449 |
+
if not is_tracing and torch.all(mask == 1):
|
450 |
+
return None
|
451 |
+
else:
|
452 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
453 |
+
|
454 |
+
|
455 |
+
def _create_4d_causal_attention_mask(
|
456 |
+
input_shape: Union[torch.Size, Tuple, List],
|
457 |
+
dtype: torch.dtype,
|
458 |
+
device: torch.device,
|
459 |
+
past_key_values_length: int = 0,
|
460 |
+
sliding_window: Optional[int] = None,
|
461 |
+
) -> Optional[torch.Tensor]:
|
462 |
+
"""
|
463 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
|
464 |
+
|
465 |
+
Args:
|
466 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
467 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
468 |
+
dtype (`torch.dtype`):
|
469 |
+
The torch dtype the created mask shall have.
|
470 |
+
device (`int`):
|
471 |
+
The torch device the created mask shall have.
|
472 |
+
sliding_window (`int`, *optional*):
|
473 |
+
If the model uses windowed attention, a sliding window should be passed.
|
474 |
+
"""
|
475 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
476 |
+
|
477 |
+
key_value_length = past_key_values_length + input_shape[-1]
|
478 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
479 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
|
480 |
+
)
|
481 |
+
|
482 |
+
return attention_mask
|
transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import inspect
|
17 |
+
import os
|
18 |
+
from typing import Optional, Tuple, Union
|
19 |
+
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from functools import lru_cache
|
25 |
+
import importlib.metadata
|
26 |
+
import importlib.util
|
27 |
+
from packaging import version
|
28 |
+
|
29 |
+
from transformers.utils import is_flash_attn_2_available
|
30 |
+
|
31 |
+
|
32 |
+
if is_flash_attn_2_available():
|
33 |
+
try:
|
34 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
35 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
36 |
+
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
37 |
+
except ImportError:
|
38 |
+
raise "Unable to import flash_attn"
|
39 |
+
|
40 |
+
|
41 |
+
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
|
42 |
+
# Check if the package spec exists and grab its version to avoid importing a local directory
|
43 |
+
package_exists = importlib.util.find_spec(pkg_name) is not None
|
44 |
+
package_version = "N/A"
|
45 |
+
if package_exists:
|
46 |
+
try:
|
47 |
+
# Primary method to get the package version
|
48 |
+
package_version = importlib.metadata.version(pkg_name)
|
49 |
+
except importlib.metadata.PackageNotFoundError:
|
50 |
+
# Fallback method: Only for "torch" and versions containing "dev"
|
51 |
+
if pkg_name == "torch":
|
52 |
+
try:
|
53 |
+
package = importlib.import_module(pkg_name)
|
54 |
+
temp_version = getattr(package, "__version__", "N/A")
|
55 |
+
# Check if the version contains "dev"
|
56 |
+
if "dev" in temp_version:
|
57 |
+
package_version = temp_version
|
58 |
+
package_exists = True
|
59 |
+
else:
|
60 |
+
package_exists = False
|
61 |
+
except ImportError:
|
62 |
+
# If the package can't be imported, it's not available
|
63 |
+
package_exists = False
|
64 |
+
else:
|
65 |
+
# For packages other than "torch", don't attempt the fallback and set as not available
|
66 |
+
package_exists = False
|
67 |
+
if return_version:
|
68 |
+
return package_exists, package_version
|
69 |
+
else:
|
70 |
+
return package_exists
|
71 |
+
|
72 |
+
|
73 |
+
@lru_cache()
|
74 |
+
def is_flash_attn_greater_or_equal(library_version: str):
|
75 |
+
if not _is_package_available("flash_attn"):
|
76 |
+
return False
|
77 |
+
|
78 |
+
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
|
79 |
+
|
80 |
+
|
81 |
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
82 |
+
"""
|
83 |
+
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
84 |
+
|
85 |
+
Arguments:
|
86 |
+
attention_mask (`torch.Tensor`):
|
87 |
+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
88 |
+
|
89 |
+
Return:
|
90 |
+
indices (`torch.Tensor`):
|
91 |
+
The indices of non-masked tokens from the flattened input sequence.
|
92 |
+
cu_seqlens (`torch.Tensor`):
|
93 |
+
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
94 |
+
max_seqlen_in_batch (`int`):
|
95 |
+
Maximum sequence length in batch.
|
96 |
+
"""
|
97 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
98 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
99 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
100 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
101 |
+
return (
|
102 |
+
indices,
|
103 |
+
cu_seqlens,
|
104 |
+
max_seqlen_in_batch,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def _upad_input(
|
109 |
+
query_layer: torch.Tensor,
|
110 |
+
key_layer: torch.Tensor,
|
111 |
+
value_layer: torch.Tensor,
|
112 |
+
attention_mask: torch.Tensor,
|
113 |
+
query_length: int,
|
114 |
+
):
|
115 |
+
"""
|
116 |
+
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
|
117 |
+
|
118 |
+
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
|
119 |
+
tensors for query, key, value tensors.
|
120 |
+
|
121 |
+
Arguments:
|
122 |
+
query_layer (`torch.Tensor`):
|
123 |
+
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
124 |
+
key_layer (`torch.Tensor`):
|
125 |
+
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
126 |
+
value_layer (`torch.Tensor`):
|
127 |
+
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
128 |
+
attention_mask (`torch.Tensor`):
|
129 |
+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
130 |
+
query_length (`int`):
|
131 |
+
Target length.
|
132 |
+
|
133 |
+
Return:
|
134 |
+
query_layer (`torch.Tensor`):
|
135 |
+
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
136 |
+
key_layer (`torch.Tensor`):
|
137 |
+
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
138 |
+
value_layer (`torch.Tensor`):
|
139 |
+
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
140 |
+
indices_q (`torch.Tensor`):
|
141 |
+
The indices of non-masked tokens from the flattened input target sequence.
|
142 |
+
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
|
143 |
+
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
144 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
|
145 |
+
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
146 |
+
"""
|
147 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
148 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
149 |
+
|
150 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
151 |
+
value_layer = index_first_axis(
|
152 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
153 |
+
)
|
154 |
+
if query_length == kv_seq_len:
|
155 |
+
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
|
156 |
+
cu_seqlens_q = cu_seqlens_k
|
157 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
158 |
+
indices_q = indices_k
|
159 |
+
elif query_length == 1:
|
160 |
+
max_seqlen_in_batch_q = 1
|
161 |
+
cu_seqlens_q = torch.arange(
|
162 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
163 |
+
) # There is a memcpy here, that is very bad.
|
164 |
+
indices_q = cu_seqlens_q[:-1]
|
165 |
+
query_layer = query_layer.squeeze(1)
|
166 |
+
else:
|
167 |
+
# The -q_len: slice assumes left padding.
|
168 |
+
attention_mask = attention_mask[:, -query_length:]
|
169 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
170 |
+
|
171 |
+
return (
|
172 |
+
query_layer,
|
173 |
+
key_layer,
|
174 |
+
value_layer,
|
175 |
+
indices_q,
|
176 |
+
(cu_seqlens_q, cu_seqlens_k),
|
177 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
182 |
+
"""
|
183 |
+
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
184 |
+
All three query, key, value states will be flattened.
|
185 |
+
Cummulative lengths of each examples in the batch will be extracted from position_ids.
|
186 |
+
|
187 |
+
NOTE: ideally cummulative lengths should be prepared at the data collator stage
|
188 |
+
|
189 |
+
Arguments:
|
190 |
+
query (`torch.Tensor`):
|
191 |
+
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
192 |
+
key (`torch.Tensor`):
|
193 |
+
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
194 |
+
value (`torch.Tensor`):
|
195 |
+
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
196 |
+
position_ids (`torch.Tensor`):
|
197 |
+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
198 |
+
|
199 |
+
Return:
|
200 |
+
query (`torch.Tensor`):
|
201 |
+
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
202 |
+
key (`torch.Tensor`):
|
203 |
+
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
204 |
+
value (`torch.Tensor`):
|
205 |
+
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
206 |
+
indices_q (`torch.Tensor`):
|
207 |
+
The indices of non-masked tokens from the flattened input target sequence.
|
208 |
+
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
|
209 |
+
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
210 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
|
211 |
+
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
212 |
+
"""
|
213 |
+
query = query.view(-1, query.size(-2), query.size(-1))
|
214 |
+
key = key.view(-1, key.size(-2), key.size(-1))
|
215 |
+
value = value.view(-1, value.size(-2), value.size(-1))
|
216 |
+
position_ids = position_ids.flatten()
|
217 |
+
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
218 |
+
|
219 |
+
cu_seq_lens = torch.cat(
|
220 |
+
(
|
221 |
+
indices_q[position_ids == 0],
|
222 |
+
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
223 |
+
)
|
224 |
+
)
|
225 |
+
|
226 |
+
max_length = position_ids.max() + 1
|
227 |
+
|
228 |
+
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
229 |
+
|
230 |
+
|
231 |
+
def _flash_attention_forward(
|
232 |
+
query_states: torch.Tensor,
|
233 |
+
key_states: torch.Tensor,
|
234 |
+
value_states: torch.Tensor,
|
235 |
+
attention_mask: torch.Tensor,
|
236 |
+
query_length: int,
|
237 |
+
is_causal: bool,
|
238 |
+
dropout: float = 0.0,
|
239 |
+
position_ids: Optional[torch.Tensor] = None,
|
240 |
+
softmax_scale: Optional[float] = None,
|
241 |
+
sliding_window: Optional[int] = None,
|
242 |
+
use_top_left_mask: bool = False,
|
243 |
+
softcap: Optional[float] = None,
|
244 |
+
deterministic: bool = None,
|
245 |
+
):
|
246 |
+
"""
|
247 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
248 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
query_states (`torch.Tensor`):
|
252 |
+
Input query states to be passed to Flash Attention API
|
253 |
+
key_states (`torch.Tensor`):
|
254 |
+
Input key states to be passed to Flash Attention API
|
255 |
+
value_states (`torch.Tensor`):
|
256 |
+
Input value states to be passed to Flash Attention API
|
257 |
+
attention_mask (`torch.Tensor`):
|
258 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
259 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
260 |
+
dropout (`float`):
|
261 |
+
Attention dropout
|
262 |
+
softmax_scale (`float`, *optional*):
|
263 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
264 |
+
use_top_left_mask (`bool`, defaults to `False`):
|
265 |
+
flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
|
266 |
+
softcap (`float`, *optional*):
|
267 |
+
Softcap for the attention logits, used e.g. in gemma2.
|
268 |
+
deterministic (`bool`, *optional*):
|
269 |
+
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
|
270 |
+
"""
|
271 |
+
if not use_top_left_mask:
|
272 |
+
causal = is_causal
|
273 |
+
else:
|
274 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
275 |
+
causal = is_causal and query_length != 1
|
276 |
+
|
277 |
+
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
278 |
+
use_sliding_windows = (
|
279 |
+
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
280 |
+
)
|
281 |
+
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
282 |
+
|
283 |
+
if is_flash_attn_greater_or_equal("2.4.1"):
|
284 |
+
if deterministic is None:
|
285 |
+
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
286 |
+
flash_kwargs["deterministic"] = deterministic
|
287 |
+
|
288 |
+
if softcap is not None:
|
289 |
+
flash_kwargs["softcap"] = softcap
|
290 |
+
|
291 |
+
# Contains at least one padding token in the sequence
|
292 |
+
if attention_mask is not None:
|
293 |
+
batch_size = query_states.shape[0]
|
294 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
|
295 |
+
query_states, key_states, value_states, attention_mask, query_length
|
296 |
+
)
|
297 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
298 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
299 |
+
|
300 |
+
attn_output_unpad = flash_attn_varlen_func(
|
301 |
+
query_states,
|
302 |
+
key_states,
|
303 |
+
value_states,
|
304 |
+
cu_seqlens_q=cu_seqlens_q,
|
305 |
+
cu_seqlens_k=cu_seqlens_k,
|
306 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
307 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
308 |
+
dropout_p=dropout,
|
309 |
+
softmax_scale=softmax_scale,
|
310 |
+
causal=causal,
|
311 |
+
**flash_kwargs,
|
312 |
+
)
|
313 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
314 |
+
|
315 |
+
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
316 |
+
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
317 |
+
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
318 |
+
elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1:
|
319 |
+
batch_size = query_states.size(0)
|
320 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
321 |
+
query_states, key_states, value_states, position_ids
|
322 |
+
)
|
323 |
+
|
324 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
325 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
326 |
+
|
327 |
+
attn_output = flash_attn_varlen_func(
|
328 |
+
query_states,
|
329 |
+
key_states,
|
330 |
+
value_states,
|
331 |
+
cu_seqlens_q=cu_seqlens_q,
|
332 |
+
cu_seqlens_k=cu_seqlens_k,
|
333 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
334 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
335 |
+
dropout_p=dropout,
|
336 |
+
softmax_scale=softmax_scale,
|
337 |
+
causal=causal,
|
338 |
+
**flash_kwargs,
|
339 |
+
)
|
340 |
+
|
341 |
+
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
342 |
+
|
343 |
+
else:
|
344 |
+
attn_output = flash_attn_func(
|
345 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
346 |
+
)
|
347 |
+
|
348 |
+
return attn_output
|
transformers_4_44_2__modeling_outputs.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
transformers_4_44_2__modeling_rope_utils.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from typing import Optional, Tuple
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import is_torch_available, logging
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
if is_torch_available():
|
26 |
+
import torch
|
27 |
+
|
28 |
+
|
29 |
+
def _compute_default_rope_parameters(
|
30 |
+
config: Optional[PretrainedConfig] = None,
|
31 |
+
device: Optional["torch.device"] = None,
|
32 |
+
seq_len: Optional[int] = None,
|
33 |
+
**rope_kwargs,
|
34 |
+
) -> Tuple["torch.Tensor", float]:
|
35 |
+
"""
|
36 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
37 |
+
Args:
|
38 |
+
config ([`~transformers.PretrainedConfig`]):
|
39 |
+
The model configuration.
|
40 |
+
device (`torch.device`):
|
41 |
+
The device to use for initialization of the inverse frequencies.
|
42 |
+
seq_len (`int`, *optional*):
|
43 |
+
The current sequence length. Unused for this type of RoPE.
|
44 |
+
rope_kwargs (`Dict`, *optional*):
|
45 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
46 |
+
Returns:
|
47 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
48 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
49 |
+
"""
|
50 |
+
if config is not None and len(rope_kwargs) > 0:
|
51 |
+
raise ValueError(
|
52 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
53 |
+
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
54 |
+
)
|
55 |
+
if len(rope_kwargs) > 0:
|
56 |
+
base = rope_kwargs["base"]
|
57 |
+
dim = rope_kwargs["dim"]
|
58 |
+
elif config is not None:
|
59 |
+
base = config.rope_theta
|
60 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
61 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
62 |
+
dim = int(head_dim * partial_rotary_factor)
|
63 |
+
|
64 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
65 |
+
|
66 |
+
# Compute the inverse frequencies
|
67 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
68 |
+
return inv_freq, attention_factor
|
69 |
+
|
70 |
+
|
71 |
+
def _compute_linear_scaling_rope_parameters(
|
72 |
+
config: Optional[PretrainedConfig] = None,
|
73 |
+
device: Optional["torch.device"] = None,
|
74 |
+
seq_len: Optional[int] = None,
|
75 |
+
**rope_kwargs,
|
76 |
+
) -> Tuple["torch.Tensor", float]:
|
77 |
+
"""
|
78 |
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
79 |
+
Args:
|
80 |
+
config ([`~transformers.PretrainedConfig`]):
|
81 |
+
The model configuration.
|
82 |
+
device (`torch.device`):
|
83 |
+
The device to use for initialization of the inverse frequencies.
|
84 |
+
seq_len (`int`, *optional*):
|
85 |
+
The current sequence length. Unused for this type of RoPE.
|
86 |
+
rope_kwargs (`Dict`, *optional*):
|
87 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
88 |
+
Returns:
|
89 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
90 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
91 |
+
"""
|
92 |
+
if config is not None and len(rope_kwargs) > 0:
|
93 |
+
raise ValueError(
|
94 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
95 |
+
f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
96 |
+
)
|
97 |
+
if len(rope_kwargs) > 0:
|
98 |
+
factor = rope_kwargs["factor"]
|
99 |
+
elif config is not None:
|
100 |
+
factor = config.rope_scaling["factor"]
|
101 |
+
|
102 |
+
# Gets the default RoPE parameters
|
103 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
104 |
+
|
105 |
+
# Then applies linear scaling to the frequencies.
|
106 |
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
107 |
+
# applying scaling to the inverse frequencies is equivalent.
|
108 |
+
inv_freq /= factor
|
109 |
+
return inv_freq, attention_factor
|
110 |
+
|
111 |
+
|
112 |
+
def _compute_dynamic_ntk_parameters(
|
113 |
+
config: Optional[PretrainedConfig] = None,
|
114 |
+
device: Optional["torch.device"] = None,
|
115 |
+
seq_len: Optional[int] = None,
|
116 |
+
**rope_kwargs,
|
117 |
+
) -> Tuple["torch.Tensor", float]:
|
118 |
+
"""
|
119 |
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
120 |
+
Args:
|
121 |
+
config ([`~transformers.PretrainedConfig`]):
|
122 |
+
The model configuration.
|
123 |
+
device (`torch.device`):
|
124 |
+
The device to use for initialization of the inverse frequencies.
|
125 |
+
seq_len (`int`, *optional*):
|
126 |
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
127 |
+
rope_kwargs (`Dict`, *optional*):
|
128 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
129 |
+
Returns:
|
130 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
131 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
132 |
+
"""
|
133 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
134 |
+
if config is not None and len(rope_kwargs) > 0:
|
135 |
+
raise ValueError(
|
136 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
137 |
+
f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
138 |
+
)
|
139 |
+
if len(rope_kwargs) > 0:
|
140 |
+
base = rope_kwargs["base"]
|
141 |
+
dim = rope_kwargs["dim"]
|
142 |
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
143 |
+
factor = rope_kwargs["factor"]
|
144 |
+
elif config is not None:
|
145 |
+
base = config.rope_theta
|
146 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
147 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
148 |
+
dim = int(head_dim * partial_rotary_factor)
|
149 |
+
max_position_embeddings = config.max_position_embeddings
|
150 |
+
factor = config.rope_scaling["factor"]
|
151 |
+
|
152 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
153 |
+
|
154 |
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
155 |
+
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
156 |
+
|
157 |
+
# Compute the inverse frequencies
|
158 |
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
159 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
160 |
+
return inv_freq, attention_factor
|
161 |
+
|
162 |
+
|
163 |
+
def _compute_yarn_parameters(
|
164 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
165 |
+
) -> Tuple["torch.Tensor", float]:
|
166 |
+
"""
|
167 |
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
168 |
+
[original paper](https://arxiv.org/abs/2309.00071)
|
169 |
+
Args:
|
170 |
+
config ([`~transformers.PretrainedConfig`]):
|
171 |
+
The model configuration.
|
172 |
+
device (`torch.device`):
|
173 |
+
The device to use for initialization of the inverse frequencies.
|
174 |
+
seq_len (`int`, *optional*):
|
175 |
+
The current sequence length. Unused for this type of RoPE.
|
176 |
+
rope_kwargs (`Dict`, *optional*):
|
177 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
178 |
+
Returns:
|
179 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
180 |
+
post-processing scaling factor applied to the computed cos/sin.
|
181 |
+
"""
|
182 |
+
# No need to keep BC with yarn, unreleased when this new pattern was created.
|
183 |
+
if len(rope_kwargs) > 0:
|
184 |
+
raise ValueError(
|
185 |
+
f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
|
186 |
+
)
|
187 |
+
|
188 |
+
base = config.rope_theta
|
189 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
190 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
191 |
+
dim = int(head_dim * partial_rotary_factor)
|
192 |
+
max_position_embeddings = config.max_position_embeddings
|
193 |
+
factor = config.rope_scaling["factor"]
|
194 |
+
|
195 |
+
# Sets the attention factor as suggested in the paper
|
196 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
197 |
+
if attention_factor is None:
|
198 |
+
attention_factor = 0.1 * math.log(factor) + 1.0
|
199 |
+
|
200 |
+
# Optional config options
|
201 |
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
202 |
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
203 |
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
204 |
+
|
205 |
+
# Compute the inverse frequencies
|
206 |
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
207 |
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
208 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
209 |
+
|
210 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
211 |
+
"""Find dimension range bounds based on rotations"""
|
212 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
213 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
214 |
+
return max(low, 0), min(high, dim - 1)
|
215 |
+
|
216 |
+
def linear_ramp_factor(min, max, dim):
|
217 |
+
if min == max:
|
218 |
+
max += 0.001 # Prevent singularity
|
219 |
+
|
220 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
221 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
222 |
+
return ramp_func
|
223 |
+
|
224 |
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
225 |
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
226 |
+
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
227 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
228 |
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
229 |
+
|
230 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
231 |
+
|
232 |
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
233 |
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
234 |
+
inv_freq = (
|
235 |
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
236 |
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
237 |
+
)
|
238 |
+
|
239 |
+
return inv_freq, attention_factor
|
240 |
+
|
241 |
+
|
242 |
+
def _compute_longrope_parameters(
|
243 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
244 |
+
) -> Tuple["torch.Tensor", float]:
|
245 |
+
"""
|
246 |
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
247 |
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
248 |
+
Args:
|
249 |
+
config ([`~transformers.PretrainedConfig`]):
|
250 |
+
The model configuration.
|
251 |
+
device (`torch.device`):
|
252 |
+
The device to use for initialization of the inverse frequencies.
|
253 |
+
seq_len (`int`, *optional*):
|
254 |
+
The current sequence length. Unused for this type of RoPE.
|
255 |
+
rope_kwargs (`Dict`, *optional*):
|
256 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
257 |
+
Returns:
|
258 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
259 |
+
post-processing scaling factor applied to the computed cos/sin.
|
260 |
+
"""
|
261 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
262 |
+
# No need to keep BC with longrope, unreleased when this new pattern was created.
|
263 |
+
if len(rope_kwargs) > 0:
|
264 |
+
raise ValueError(
|
265 |
+
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
|
266 |
+
f"{rope_kwargs}"
|
267 |
+
)
|
268 |
+
|
269 |
+
base = config.rope_theta
|
270 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
271 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
272 |
+
dim = int(head_dim * partial_rotary_factor)
|
273 |
+
long_factor = config.rope_scaling["long_factor"]
|
274 |
+
short_factor = config.rope_scaling["short_factor"]
|
275 |
+
factor = config.rope_scaling.get("factor")
|
276 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
277 |
+
|
278 |
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
279 |
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
280 |
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
281 |
+
if hasattr(config, "original_max_position_embeddings"):
|
282 |
+
max_position_embeddings = config.original_max_position_embeddings
|
283 |
+
expanded_max_position_embeddings = config.max_position_embeddings
|
284 |
+
factor = expanded_max_position_embeddings / max_position_embeddings
|
285 |
+
else:
|
286 |
+
max_position_embeddings = config.max_position_embeddings
|
287 |
+
expanded_max_position_embeddings = max_position_embeddings * factor
|
288 |
+
|
289 |
+
# Sets the attention factor as suggested in the paper
|
290 |
+
if attention_factor is None:
|
291 |
+
if factor <= 1.0:
|
292 |
+
attention_factor = 1.0
|
293 |
+
else:
|
294 |
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
295 |
+
|
296 |
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
297 |
+
if expanded_max_position_embeddings > max_position_embeddings:
|
298 |
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
299 |
+
else:
|
300 |
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
301 |
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
|
302 |
+
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
303 |
+
|
304 |
+
return inv_freq, attention_factor
|
305 |
+
|
306 |
+
|
307 |
+
def _compute_llama3_parameters(
|
308 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
309 |
+
) -> Tuple["torch.Tensor", float]:
|
310 |
+
"""
|
311 |
+
Computes the inverse frequencies for llama 3.1.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
config ([`~transformers.PretrainedConfig`]):
|
315 |
+
The model configuration.
|
316 |
+
device (`torch.device`):
|
317 |
+
The device to use for initialization of the inverse frequencies.
|
318 |
+
seq_len (`int`, *optional*):
|
319 |
+
The current sequence length. Unused for this type of RoPE.
|
320 |
+
rope_kwargs (`Dict`, *optional*):
|
321 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
322 |
+
Returns:
|
323 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
324 |
+
post-processing scaling factor applied to the computed cos/sin.
|
325 |
+
"""
|
326 |
+
# Gets the default RoPE parameters
|
327 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
328 |
+
|
329 |
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
330 |
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
331 |
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
332 |
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
333 |
+
|
334 |
+
low_freq_wavelen = old_context_len / low_freq_factor
|
335 |
+
high_freq_wavelen = old_context_len / high_freq_factor
|
336 |
+
|
337 |
+
wavelen = 2 * math.pi / inv_freq
|
338 |
+
# wavelen < high_freq_wavelen: do nothing
|
339 |
+
# wavelen > low_freq_wavelen: divide by factor
|
340 |
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
341 |
+
# otherwise: interpolate between the two, using a smooth factor
|
342 |
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
343 |
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
344 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
345 |
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
346 |
+
|
347 |
+
return inv_freq_llama, attention_factor
|
348 |
+
|
349 |
+
|
350 |
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
351 |
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
352 |
+
# parameterizations, as long as the callable has the same signature.
|
353 |
+
ROPE_INIT_FUNCTIONS = {
|
354 |
+
"default": _compute_default_rope_parameters,
|
355 |
+
"linear": _compute_linear_scaling_rope_parameters,
|
356 |
+
"dynamic": _compute_dynamic_ntk_parameters,
|
357 |
+
"yarn": _compute_yarn_parameters,
|
358 |
+
"longrope": _compute_longrope_parameters,
|
359 |
+
"llama3": _compute_llama3_parameters,
|
360 |
+
}
|
361 |
+
|
362 |
+
|
363 |
+
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
364 |
+
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
365 |
+
# BC: "rope_type" was originally "type" -- let's gracefully handle it
|
366 |
+
if "rope_type" not in received_keys and "type" in received_keys:
|
367 |
+
received_keys -= {"type"}
|
368 |
+
received_keys.add("rope_type")
|
369 |
+
|
370 |
+
missing_keys = required_keys - received_keys
|
371 |
+
if missing_keys:
|
372 |
+
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
|
373 |
+
|
374 |
+
if optional_keys is not None:
|
375 |
+
unused_keys = received_keys - required_keys - optional_keys
|
376 |
+
else:
|
377 |
+
unused_keys = received_keys - required_keys
|
378 |
+
if unused_keys:
|
379 |
+
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
380 |
+
|
381 |
+
|
382 |
+
def _validate_default_rope_parameters(config: PretrainedConfig):
|
383 |
+
rope_scaling = config.rope_scaling
|
384 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
385 |
+
required_keys = {"rope_type"}
|
386 |
+
received_keys = set(rope_scaling.keys())
|
387 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
388 |
+
|
389 |
+
|
390 |
+
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
391 |
+
rope_scaling = config.rope_scaling
|
392 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
393 |
+
required_keys = {"rope_type", "factor"}
|
394 |
+
received_keys = set(rope_scaling.keys())
|
395 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
396 |
+
|
397 |
+
factor = rope_scaling["factor"]
|
398 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
399 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
400 |
+
|
401 |
+
|
402 |
+
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
403 |
+
rope_scaling = config.rope_scaling
|
404 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
405 |
+
required_keys = {"rope_type", "factor"}
|
406 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
407 |
+
optional_keys = {"original_max_position_embeddings"}
|
408 |
+
received_keys = set(rope_scaling.keys())
|
409 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
410 |
+
|
411 |
+
factor = rope_scaling["factor"]
|
412 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
413 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
414 |
+
|
415 |
+
|
416 |
+
def _validate_yarn_parameters(config: PretrainedConfig):
|
417 |
+
rope_scaling = config.rope_scaling
|
418 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
419 |
+
required_keys = {"rope_type", "factor"}
|
420 |
+
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
421 |
+
received_keys = set(rope_scaling.keys())
|
422 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
423 |
+
|
424 |
+
factor = rope_scaling["factor"]
|
425 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
426 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
427 |
+
|
428 |
+
attention_factor = rope_scaling.get("attention_factor")
|
429 |
+
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
|
430 |
+
logger.warning(
|
431 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
432 |
+
)
|
433 |
+
beta_fast = rope_scaling.get("beta_fast")
|
434 |
+
if beta_fast is not None and not isinstance(beta_fast, float):
|
435 |
+
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
|
436 |
+
beta_slow = rope_scaling.get("beta_slow")
|
437 |
+
if beta_slow is not None and not isinstance(beta_slow, float):
|
438 |
+
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
|
439 |
+
|
440 |
+
if (beta_fast or 32) < (beta_slow or 1):
|
441 |
+
logger.warning(
|
442 |
+
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
|
443 |
+
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
|
444 |
+
)
|
445 |
+
|
446 |
+
|
447 |
+
def _validate_longrope_parameters(config: PretrainedConfig):
|
448 |
+
rope_scaling = config.rope_scaling
|
449 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
450 |
+
required_keys = {"rope_type", "short_factor", "long_factor"}
|
451 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
452 |
+
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
453 |
+
received_keys = set(rope_scaling.keys())
|
454 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
455 |
+
|
456 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
457 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
458 |
+
dim = int(head_dim * partial_rotary_factor)
|
459 |
+
|
460 |
+
short_factor = rope_scaling.get("short_factor")
|
461 |
+
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
462 |
+
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
|
463 |
+
if not len(short_factor) == dim // 2:
|
464 |
+
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
|
465 |
+
|
466 |
+
long_factor = rope_scaling.get("long_factor")
|
467 |
+
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
|
468 |
+
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
|
469 |
+
if not len(long_factor) == dim // 2:
|
470 |
+
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
|
471 |
+
|
472 |
+
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
|
473 |
+
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
|
474 |
+
# unique to longrope (= undesirable)
|
475 |
+
if hasattr(config, "original_max_position_embeddings"):
|
476 |
+
logger.warning_once(
|
477 |
+
"This model has set a `original_max_position_embeddings` field, to be used together with "
|
478 |
+
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
|
479 |
+
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
|
480 |
+
"as it is compatible with most model architectures."
|
481 |
+
)
|
482 |
+
else:
|
483 |
+
factor = rope_scaling.get("factor")
|
484 |
+
if factor is None:
|
485 |
+
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
|
486 |
+
elif not isinstance(factor, float) or factor < 1.0:
|
487 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
488 |
+
|
489 |
+
attention_factor = rope_scaling.get("attention_factor")
|
490 |
+
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
|
491 |
+
logger.warning(
|
492 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
493 |
+
)
|
494 |
+
|
495 |
+
|
496 |
+
def _validate_llama3_parameters(config: PretrainedConfig):
|
497 |
+
rope_scaling = config.rope_scaling
|
498 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
499 |
+
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
500 |
+
received_keys = set(rope_scaling.keys())
|
501 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
502 |
+
|
503 |
+
factor = rope_scaling["factor"]
|
504 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
505 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
506 |
+
|
507 |
+
low_freq_factor = rope_scaling["low_freq_factor"]
|
508 |
+
high_freq_factor = rope_scaling["high_freq_factor"]
|
509 |
+
if low_freq_factor is None or not isinstance(low_freq_factor, float):
|
510 |
+
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
511 |
+
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
512 |
+
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
513 |
+
if high_freq_factor <= low_freq_factor:
|
514 |
+
logger.warning(
|
515 |
+
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
516 |
+
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
517 |
+
)
|
518 |
+
|
519 |
+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
520 |
+
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
|
521 |
+
logger.warning(
|
522 |
+
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
|
523 |
+
f"{original_max_position_embeddings}"
|
524 |
+
)
|
525 |
+
if original_max_position_embeddings >= config.max_position_embeddings:
|
526 |
+
logger.warning(
|
527 |
+
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
|
528 |
+
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
|
529 |
+
)
|
530 |
+
|
531 |
+
|
532 |
+
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
|
533 |
+
ROPE_VALIDATION_FUNCTIONS = {
|
534 |
+
"default": _validate_default_rope_parameters,
|
535 |
+
"linear": _validate_linear_scaling_rope_parameters,
|
536 |
+
"dynamic": _validate_dynamic_scaling_rope_parameters,
|
537 |
+
"yarn": _validate_yarn_parameters,
|
538 |
+
"longrope": _validate_longrope_parameters,
|
539 |
+
"llama3": _validate_llama3_parameters,
|
540 |
+
}
|
541 |
+
|
542 |
+
|
543 |
+
def rope_config_validation(config: PretrainedConfig):
|
544 |
+
"""
|
545 |
+
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
546 |
+
"""
|
547 |
+
rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
|
548 |
+
if rope_scaling is None:
|
549 |
+
return
|
550 |
+
|
551 |
+
# BC: "rope_type" was originally "type"
|
552 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
553 |
+
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
554 |
+
if validation_fn is not None:
|
555 |
+
validation_fn(config)
|
556 |
+
else:
|
557 |
+
logger.warning(
|
558 |
+
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
559 |
+
)
|
transformers_4_44_2__pytorch_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
variable_cache.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Nvidia Corporation. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from copy import deepcopy
|
17 |
+
from typing import Optional, Dict, Any, Tuple
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from transformers.cache_utils import Cache # used to let GenerationMixin know that we use a Cache object
|
21 |
+
|
22 |
+
from .configuration_decilm import DeciLMConfig, AttentionConfig
|
23 |
+
from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, StaticCache
|
24 |
+
|
25 |
+
|
26 |
+
class VariableCache(Cache_4_44_2, Cache):
|
27 |
+
"""
|
28 |
+
A Cache object that supports a different Cache implementation for every layer,
|
29 |
+
including layers without any kv-cache.
|
30 |
+
Implemented using a list of Cache objects, each represents a "model" with 1 layer.
|
31 |
+
The default implementation for the layer caches is StaticCache.
|
32 |
+
The cache of each layer is allocated to the same gpu as the layer itself.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self,
|
36 |
+
config: DeciLMConfig,
|
37 |
+
max_batch_size: int,
|
38 |
+
max_cache_len: int | None,
|
39 |
+
device: torch.device | str | None = None,
|
40 |
+
dtype: torch.dtype | None = None,
|
41 |
+
):
|
42 |
+
Cache_4_44_2.__init__(self)
|
43 |
+
|
44 |
+
self.config = config
|
45 |
+
self.max_batch_size = max_batch_size
|
46 |
+
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
47 |
+
self.dtype = dtype
|
48 |
+
|
49 |
+
self.layer_caches: list[Cache | None] = [None] * config.num_hidden_layers
|
50 |
+
|
51 |
+
def update(
|
52 |
+
self,
|
53 |
+
key_states: torch.Tensor,
|
54 |
+
value_states: torch.Tensor,
|
55 |
+
layer_idx: int,
|
56 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
57 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
58 |
+
layer_cache = self.layer_caches[layer_idx]
|
59 |
+
|
60 |
+
if layer_cache is None:
|
61 |
+
block_config = self.config.block_configs[layer_idx]
|
62 |
+
layer_cache = self._init_layer_cache(attention_config=block_config.attention, device=key_states.device)
|
63 |
+
assert layer_cache is not None, "Trying to update the cache of a cache-less layer"
|
64 |
+
self.layer_caches[layer_idx] = layer_cache
|
65 |
+
|
66 |
+
k_out, v_out = layer_cache.update(key_states=key_states,
|
67 |
+
value_states=value_states,
|
68 |
+
layer_idx=0,
|
69 |
+
cache_kwargs=cache_kwargs)
|
70 |
+
seq_len = self.get_seq_length(layer_idx)
|
71 |
+
k_out = k_out[:, :, :seq_len, :]
|
72 |
+
v_out = v_out[:, :, :seq_len, :]
|
73 |
+
return k_out, v_out
|
74 |
+
|
75 |
+
def _init_layer_cache(self,
|
76 |
+
attention_config: AttentionConfig,
|
77 |
+
device: torch.device,
|
78 |
+
) -> Cache | None:
|
79 |
+
if attention_config.no_op or attention_config.replace_with_linear:
|
80 |
+
return None
|
81 |
+
config = deepcopy(self.config)
|
82 |
+
config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
|
83 |
+
return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype)
|
84 |
+
|
85 |
+
def _get_first_real_cache(self) -> Cache:
|
86 |
+
for layer_cache in self.layer_caches:
|
87 |
+
if layer_cache is not None:
|
88 |
+
return layer_cache
|
89 |
+
raise ValueError(f"No real cache found, all layer caches are None.")
|
90 |
+
|
91 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
92 |
+
if layer_idx == 0 and self.layer_caches[0] is None:
|
93 |
+
try:
|
94 |
+
layer_cache = self._get_first_real_cache()
|
95 |
+
except ValueError:
|
96 |
+
return 0
|
97 |
+
else:
|
98 |
+
layer_cache = self.layer_caches[layer_idx]
|
99 |
+
return layer_cache.get_seq_length()
|
100 |
+
|
101 |
+
def get_max_length(self) -> Optional[int]:
|
102 |
+
"""Returns the maximum sequence length of the cached states."""
|
103 |
+
return self.max_cache_len
|
104 |
+
|
105 |
+
def reset(self):
|
106 |
+
for layer_cache in self.layer_caches:
|
107 |
+
if hasattr(layer_cache, "reset"):
|
108 |
+
layer_cache.reset()
|