Adding tflite
#12
by
0xrk
- opened
- CODE_OF_CONDUCT.md +0 -9
- LICENSE +0 -22
- NOTICE.md +0 -38
- README.md +36 -64
- Research License.docx +0 -0
- SECURITY.md +0 -41
- config.json +25 -20
- configuration_mixformer_sequential.py +59 -0
- generation_config.json +1 -1
- modeling_mixformer_sequential.py +778 -0
- model.safetensors → pytorch_model.bin +2 -2
CODE_OF_CONDUCT.md
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Microsoft Open Source Code of Conduct
|
2 |
-
|
3 |
-
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
-
|
5 |
-
Resources:
|
6 |
-
|
7 |
-
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
-
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
-
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
Microsoft.
|
2 |
-
Copyright (c) Microsoft Corporation.
|
3 |
-
|
4 |
-
MIT License
|
5 |
-
|
6 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
-
of this software and associated documentation files (the "Software"), to deal
|
8 |
-
in the Software without restriction, including without limitation the rights
|
9 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
-
copies of the Software, and to permit persons to whom the Software is
|
11 |
-
furnished to do so, subject to the following conditions:
|
12 |
-
|
13 |
-
The above copyright notice and this permission notice shall be included in all
|
14 |
-
copies or substantial portions of the Software.
|
15 |
-
|
16 |
-
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NOTICE.md
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
NOTICES AND INFORMATION
|
2 |
-
Do Not Translate or Localize
|
3 |
-
|
4 |
-
This software incorporates material from third parties.
|
5 |
-
|
6 |
-
**Component.** https://github.com/Dao-AILab/flash-attention
|
7 |
-
|
8 |
-
**Open Source License/Copyright Notice.**
|
9 |
-
|
10 |
-
BSD 3-Clause License
|
11 |
-
|
12 |
-
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
13 |
-
All rights reserved.
|
14 |
-
|
15 |
-
Redistribution and use in source and binary forms, with or without
|
16 |
-
modification, are permitted provided that the following conditions are met:
|
17 |
-
|
18 |
-
* Redistributions of source code must retain the above copyright notice, this
|
19 |
-
list of conditions and the following disclaimer.
|
20 |
-
|
21 |
-
* Redistributions in binary form must reproduce the above copyright notice,
|
22 |
-
this list of conditions and the following disclaimer in the documentation
|
23 |
-
and/or other materials provided with the distribution.
|
24 |
-
|
25 |
-
* Neither the name of the copyright holder nor the names of its
|
26 |
-
contributors may be used to endorse or promote products derived from
|
27 |
-
this software without specific prior written permission.
|
28 |
-
|
29 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
30 |
-
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
31 |
-
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
32 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
33 |
-
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
34 |
-
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
35 |
-
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
36 |
-
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
37 |
-
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
38 |
-
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,42 +1,30 @@
|
|
1 |
---
|
2 |
-
license:
|
3 |
-
license_link: https://huggingface.co/microsoft/phi-1_5/resolve/main/LICENSE
|
4 |
language:
|
5 |
- en
|
6 |
pipeline_tag: text-generation
|
7 |
-
tags:
|
8 |
-
- nlp
|
9 |
-
- code
|
10 |
---
|
11 |
## Model Summary
|
12 |
|
13 |
-
The language model
|
14 |
|
15 |
-
We **did not** fine-tune
|
16 |
|
17 |
For a safer model release, we exclude generic web-crawl data sources such as common-crawl from the training. This strategy prevents direct exposure to potentially harmful online content, enhancing the model's safety without RLHF. However, the model is still vulnerable to generating harmful content. We hope the model can help the research community to further study the safety of language models.
|
18 |
|
19 |
-
Phi-1.5 can write poems, draft emails, create stories, summarize texts, write Python code (such as downloading a Hugging Face transformer model), etc.
|
20 |
-
|
21 |
-
## How to Use
|
22 |
-
|
23 |
-
Phi-1.5 has been integrated in the `transformers` version 4.37.0, please ensure that you are using a version equal or higher than it.
|
24 |
-
|
25 |
## Intended Uses
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
### QA Format:
|
30 |
|
31 |
```markdown
|
32 |
Write a detailed analogy between mathematics and a lighthouse.
|
33 |
|
34 |
Answer: Mathematics is like a lighthouse, guiding us through the vast ocean of numbers and calculations. Just as a lighthouse illuminates the darkness, mathematics provides us with a clear path to navigate through complex problems. It helps us make sense of the world around us, just like a lighthouse helps ships find their way home.
|
35 |
```
|
36 |
-
|
37 |
where the model generates the text after "Answer:".
|
38 |
|
39 |
-
|
40 |
|
41 |
```markdown
|
42 |
Alice: I don't know why, I'm struggling to maintain focus while studying. Any suggestions?
|
@@ -57,11 +45,9 @@ Charlie: No problem, Alice. We're all in this together.
|
|
57 |
|
58 |
Bob: Yeah, and remember that it's okay to ask for help if you need it. We're here to support each other.
|
59 |
```
|
60 |
-
|
61 |
where the model generates the text after the first "Bob:".
|
62 |
|
63 |
-
|
64 |
-
|
65 |
```python
|
66 |
def print_prime(n):
|
67 |
"""
|
@@ -78,54 +64,24 @@ def print_prime(n):
|
|
78 |
primes.append(num)
|
79 |
print(primes)
|
80 |
```
|
81 |
-
|
82 |
where the model generates the text after the comments.
|
83 |
|
84 |
-
**Notes
|
85 |
-
|
86 |
-
*
|
87 |
-
|
88 |
-
* Phi-1.5 has not been tested to ensure that it performs adequately for any production-level application. Please refer to the limitation sections of this document for more details.
|
89 |
-
|
90 |
-
## Sample Code
|
91 |
-
|
92 |
-
```python
|
93 |
-
import torch
|
94 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
95 |
-
|
96 |
-
torch.set_default_device("cuda")
|
97 |
-
|
98 |
-
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype="auto")
|
99 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
|
100 |
-
|
101 |
-
inputs = tokenizer('''def print_prime(n):
|
102 |
-
"""
|
103 |
-
Print all primes between 1 and n
|
104 |
-
"""''', return_tensors="pt", return_attention_mask=False)
|
105 |
-
|
106 |
-
outputs = model.generate(**inputs, max_length=200)
|
107 |
-
text = tokenizer.batch_decode(outputs)[0]
|
108 |
-
print(text)
|
109 |
-
```
|
110 |
|
111 |
-
## Limitations of
|
112 |
|
113 |
* Generate Inaccurate Code and Facts: The model often produces incorrect code snippets and statements. Users should treat these outputs as suggestions or starting points, not as definitive or accurate solutions.
|
114 |
-
|
115 |
* Limited Scope for code: If the model generates Python scripts that utilize uncommon packages or scripts in other languages, we strongly recommend users manually verify all API uses.
|
116 |
-
|
117 |
* Unreliable Responses to Instruction: The model has not undergone instruction fine-tuning. As a result, it may struggle or fail to adhere to intricate or nuanced instructions provided by users.
|
118 |
-
|
119 |
* Language Limitations: The model is primarily designed to understand standard English. Informal English, slang, or any other language outside of English might pose challenges to its comprehension, leading to potential misinterpretations or errors in response.
|
120 |
-
|
121 |
* Potential Societal Biases: Regardless of the safe data used for its training, the model is not entirely free from societal biases. There's a possibility it may generate content that mirrors these societal biases, particularly if prompted or instructed to do so. We urge users to be aware of this and to exercise caution and critical thinking when interpreting model outputs.
|
122 |
-
|
123 |
* Toxicity: Despite that the model is trained with carefully selected data, the model can still produce harmful content if explicitly prompted or instructed to do so. We chose to release the model for research purposes only -- We hope to help the open-source community develop the most effective ways to reduce the toxicity of a model directly after pretraining.
|
124 |
|
125 |
## Training
|
126 |
|
127 |
### Model
|
128 |
-
|
129 |
* Architecture: a Transformer-based model with next-word prediction objective
|
130 |
* Dataset size: 30B tokens
|
131 |
* Training tokens: 150B tokens
|
@@ -134,18 +90,38 @@ print(text)
|
|
134 |
* Training time: 8 days
|
135 |
|
136 |
### Software
|
137 |
-
|
138 |
* [PyTorch](https://github.com/pytorch/pytorch)
|
139 |
* [DeepSpeed](https://github.com/microsoft/DeepSpeed)
|
140 |
-
* [
|
141 |
|
142 |
### License
|
|
|
143 |
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
### Citation
|
147 |
|
148 |
-
You can find the paper at https://arxiv.org/abs/2309.05463
|
149 |
|
150 |
```bib
|
151 |
@article{textbooks2,
|
@@ -154,8 +130,4 @@ You can find the paper at https://arxiv.org/abs/2309.05463. Please cite as:
|
|
154 |
journal={arXiv preprint arXiv:2309.05463},
|
155 |
year={2023}
|
156 |
}
|
157 |
-
```
|
158 |
-
|
159 |
-
## Trademarks
|
160 |
-
|
161 |
-
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
|
|
|
1 |
---
|
2 |
+
license: other
|
|
|
3 |
language:
|
4 |
- en
|
5 |
pipeline_tag: text-generation
|
|
|
|
|
|
|
6 |
---
|
7 |
## Model Summary
|
8 |
|
9 |
+
The language model phi-1.5 is a Transformer with 1.3 billion parameters. It was trained using the same data sources as [phi-1](https://huggingface.co/microsoft/phi-1), augmented with a new data source that consists of various NLP synthetic texts. When assessed against benchmarks testing common sense, language understanding, and logical reasoning, phi-1.5 demonstrates a nearly state-of-the-art performance among models with less than 10 billion parameters.
|
10 |
|
11 |
+
We **did not** fine-tune phi-1.5 either for **instruction following or through reinforcement learning from human feedback**. The intention behind crafting this open-source model is to provide the research community with a non-restricted small model to explore vital safety challenges, such as reducing toxicity, understanding societal biases, enhancing controllability, and more.
|
12 |
|
13 |
For a safer model release, we exclude generic web-crawl data sources such as common-crawl from the training. This strategy prevents direct exposure to potentially harmful online content, enhancing the model's safety without RLHF. However, the model is still vulnerable to generating harmful content. We hope the model can help the research community to further study the safety of language models.
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
## Intended Uses
|
16 |
+
Given the nature of the training data, phi-1.5 is best suited for prompts using the QA format, the chat format, and the code format. Note that phi-1.5, being a base model, often produces irrelevant text following the main answer. In the following example, we've truncated the answer for illustrative purposes only.
|
17 |
|
18 |
+
#### QA format:
|
|
|
|
|
19 |
|
20 |
```markdown
|
21 |
Write a detailed analogy between mathematics and a lighthouse.
|
22 |
|
23 |
Answer: Mathematics is like a lighthouse, guiding us through the vast ocean of numbers and calculations. Just as a lighthouse illuminates the darkness, mathematics provides us with a clear path to navigate through complex problems. It helps us make sense of the world around us, just like a lighthouse helps ships find their way home.
|
24 |
```
|
|
|
25 |
where the model generates the text after "Answer:".
|
26 |
|
27 |
+
#### Chat format:
|
28 |
|
29 |
```markdown
|
30 |
Alice: I don't know why, I'm struggling to maintain focus while studying. Any suggestions?
|
|
|
45 |
|
46 |
Bob: Yeah, and remember that it's okay to ask for help if you need it. We're here to support each other.
|
47 |
```
|
|
|
48 |
where the model generates the text after the first "Bob:".
|
49 |
|
50 |
+
#### Code format:
|
|
|
51 |
```python
|
52 |
def print_prime(n):
|
53 |
"""
|
|
|
64 |
primes.append(num)
|
65 |
print(primes)
|
66 |
```
|
|
|
67 |
where the model generates the text after the comments.
|
68 |
|
69 |
+
**Notes**
|
70 |
+
* phi-1.5 is intended for research purposes. The model-generated text/code should be treated as a starting point rather than a definitive solution for potential use cases. Users should be cautious when employing these models in their applications.
|
71 |
+
* Direct adoption for production tasks is out of the scope of this research project. As a result, phi-1.5 has not been tested to ensure that it performs adequately for any production-level application. Please refer to the limitation sections of this document for more details.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
## Limitations of phi-1.5
|
74 |
|
75 |
* Generate Inaccurate Code and Facts: The model often produces incorrect code snippets and statements. Users should treat these outputs as suggestions or starting points, not as definitive or accurate solutions.
|
|
|
76 |
* Limited Scope for code: If the model generates Python scripts that utilize uncommon packages or scripts in other languages, we strongly recommend users manually verify all API uses.
|
|
|
77 |
* Unreliable Responses to Instruction: The model has not undergone instruction fine-tuning. As a result, it may struggle or fail to adhere to intricate or nuanced instructions provided by users.
|
|
|
78 |
* Language Limitations: The model is primarily designed to understand standard English. Informal English, slang, or any other language outside of English might pose challenges to its comprehension, leading to potential misinterpretations or errors in response.
|
|
|
79 |
* Potential Societal Biases: Regardless of the safe data used for its training, the model is not entirely free from societal biases. There's a possibility it may generate content that mirrors these societal biases, particularly if prompted or instructed to do so. We urge users to be aware of this and to exercise caution and critical thinking when interpreting model outputs.
|
|
|
80 |
* Toxicity: Despite that the model is trained with carefully selected data, the model can still produce harmful content if explicitly prompted or instructed to do so. We chose to release the model for research purposes only -- We hope to help the open-source community develop the most effective ways to reduce the toxicity of a model directly after pretraining.
|
81 |
|
82 |
## Training
|
83 |
|
84 |
### Model
|
|
|
85 |
* Architecture: a Transformer-based model with next-word prediction objective
|
86 |
* Dataset size: 30B tokens
|
87 |
* Training tokens: 150B tokens
|
|
|
90 |
* Training time: 8 days
|
91 |
|
92 |
### Software
|
|
|
93 |
* [PyTorch](https://github.com/pytorch/pytorch)
|
94 |
* [DeepSpeed](https://github.com/microsoft/DeepSpeed)
|
95 |
+
* [flash-attention](https://github.com/HazyResearch/flash-attention)
|
96 |
|
97 |
### License
|
98 |
+
The model is licensed under the [Research License](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx).
|
99 |
|
100 |
+
### Sample Code
|
101 |
+
```python
|
102 |
+
import torch
|
103 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
104 |
+
|
105 |
+
torch.set_default_device('cuda')
|
106 |
+
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto")
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto")
|
108 |
+
inputs = tokenizer('''```python
|
109 |
+
def print_prime(n):
|
110 |
+
"""
|
111 |
+
Print all primes between 1 and n
|
112 |
+
"""''', return_tensors="pt", return_attention_mask=False)
|
113 |
+
|
114 |
+
outputs = model.generate(**inputs, max_length=200)
|
115 |
+
text = tokenizer.batch_decode(outputs)[0]
|
116 |
+
print(text)
|
117 |
+
```
|
118 |
+
|
119 |
+
**Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1) and `attention_mask' parameters.
|
120 |
+
Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
|
121 |
|
122 |
### Citation
|
123 |
|
124 |
+
You can find the paper at https://arxiv.org/abs/2309.05463
|
125 |
|
126 |
```bib
|
127 |
@article{textbooks2,
|
|
|
130 |
journal={arXiv preprint arXiv:2309.05463},
|
131 |
year={2023}
|
132 |
}
|
133 |
+
```
|
|
|
|
|
|
|
|
Research License.docx
ADDED
Binary file (38.9 kB). View file
|
|
SECURITY.md
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
|
2 |
-
|
3 |
-
## Security
|
4 |
-
|
5 |
-
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
|
6 |
-
|
7 |
-
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
|
8 |
-
|
9 |
-
## Reporting Security Issues
|
10 |
-
|
11 |
-
**Please do not report security vulnerabilities through public GitHub issues.**
|
12 |
-
|
13 |
-
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
|
14 |
-
|
15 |
-
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
|
16 |
-
|
17 |
-
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
18 |
-
|
19 |
-
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
20 |
-
|
21 |
-
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
22 |
-
* Full paths of source file(s) related to the manifestation of the issue
|
23 |
-
* The location of the affected source code (tag/branch/commit or direct URL)
|
24 |
-
* Any special configuration required to reproduce the issue
|
25 |
-
* Step-by-step instructions to reproduce the issue
|
26 |
-
* Proof-of-concept or exploit code (if possible)
|
27 |
-
* Impact of the issue, including how an attacker might exploit the issue
|
28 |
-
|
29 |
-
This information will help us triage your report more quickly.
|
30 |
-
|
31 |
-
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
|
32 |
-
|
33 |
-
## Preferred Languages
|
34 |
-
|
35 |
-
We prefer all communications to be in English.
|
36 |
-
|
37 |
-
## Policy
|
38 |
-
|
39 |
-
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
|
40 |
-
|
41 |
-
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.json
CHANGED
@@ -1,30 +1,35 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
"architectures": [
|
4 |
-
"
|
5 |
],
|
6 |
-
"
|
7 |
-
|
|
|
|
|
|
|
8 |
"embd_pdrop": 0.0,
|
9 |
-
"eos_token_id": null,
|
10 |
-
"hidden_act": "gelu_new",
|
11 |
-
"hidden_size": 2048,
|
12 |
"initializer_range": 0.02,
|
13 |
-
"
|
14 |
-
"
|
15 |
-
"
|
16 |
-
"
|
17 |
-
"
|
18 |
-
"
|
19 |
-
"
|
20 |
-
"
|
21 |
-
"qk_layernorm": false,
|
22 |
"resid_pdrop": 0.0,
|
23 |
-
"
|
24 |
-
"rope_theta": 10000.0,
|
25 |
"tie_word_embeddings": false,
|
26 |
"torch_dtype": "float16",
|
27 |
-
"transformers_version": "4.
|
28 |
-
"use_cache": true,
|
29 |
"vocab_size": 51200
|
30 |
}
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "phi-1.5-half",
|
3 |
+
"activation_function": "gelu_new",
|
4 |
+
"architecture": {
|
5 |
+
"block_cls": "parallel",
|
6 |
+
"mixer": {},
|
7 |
+
"mlp": {
|
8 |
+
"mlp_cls": "mlp"
|
9 |
+
}
|
10 |
+
},
|
11 |
"architectures": [
|
12 |
+
"MixFormerSequentialForCausalLM"
|
13 |
],
|
14 |
+
"auto_map": {
|
15 |
+
"AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
|
16 |
+
"AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
|
17 |
+
},
|
18 |
+
"embd_layer": "default",
|
19 |
"embd_pdrop": 0.0,
|
|
|
|
|
|
|
20 |
"initializer_range": 0.02,
|
21 |
+
"layer_norm_epsilon": 1e-05,
|
22 |
+
"model_type": "mixformer-sequential",
|
23 |
+
"n_embd": 2048,
|
24 |
+
"n_head": 32,
|
25 |
+
"n_inner": null,
|
26 |
+
"n_layer": 24,
|
27 |
+
"n_positions": 2048,
|
28 |
+
"phyagi_version": "0.0.4.dev",
|
|
|
29 |
"resid_pdrop": 0.0,
|
30 |
+
"rotary_dim": 32,
|
|
|
31 |
"tie_word_embeddings": false,
|
32 |
"torch_dtype": "float16",
|
33 |
+
"transformers_version": "4.32.1",
|
|
|
34 |
"vocab_size": 51200
|
35 |
}
|
configuration_mixformer_sequential.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import math
|
5 |
+
from typing import Any, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
from transformers import PretrainedConfig
|
8 |
+
|
9 |
+
|
10 |
+
class MixFormerSequentialConfig(PretrainedConfig):
|
11 |
+
"""MixFormer (sequential for DeepSpeed) configuration."""
|
12 |
+
|
13 |
+
model_type = "mixformer-sequential"
|
14 |
+
|
15 |
+
attribute_map = {
|
16 |
+
"max_position_embeddings": "n_positions",
|
17 |
+
"hidden_size": "n_embd",
|
18 |
+
"num_attention_heads": "n_head",
|
19 |
+
"num_hidden_layers": "n_layer",
|
20 |
+
"input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
|
21 |
+
"blocks": "architecture", # `blocks` key is for backward compatibility
|
22 |
+
}
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
vocab_size: Optional[int] = 50304,
|
27 |
+
n_positions: Optional[int] = 2048,
|
28 |
+
n_embd: Optional[int] = 1024,
|
29 |
+
n_layer: Optional[int] = 20,
|
30 |
+
n_inner: Optional[int] = None,
|
31 |
+
n_head: Optional[int] = 16,
|
32 |
+
rotary_dim: Optional[int] = 32,
|
33 |
+
activation_function: Optional[str] = "gelu_new",
|
34 |
+
embd_layer: Optional[str] = "default",
|
35 |
+
architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
|
36 |
+
embd_pdrop: Optional[float] = 0.0,
|
37 |
+
resid_pdrop: Optional[float] = 0.0,
|
38 |
+
layer_norm_epsilon: Optional[float] = 1e-5,
|
39 |
+
initializer_range: Optional[float] = 0.02,
|
40 |
+
tie_word_embeddings: Optional[bool] = False,
|
41 |
+
pad_vocab_size_multiple: Optional[int] = 64,
|
42 |
+
**kwargs
|
43 |
+
) -> None:
|
44 |
+
self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
45 |
+
self.n_positions = n_positions
|
46 |
+
self.n_embd = n_embd
|
47 |
+
self.n_layer = n_layer
|
48 |
+
self.n_inner = n_inner
|
49 |
+
self.n_head = n_head
|
50 |
+
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
51 |
+
self.activation_function = activation_function
|
52 |
+
self.embd_layer = embd_layer
|
53 |
+
self.architecture = architecture
|
54 |
+
self.embd_pdrop = embd_pdrop
|
55 |
+
self.resid_pdrop = resid_pdrop
|
56 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
57 |
+
self.initializer_range = initializer_range
|
58 |
+
|
59 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
generation_config.json
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
{
|
2 |
"_from_model_config": true,
|
3 |
-
"transformers_version": "4.
|
4 |
}
|
|
|
1 |
{
|
2 |
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.32.1"
|
4 |
}
|
modeling_mixformer_sequential.py
ADDED
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
# BSD 3-Clause License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
7 |
+
# All rights reserved.
|
8 |
+
#
|
9 |
+
# Redistribution and use in source and binary forms, with or without
|
10 |
+
# modification, are permitted provided that the following conditions are met:
|
11 |
+
#
|
12 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
13 |
+
# list of conditions and the following disclaimer.
|
14 |
+
#
|
15 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
16 |
+
# this list of conditions and the following disclaimer in the documentation
|
17 |
+
# and/or other materials provided with the distribution.
|
18 |
+
#
|
19 |
+
# * Neither the name of the copyright holder nor the names of its
|
20 |
+
# contributors may be used to endorse or promote products derived from
|
21 |
+
# this software without specific prior written permission.
|
22 |
+
#
|
23 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
24 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
25 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
26 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
27 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
28 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
29 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
30 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
31 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
32 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
33 |
+
|
34 |
+
from __future__ import annotations
|
35 |
+
|
36 |
+
import math
|
37 |
+
import copy
|
38 |
+
from typing import Any, Dict, Optional, Tuple
|
39 |
+
from dataclasses import dataclass, field
|
40 |
+
|
41 |
+
import torch
|
42 |
+
import torch.nn as nn
|
43 |
+
|
44 |
+
from einops import rearrange
|
45 |
+
from transformers.activations import ACT2FN
|
46 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
47 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
48 |
+
|
49 |
+
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class InferenceParams:
|
53 |
+
"""Inference parameters that are passed to the main model in order
|
54 |
+
to efficienly calculate and store the context during inference.
|
55 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
56 |
+
max_sequence_len: int
|
57 |
+
max_batch_size: int
|
58 |
+
sequence_len_offset: int = 0
|
59 |
+
batch_size_offset: int = 0
|
60 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
61 |
+
fused_ft_kernel: bool = False
|
62 |
+
lengths_per_sample: Optional[torch.Tensor] = None
|
63 |
+
|
64 |
+
|
65 |
+
class Embedding(nn.Module):
|
66 |
+
"""Token embedding with dropout."""
|
67 |
+
|
68 |
+
def __init__(self, config: PretrainedConfig) -> None:
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
72 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
73 |
+
|
74 |
+
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
75 |
+
input_shape = input_ids.size()
|
76 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
77 |
+
|
78 |
+
hidden_states = self.wte(input_ids)
|
79 |
+
hidden_states = self.drop(hidden_states)
|
80 |
+
|
81 |
+
return hidden_states
|
82 |
+
|
83 |
+
class RotaryEmbedding(nn.Module):
|
84 |
+
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
|
85 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
dim: int,
|
90 |
+
base: Optional[int] = 10000,
|
91 |
+
scale_base: Optional[float] = None,
|
92 |
+
device: Optional[str] = None,
|
93 |
+
**kwargs,
|
94 |
+
) -> None:
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
if scale_base is not None:
|
98 |
+
raise NotImplementedError
|
99 |
+
|
100 |
+
# Generate and save the inverse frequency buffer (non-trainable)
|
101 |
+
self.dim = dim
|
102 |
+
self.base = base
|
103 |
+
self.scale_base = scale_base
|
104 |
+
self.device = device
|
105 |
+
|
106 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
107 |
+
self.register_buffer("inv_freq", inv_freq)
|
108 |
+
|
109 |
+
scale = (
|
110 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
111 |
+
if scale_base is not None
|
112 |
+
else None
|
113 |
+
)
|
114 |
+
self.register_buffer("scale", scale)
|
115 |
+
|
116 |
+
self._seq_len_cached = 0
|
117 |
+
self._cos_cached = None
|
118 |
+
self._sin_cached = None
|
119 |
+
self._cos_k_cached = None
|
120 |
+
self._sin_k_cached = None
|
121 |
+
|
122 |
+
def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0) -> None:
|
123 |
+
# Reset the tables if the sequence length has changed,
|
124 |
+
# or if we're on a new device (possibly due to tracing for instance)
|
125 |
+
seqlen = x.shape[1] + seqlen_offset
|
126 |
+
|
127 |
+
# Re-generate the inverse frequency buffer if it's not fp32
|
128 |
+
# (for instance if model.half() was called)
|
129 |
+
if self.inv_freq.dtype != "torch.float32":
|
130 |
+
self.inv_freq = 1.0 / (
|
131 |
+
self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
|
132 |
+
)
|
133 |
+
|
134 |
+
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
|
135 |
+
self._seq_len_cached = seqlen
|
136 |
+
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
|
137 |
+
|
138 |
+
# Don't do einsum, it converts fp32 to fp16
|
139 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
140 |
+
freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
|
141 |
+
if self.scale is None:
|
142 |
+
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
143 |
+
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
144 |
+
else:
|
145 |
+
power = (
|
146 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
147 |
+
) / self.scale_base
|
148 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
149 |
+
|
150 |
+
# We want the multiplication by scale to happen in fp32
|
151 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
152 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
153 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
154 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
155 |
+
|
156 |
+
def apply_rotary_emb_qkv(
|
157 |
+
self,
|
158 |
+
qkv: torch.FloatTensor,
|
159 |
+
sin: torch.FloatTensor,
|
160 |
+
cos: torch.FloatTensor,
|
161 |
+
sin_k: Optional[torch.FloatTensor] = None,
|
162 |
+
cos_k: Optional[torch.FloatTensor] = None,
|
163 |
+
) -> torch.FloatTensor:
|
164 |
+
_, seqlen, three, _, headdim = qkv.shape
|
165 |
+
assert three == 3
|
166 |
+
|
167 |
+
rotary_seqlen, rotary_dim = cos.shape
|
168 |
+
rotary_dim *= 2
|
169 |
+
assert rotary_dim <= headdim
|
170 |
+
assert seqlen <= rotary_seqlen
|
171 |
+
|
172 |
+
cos_k = cos if cos_k is None else cos_k
|
173 |
+
sin_k = sin if sin_k is None else sin_k
|
174 |
+
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
175 |
+
|
176 |
+
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
177 |
+
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
178 |
+
|
179 |
+
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
180 |
+
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
181 |
+
|
182 |
+
# Splits the queries and keys in half
|
183 |
+
q1, q2 = q_rot.chunk(2, dim=-1)
|
184 |
+
k1, k2 = k_rot.chunk(2, dim=-1)
|
185 |
+
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
186 |
+
|
187 |
+
# Casts to fp32 are necessary to prevent fp16 overflow issues
|
188 |
+
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
|
189 |
+
|
190 |
+
# Computes the new keys and queries, recasting to original dtype
|
191 |
+
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
192 |
+
|
193 |
+
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
194 |
+
|
195 |
+
return torch.cat(
|
196 |
+
[
|
197 |
+
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
198 |
+
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
199 |
+
qkv[:, :, 2:3, :, :],
|
200 |
+
],
|
201 |
+
axis=2,
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
205 |
+
"""Perform the forward pass.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
|
209 |
+
seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
New `qkv` and the cached sinusoids.
|
213 |
+
|
214 |
+
"""
|
215 |
+
|
216 |
+
self._update_cos_sin_cache(qkv, seqlen_offset)
|
217 |
+
|
218 |
+
return self.apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
|
219 |
+
|
220 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
221 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
222 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
223 |
+
# Pre-allocate memory for key-values for inference.
|
224 |
+
num_heads, head_dim = kv.shape[-2:]
|
225 |
+
if layer_idx not in inference_params.key_value_memory_dict:
|
226 |
+
kv_cache = torch.empty(
|
227 |
+
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
|
228 |
+
num_heads, head_dim, dtype=kv.dtype, device=kv.device
|
229 |
+
)
|
230 |
+
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
231 |
+
else:
|
232 |
+
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
233 |
+
|
234 |
+
# Adjust key and value for inference
|
235 |
+
batch_start = inference_params.batch_size_offset
|
236 |
+
batch_end = batch_start + kv.shape[0]
|
237 |
+
sequence_start = inference_params.sequence_len_offset
|
238 |
+
sequence_end = sequence_start + kv.shape[1]
|
239 |
+
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
240 |
+
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
241 |
+
|
242 |
+
assert kv_cache is not None
|
243 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
244 |
+
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
245 |
+
return kv
|
246 |
+
|
247 |
+
|
248 |
+
class MLP(nn.Module):
|
249 |
+
"""Multi-Layer Perceptron.
|
250 |
+
|
251 |
+
Reference:
|
252 |
+
Attention Is All You Need.
|
253 |
+
https://arxiv.org/pdf/1706.03762.pdf.
|
254 |
+
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None:
|
258 |
+
super().__init__()
|
259 |
+
|
260 |
+
act_fn = config.activation_function if act_fn is None else act_fn
|
261 |
+
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
262 |
+
|
263 |
+
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
264 |
+
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
265 |
+
|
266 |
+
self.fc1 = nn.Linear(config.n_embd, n_inner)
|
267 |
+
self.fc2 = nn.Linear(n_inner, config.n_embd)
|
268 |
+
self.act = ACT2FN[act_fn]
|
269 |
+
|
270 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
271 |
+
old_keys = [prefix + "fc_in.weight", prefix + "fc_out.weight", prefix + "fc_in.bias", prefix + "fc_out.bias"]
|
272 |
+
new_keys = [prefix + "fc1.weight", prefix + "fc2.weight", prefix + "fc1.bias", prefix + "fc2.bias"]
|
273 |
+
|
274 |
+
if all(k in state_dict for k in old_keys) and not all(k in state_dict for k in new_keys):
|
275 |
+
# Older version of `MLP` saved with different key names.
|
276 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
277 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
278 |
+
|
279 |
+
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
280 |
+
|
281 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
282 |
+
hidden_states = self.fc1(hidden_states)
|
283 |
+
hidden_states = self.act(hidden_states)
|
284 |
+
hidden_states = self.fc2(hidden_states)
|
285 |
+
|
286 |
+
return hidden_states
|
287 |
+
|
288 |
+
|
289 |
+
class FusedMLP(nn.Module):
|
290 |
+
"""Fused Multi-Layer Perceptron from `flash-attn`.
|
291 |
+
|
292 |
+
Reference:
|
293 |
+
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
|
294 |
+
|
295 |
+
"""
|
296 |
+
def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None,
|
297 |
+
raise_on_missing: bool = False) -> None:
|
298 |
+
super().__init__()
|
299 |
+
|
300 |
+
act_fn = config.activation_function if act_fn is None else act_fn
|
301 |
+
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
302 |
+
|
303 |
+
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
304 |
+
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
305 |
+
|
306 |
+
gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"]
|
307 |
+
activation = "gelu_approx" if act_fn in gelu_activations else "relu"
|
308 |
+
|
309 |
+
self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
|
310 |
+
|
311 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
312 |
+
return self.mlp(hidden_states)
|
313 |
+
|
314 |
+
class SelfAttention(nn.Module):
|
315 |
+
"""Implement the scaled dot product attention with softmax.
|
316 |
+
Adapted from https://github.com/Dao-AILab/flash-attention.
|
317 |
+
Arguments
|
318 |
+
---------
|
319 |
+
softmax_scale: The temperature to use for the softmax attention.
|
320 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
321 |
+
runtime)
|
322 |
+
attention_dropout: The dropout rate to apply to the attention
|
323 |
+
(default: 0.0)
|
324 |
+
"""
|
325 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
326 |
+
super().__init__()
|
327 |
+
self.causal = causal
|
328 |
+
self.softmax_scale = softmax_scale
|
329 |
+
self.drop = nn.Dropout(attention_dropout)
|
330 |
+
|
331 |
+
def forward(self, qkv, causal=None, key_padding_mask=None):
|
332 |
+
"""Implements the multihead softmax attention.
|
333 |
+
Arguments
|
334 |
+
---------
|
335 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
336 |
+
causal: if passed, will override self.causal
|
337 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
338 |
+
False means to mask out. (B, S)
|
339 |
+
"""
|
340 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
341 |
+
causal = self.causal if causal is None else causal
|
342 |
+
q, k, v = qkv.unbind(dim=2)
|
343 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
344 |
+
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
345 |
+
if key_padding_mask is not None:
|
346 |
+
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
|
347 |
+
device=scores.device)
|
348 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
349 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
350 |
+
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
|
351 |
+
if causal:
|
352 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
353 |
+
# So we have to construct the mask in float
|
354 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
355 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
356 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
357 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
358 |
+
attention_drop = self.drop(attention)
|
359 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
360 |
+
return output
|
361 |
+
|
362 |
+
|
363 |
+
class CrossAttention(nn.Module):
|
364 |
+
"""Implement the scaled dot product attention with softmax.
|
365 |
+
Adapted from https://github.com/Dao-AILab/flash-attention.
|
366 |
+
Arguments
|
367 |
+
---------
|
368 |
+
softmax_scale: The temperature to use for the softmax attention.
|
369 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
370 |
+
runtime)
|
371 |
+
attention_dropout: The dropout rate to apply to the attention
|
372 |
+
(default: 0.0)
|
373 |
+
"""
|
374 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
375 |
+
super().__init__()
|
376 |
+
self.causal = causal
|
377 |
+
self.softmax_scale = softmax_scale
|
378 |
+
self.drop = nn.Dropout(attention_dropout)
|
379 |
+
|
380 |
+
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
381 |
+
"""Implements the multihead softmax attention.
|
382 |
+
Arguments
|
383 |
+
---------
|
384 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
385 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
|
386 |
+
causal: if passed, will override self.causal
|
387 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
388 |
+
False means to mask out. (B, Sk)
|
389 |
+
"""
|
390 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
391 |
+
causal = self.causal if causal is None else causal
|
392 |
+
seqlen_k = kv.shape[1]
|
393 |
+
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
394 |
+
k, v = kv.unbind(dim=2)
|
395 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
396 |
+
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
397 |
+
if key_padding_mask is not None:
|
398 |
+
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
|
399 |
+
device=scores.device)
|
400 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
401 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
402 |
+
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
|
403 |
+
if causal:
|
404 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
405 |
+
# So we have to construct the mask in float
|
406 |
+
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
|
407 |
+
device=scores.device), 1)
|
408 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
409 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
410 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
411 |
+
attention_drop = self.drop(attention)
|
412 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
413 |
+
return output
|
414 |
+
|
415 |
+
def find_mha_dims(
|
416 |
+
config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
|
417 |
+
) -> Tuple[int, int]:
|
418 |
+
"""Validate and return the number of heads and head dimension for multi-head attention.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
config: Model configuration.
|
422 |
+
n_head: Number of heads.
|
423 |
+
head_dim: Head dimension.
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
Number of heads and head dimension.
|
427 |
+
|
428 |
+
"""
|
429 |
+
|
430 |
+
assert all(
|
431 |
+
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
432 |
+
), "`config` must have `n_embd` and `n_head` attributes."
|
433 |
+
|
434 |
+
if head_dim is None:
|
435 |
+
assert (
|
436 |
+
config.n_embd % config.n_head == 0
|
437 |
+
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
|
438 |
+
|
439 |
+
if n_head is None and head_dim is None:
|
440 |
+
head_dim = config.n_embd // config.n_head
|
441 |
+
n_head = config.n_head
|
442 |
+
elif n_head is None or head_dim is None:
|
443 |
+
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
444 |
+
|
445 |
+
return n_head, head_dim
|
446 |
+
|
447 |
+
|
448 |
+
class MHA(nn.Module):
|
449 |
+
"""Multi-head attention layer.
|
450 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
451 |
+
|
452 |
+
def __init__(
|
453 |
+
self,
|
454 |
+
config: PretrainedConfig,
|
455 |
+
rotary_dim: Optional[int] = None,
|
456 |
+
n_head: Optional[int] = None,
|
457 |
+
head_dim: Optional[int] = None,
|
458 |
+
bias: Optional[bool] = True,
|
459 |
+
dropout: Optional[float] = 0.0,
|
460 |
+
softmax_scale: Optional[float] = None,
|
461 |
+
causal: Optional[bool] = True,
|
462 |
+
layer_idx: Optional[int] = None,
|
463 |
+
rotary_emb_scale_base: Optional[float] = None,
|
464 |
+
return_residual: Optional[bool] = False,
|
465 |
+
checkpointing: Optional[bool] = False,
|
466 |
+
device: Optional[str] = None,
|
467 |
+
dtype: Optional[torch.dtype] = None,
|
468 |
+
fused_dense: Optional[bool] = True,
|
469 |
+
flash_attn: Optional[bool] = True,
|
470 |
+
cutlass_attn: Optional[bool] = False,
|
471 |
+
flash_rotary: Optional[bool] = True,
|
472 |
+
raise_on_missing: Optional[bool] = False
|
473 |
+
) -> None:
|
474 |
+
super().__init__()
|
475 |
+
|
476 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
477 |
+
n_head, head_dim = find_mha_dims(config, n_head, head_dim)
|
478 |
+
|
479 |
+
self.hidden_size = config.n_embd
|
480 |
+
self.n_head = n_head
|
481 |
+
self.head_dim = head_dim
|
482 |
+
self.op_size = n_head * head_dim
|
483 |
+
|
484 |
+
self.causal = causal
|
485 |
+
self.layer_idx = layer_idx
|
486 |
+
self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
487 |
+
self.fused_dense = fused_dense
|
488 |
+
self.flash_attn = flash_attn
|
489 |
+
self.cutlass_attn = cutlass_attn
|
490 |
+
self.flash_rotary = flash_rotary
|
491 |
+
self.return_residual = return_residual
|
492 |
+
self.checkpointing = checkpointing
|
493 |
+
|
494 |
+
if self.rotary_emb_dim > 0:
|
495 |
+
rotary_kwargs = {"device": device}
|
496 |
+
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
497 |
+
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
498 |
+
|
499 |
+
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
|
500 |
+
else:
|
501 |
+
pass
|
502 |
+
|
503 |
+
self.Wqkv = nn.Linear(self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs)
|
504 |
+
self.out_proj = nn.Linear(self.op_size, self.hidden_size, bias=bias, **factory_kwargs)
|
505 |
+
|
506 |
+
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
507 |
+
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
508 |
+
|
509 |
+
def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
|
510 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
511 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
512 |
+
|
513 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
514 |
+
|
515 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
516 |
+
|
517 |
+
def forward(
|
518 |
+
self,
|
519 |
+
x: torch.FloatTensor,
|
520 |
+
x_kv: Optional[torch.FloatTensor] = None,
|
521 |
+
key_padding_mask: Optional[torch.BoolTensor] = None,
|
522 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
523 |
+
max_seqlen: Optional[int] = None,
|
524 |
+
mixer_subset: Optional[torch.LongTensor] = None,
|
525 |
+
past_cache: Optional[InferenceParams] = None,
|
526 |
+
**kwargs
|
527 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
528 |
+
"""Perform the forward pass.
|
529 |
+
|
530 |
+
Args:
|
531 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
532 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
533 |
+
is the is the sum of the sequence lengths in the batch.
|
534 |
+
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
535 |
+
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
536 |
+
(batch, seqlen). Only applicable when not using FlashAttention.
|
537 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
538 |
+
of the sequences in the batch, used to index into x. Only applicable when using
|
539 |
+
FlashAttention.
|
540 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
541 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
542 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
543 |
+
about the CLS token in the last layer.
|
544 |
+
past_cache: For generation only.
|
545 |
+
|
546 |
+
Returns:
|
547 |
+
(batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
|
548 |
+
else (total, hidden_dim) where total is the is the sum of the sequence lengths
|
549 |
+
in the batch.
|
550 |
+
|
551 |
+
"""
|
552 |
+
|
553 |
+
if cu_seqlens is not None:
|
554 |
+
assert max_seqlen is not None
|
555 |
+
assert key_padding_mask is None
|
556 |
+
assert self.flash_attn
|
557 |
+
assert self.rotary_emb_dim == 0
|
558 |
+
|
559 |
+
if key_padding_mask is not None:
|
560 |
+
assert cu_seqlens is None
|
561 |
+
assert max_seqlen is None
|
562 |
+
assert not self.flash_attn
|
563 |
+
|
564 |
+
if past_cache is not None:
|
565 |
+
assert key_padding_mask is None
|
566 |
+
assert cu_seqlens is None and max_seqlen is None
|
567 |
+
|
568 |
+
attn_kwargs = {"key_padding_mask": key_padding_mask}
|
569 |
+
|
570 |
+
assert x_kv is None and mixer_subset is None
|
571 |
+
|
572 |
+
qkv = self.Wqkv(x)
|
573 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
574 |
+
|
575 |
+
if past_cache is None:
|
576 |
+
if self.rotary_emb_dim > 0:
|
577 |
+
qkv = self.rotary_emb(qkv)
|
578 |
+
context = self.inner_attn(qkv, **attn_kwargs)
|
579 |
+
|
580 |
+
else:
|
581 |
+
if self.rotary_emb_dim > 0:
|
582 |
+
qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
|
583 |
+
q = qkv[:, :, 0]
|
584 |
+
kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
|
585 |
+
# If we're processing the prompt, causal=None (use self.causal).
|
586 |
+
# If we're decoding, then causal=False.
|
587 |
+
causal = None if past_cache.sequence_len_offset == 0 else False
|
588 |
+
context = self.inner_cross_attn(q, kv, causal=causal)
|
589 |
+
|
590 |
+
out = rearrange(context, "... h d -> ... (h d)")
|
591 |
+
out = self.out_proj(out)
|
592 |
+
|
593 |
+
return out if not self.return_residual else (out, x)
|
594 |
+
|
595 |
+
class ParallelBlock(nn.Module):
|
596 |
+
"""Parallel block.
|
597 |
+
|
598 |
+
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
|
599 |
+
|
600 |
+
"""
|
601 |
+
|
602 |
+
def __init__(
|
603 |
+
self,
|
604 |
+
config: PretrainedConfig,
|
605 |
+
mixer: Optional[Dict[str, Any]] = None,
|
606 |
+
mlp: Optional[Dict[str, Any]] = None,
|
607 |
+
block_idx: Optional[int] = None,
|
608 |
+
) -> None:
|
609 |
+
super().__init__()
|
610 |
+
|
611 |
+
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
612 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
613 |
+
self.block_idx = block_idx
|
614 |
+
|
615 |
+
self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
|
616 |
+
mlp_cls = mlp.pop('mlp_cls')
|
617 |
+
if mlp_cls == 'fused_mlp':
|
618 |
+
self.mlp = FusedMLP(config=config, **mlp)
|
619 |
+
else:
|
620 |
+
self.mlp = MLP(config=config, **mlp)
|
621 |
+
|
622 |
+
def forward(self, hidden_states: torch.FloatTensor,
|
623 |
+
past_cache: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
624 |
+
residual = hidden_states
|
625 |
+
hidden_states = self.ln(hidden_states)
|
626 |
+
|
627 |
+
attn_outputs = self.mixer(hidden_states, past_cache=past_cache)
|
628 |
+
if isinstance(attn_outputs, tuple):
|
629 |
+
attn_outputs = attn_outputs[0]
|
630 |
+
|
631 |
+
attn_outputs = self.resid_dropout(attn_outputs)
|
632 |
+
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
633 |
+
|
634 |
+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
635 |
+
|
636 |
+
return hidden_states
|
637 |
+
|
638 |
+
class CausalLMHead(nn.Module):
|
639 |
+
"""Causal Language Modeling head.
|
640 |
+
|
641 |
+
Reference:
|
642 |
+
Improving Language Understanding by Generative Pre-Training.
|
643 |
+
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
644 |
+
|
645 |
+
"""
|
646 |
+
|
647 |
+
def __init__(self, config: PretrainedConfig) -> None:
|
648 |
+
super().__init__()
|
649 |
+
|
650 |
+
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
651 |
+
self.linear = nn.Linear(config.n_embd, config.vocab_size)
|
652 |
+
|
653 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
654 |
+
hidden_states = self.ln(hidden_states)
|
655 |
+
logits = self.linear(hidden_states).to(torch.float32)
|
656 |
+
|
657 |
+
return logits
|
658 |
+
|
659 |
+
|
660 |
+
class CausalLMLoss(nn.Module):
|
661 |
+
"""Causal Language Modeling loss.
|
662 |
+
|
663 |
+
Reference:
|
664 |
+
Improving Language Understanding by Generative Pre-Training.
|
665 |
+
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
666 |
+
|
667 |
+
"""
|
668 |
+
|
669 |
+
def __init__(self, shift_labels: Optional[bool] = True) -> None:
|
670 |
+
super().__init__()
|
671 |
+
|
672 |
+
self.shift_labels = shift_labels
|
673 |
+
self.loss_fct = nn.CrossEntropyLoss()
|
674 |
+
|
675 |
+
def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
|
676 |
+
if self.shift_labels:
|
677 |
+
logits = logits[..., :-1, :].contiguous()
|
678 |
+
labels = labels[..., 1:].contiguous()
|
679 |
+
|
680 |
+
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
681 |
+
|
682 |
+
return loss
|
683 |
+
|
684 |
+
class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
685 |
+
"""MixFormer (sequential for DeepSpeed) pre-trained model."""
|
686 |
+
|
687 |
+
config_class = MixFormerSequentialConfig
|
688 |
+
base_model_prefix = "transformer"
|
689 |
+
supports_gradient_checkpointing = True
|
690 |
+
|
691 |
+
def __init__(self, *inputs, **kwargs) -> None:
|
692 |
+
super().__init__(*inputs, **kwargs)
|
693 |
+
|
694 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs) -> Dict[str, Any]:
|
695 |
+
if "use_cache" in kwargs and not kwargs["use_cache"]:
|
696 |
+
return {"input_ids": input_ids}
|
697 |
+
|
698 |
+
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
699 |
+
past_key_values = InferenceParams(
|
700 |
+
max_batch_size=input_ids.shape[0],
|
701 |
+
max_sequence_len=self.config.n_positions,
|
702 |
+
sequence_len_offset=0,
|
703 |
+
batch_size_offset=0,
|
704 |
+
fused_ft_kernel=False,
|
705 |
+
key_value_memory_dict={},
|
706 |
+
)
|
707 |
+
else:
|
708 |
+
# assume past_key_values has cached all but last token in input_ids
|
709 |
+
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
|
710 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
711 |
+
|
712 |
+
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
|
713 |
+
|
714 |
+
|
715 |
+
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
716 |
+
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
|
717 |
+
|
718 |
+
_keys_to_ignore_on_load_missing = [""]
|
719 |
+
_keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
720 |
+
|
721 |
+
def __init__(self, config: MixFormerSequentialConfig) -> None:
|
722 |
+
super().__init__(config)
|
723 |
+
|
724 |
+
modules = [Embedding(config)]
|
725 |
+
block_config = config.architecture
|
726 |
+
|
727 |
+
if not isinstance(block_config, list):
|
728 |
+
block_config = [block_config for _ in range(config.n_layer)]
|
729 |
+
|
730 |
+
if config.n_layer != len(block_config):
|
731 |
+
config.n_layer = len(block_config)
|
732 |
+
|
733 |
+
for block_idx, block in enumerate(block_config):
|
734 |
+
# `block_cls` with `legacy` value is for backward compatibility
|
735 |
+
# `path` key is for backward compatibility
|
736 |
+
block = copy.deepcopy(block) or {"block_cls": "parallel"}
|
737 |
+
block_cls = block.pop("path", None) or block.pop("block_cls", None)
|
738 |
+
|
739 |
+
block["block_idx"] = block_idx
|
740 |
+
modules.append(ParallelBlock(config, **block))
|
741 |
+
|
742 |
+
modules.append(CausalLMHead(config))
|
743 |
+
|
744 |
+
self.layers = nn.Sequential(*modules)
|
745 |
+
self.loss = CausalLMLoss()
|
746 |
+
|
747 |
+
self.post_init()
|
748 |
+
|
749 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
750 |
+
return self.layers[0].wte
|
751 |
+
|
752 |
+
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
753 |
+
self.layers[0].wte = new_embeddings
|
754 |
+
|
755 |
+
def get_output_embeddings(self) -> nn.Linear:
|
756 |
+
return self.layers[-1].linear
|
757 |
+
|
758 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
759 |
+
self.layers[-1].linear = new_embeddings
|
760 |
+
|
761 |
+
def forward(
|
762 |
+
self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None,
|
763 |
+
past_key_values: Optional[torch.FloatTensor] = None, **kwargs
|
764 |
+
) -> CausalLMOutputWithPast:
|
765 |
+
|
766 |
+
if not past_key_values:
|
767 |
+
lm_logits = self.layers(input_ids)
|
768 |
+
else:
|
769 |
+
hidden_layer = self.layers[0](input_ids)
|
770 |
+
for module in self.layers[1:-1]:
|
771 |
+
hidden_layer = module(hidden_layer, past_cache=past_key_values)
|
772 |
+
lm_logits = self.layers[-1](hidden_layer)
|
773 |
+
|
774 |
+
loss = None
|
775 |
+
if labels is not None:
|
776 |
+
loss = self.loss(lm_logits, labels)
|
777 |
+
|
778 |
+
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
|
model.safetensors → pytorch_model.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eab6a12a9a2b78cac8f8975aea9f3a5e89ddadcb9e0dad27e40965e57e235a4a
|
3 |
+
size 2836623617
|