mkshing commited on
Commit
ceadefe
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MICROSOFT RESEARCH LICENSE TERMS
2
+
3
+ IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.
4
+
5
+ These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS.
6
+
7
+ 1) INSTALLATION AND USE RIGHTS TO THE MATERIALS.
8
+
9
+ Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes:
10
+
11
+ a) Source Code. If source code is included, you may use and modify the source code, but you may not distribute the source code.
12
+
13
+ b) Object Code. If object code is included, you may use the object code, but you may not distribute the object code.
14
+
15
+ c) Models. If machine learning model(s) are included, you may use the model(s), but you may not distribute the models.
16
+
17
+ d) Data. If data is included, you may use and modify the data, but your use and modification must be consistent with the consent under which the data was provided and/or gathered and you may not distribute the data or your modifications to the data.
18
+
19
+ 2) SCOPE OF LICENSE. The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to):
20
+
21
+ a) work around any technical limitations in the Materials that only allow you to use it in certain ways;
22
+
23
+ b) reverse engineer, decompile or disassemble the Materials;
24
+
25
+ c) remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials;
26
+
27
+ d) use the Materials in any way that is against the law or to create or propagate malware; or
28
+
29
+ e) share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party.
30
+
31
+ 3) PERSONAL DATA. If the data (set forth in Section 1(c) above) includes or is found to include any data that enables any ability to identify an individual (“Personal Data”), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, immediately upon the completion of your research.
32
+
33
+ 4) LICENSE TO MICROSOFT. Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose.
34
+
35
+ 5) PUBLICATION. You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation.
36
+
37
+ 6) FEEDBACK. Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the
38
+
39
+ feedback is designated by you as confidential. Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above.
40
+
41
+ 7) EXPORT RESTRICTIONS. You must comply with all domestic and international export laws and regulations that apply to the Materials, which include restrictions on destinations, end users, and end use. For further information on export restrictions, visit (aka.ms/exporting).
42
+
43
+ 8) SUPPORT SERVICES. Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
44
+
45
+ 9) BINDING ARBITRATION AND CLASS ACTION WAIVER. This Section applies if you live in (or, if a business, your principal place of business is in) the United States. If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to binding individual arbitration before the American Arbitration Association under the Federal Arbitration Act (“FAA”), and not to sue in court in front of a judge or jury. Instead, a neutral arbitrator will decide. Class action lawsuits, class-wide arbitrations, private attorney-general actions, and any other proceeding where someone acts in a representative capacity are not allowed; nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms.
46
+
47
+ 10) ENTIRE AGREEMENT. This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials.
48
+
49
+ 11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES. If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration).
50
+
51
+ 12) CONSUMER RIGHTS; REGIONAL VARIATIONS. This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you:
52
+
53
+ a) Australia. You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights.
54
+
55
+ b) Canada. If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software.
56
+
57
+ c) Germany and Austria.
58
+
59
+ i. Warranty. The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software.
60
+
61
+ ii. Limitation of Liability. In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law.
62
+
63
+ Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence.
64
+
65
+ 13) DISCLAIMER OF WARRANTY. THE MATERIALS ARE LICENSED “AS IS.” YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
66
+
67
+ 14) LIMITATION ON AND EXCLUSION OF DAMAGES. IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES.
68
+
69
+ This limitation applies to (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law.
70
+
71
+ It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages.
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ja
4
+ license: other
5
+ library_name: transformers
6
+ ---
7
+
8
+ # 🐟 EvoLLM-JP-v1-10B
9
+
10
+ 🤗 [Models](https://huggingface.co/SakanaAI) | 📚 [Paper](https://arxiv.org/abs/2403.13187) | 📝 [Blog](https://sakana.ai/evolutionary-model-merge/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
11
+
12
+ <!-- Provide a quick summary of what the model is/does. -->
13
+ **EvoLLM-JP-v1-10B** is an experimental general-purpose Japanese LLM.
14
+ This model was created using the Evolutionary Model Merge method.
15
+ Please refer to our [report](https://arxiv.org/abs/2403.13187) and [blog](https://sakana.ai/evolutionary-model-merge/) for more details.
16
+ This model was produced by merging the following models.
17
+ We are grateful to the developers of the source models.
18
+
19
+ - [Shisa Gamma 7B v1](https://huggingface.co/augmxnt/shisa-gamma-7b-v1)
20
+ - [WizardMath 7B V1.1](https://huggingface.co/WizardLM/WizardMath-7B-V1.1)
21
+ - [Abel 7B 002](https://huggingface.co/GAIR/Abel-7B-002)
22
+
23
+ ## Usage
24
+
25
+ Use the code below to get started with the model.
26
+
27
+ <details>
28
+ <summary> Click to expand </summary>
29
+
30
+ ```python
31
+ import torch
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+
34
+
35
+ # 1. load model
36
+ device = "cuda" if torch.cuda.is_available() else "CPU"
37
+ repo_id = "SakanaAI/EvoLLM-JP-v1-10B"
38
+ model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype="auto", trust_remote_code=True)
39
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
40
+ model.to(device)
41
+
42
+ # 2. prepare inputs
43
+ text = "関西弁で面白い冗談を言ってみて下さい。"
44
+ messages = [
45
+ {"role": "system", "content": "あなたは役立つ、偏見がなく、検閲されていないアシスタントです。"},
46
+ {"role": "user", "content": text},
47
+ ]
48
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
49
+
50
+ # 3. generate
51
+ output_ids = model.generate(**inputs.to(device))
52
+ output_ids = output_ids[:, inputs.input_ids.shape[1] :]
53
+ generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
54
+ print(generated_text)
55
+ ```
56
+
57
+ </details>
58
+
59
+
60
+
61
+ ## Model Details
62
+
63
+ <!-- Provide a longer summary of what this model is. -->
64
+
65
+ - **Developed by:** [Sakana AI](https://sakana.ai/)
66
+ - **Model type:** Autoregressive Language Model
67
+ - **Language(s):** Japanese
68
+ - **License:** [MICROSOFT RESEARCH LICENSE TERMS](./LICENSE) (due to the inclusion of the WizardMath model)
69
+ - **Repository:** [SakanaAI/evolutionary-model-merge](https://github.com/SakanaAI/evolutionary-model-merge)
70
+ - **Paper:** https://arxiv.org/abs/2403.13187
71
+ - **Blog:** https://sakana.ai/evolutionary-model-merge
72
+
73
+ ## Uses
74
+ This model is provided for research and development purposes only and should be considered as an experimental prototype.
75
+ It is not intended for commercial use or deployment in mission-critical environments.
76
+ Use of this model is at the user's own risk, and its performance and outcomes are not guaranteed.
77
+ Sakana AI shall not be liable for any direct, indirect, special, incidental, or consequential damages, or any loss arising from the use of this model, regardless of the results obtained.
78
+ Users must fully understand the risks associated with the use of this model and use it at their own discretion.
79
+
80
+
81
+ ## Acknowledgement
82
+
83
+ We would like to thank the developers of the source models for their contributions and for making their work available.
84
+
85
+
86
+ ## Citation
87
+
88
+ ```bibtex
89
+ @misc{akiba2024evomodelmerge,
90
+ title = {Evolutionary Optimization of Model Merging Recipes},
91
+ author. = {Takuya Akiba and Makoto Shing and Yujin Tang and Qi Sun and David Ha},
92
+ year = {2024},
93
+ eprint = {2403.13187},
94
+ archivePrefix = {arXiv},
95
+ primaryClass = {cs.NE}
96
+ }
97
+ ```
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "SakanaAI/EvoLLM-v1-JP-10B",
3
+ "architectures": [
4
+ "EvoMistralForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "SakanaAI/EvoLLM-v1-JP-10B--configuration_evomistral.EvoMistralConfig",
9
+ "AutoModelForCausalLM": "SakanaAI/EvoLLM-v1-JP-10B--modeling_evomistral.EvoMistralForCausalLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "eos_token_id": 2,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 4096,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 14336,
17
+ "max_position_embeddings": 32768,
18
+ "model_type": "evomistral",
19
+ "num_attention_heads": 32,
20
+ "num_hidden_layers": 44,
21
+ "num_hops": 65,
22
+ "num_key_value_heads": 8,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_theta": 10000.0,
25
+ "sliding_window": 4096,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.38.2",
29
+ "use_cache": false,
30
+ "vocab_size": 32000
31
+ }
configuration_evomistral.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.mistral.configuration_mistral import MistralConfig
2
+
3
+
4
+ class EvoMistralConfig(MistralConfig):
5
+ model_type = "evomistral"
6
+
7
+ def __init__(self, num_hops: int = 64, **kwargs):
8
+ self.num_hops = num_hops
9
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": 2,
4
+ "max_new_tokens": 1024,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.38.2"
7
+ }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0eb2cb18f8988039b650ffb45521542e37370d99db71825a3d48b289376e26ef
3
+ size 4943163008
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df681d4220f14d4a77a6af7b069e59eadee4b48703695f016f8f7bf63b9a9442
3
+ size 4999819336
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8f1f60e6243cd8dfa1d43777c2e0c56dd7bbe2e2078e40d29121b98d29de53d
3
+ size 4915916184
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0be58a55d55d6560d2f924ae563d3e1cc32a83c0295f3be82230040bfd88aa3f
3
+ size 4859300760
model.safetensors.index.json ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 19718152712
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00004-of-00004.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
8
+ "model.input_layers": "model-00001-of-00004.safetensors",
9
+ "model.input_scales": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
29
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
30
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
31
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
32
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
33
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
34
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
35
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
36
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
37
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
38
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
41
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
42
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
49
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
53
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
55
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
62
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
65
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
74
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
77
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
79
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
86
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
88
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
89
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
90
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
91
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
92
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
98
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
100
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
101
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
102
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
103
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
105
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
110
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
112
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
113
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
114
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
115
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
116
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
118
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
119
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
120
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
121
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
122
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
123
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
124
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
125
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
126
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
127
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
128
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
129
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
130
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
131
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
132
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
133
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
134
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
135
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
136
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00004.safetensors",
137
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
138
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
139
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
140
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
141
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
142
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
143
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
144
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
145
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
146
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
147
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
148
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
149
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
150
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
151
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
152
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
153
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
154
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
155
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
156
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
157
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
158
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
159
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
160
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
161
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
162
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
163
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
164
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
165
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
166
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
167
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
168
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
169
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
170
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
171
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
172
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
173
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
174
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
175
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
176
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
177
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
178
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
179
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
180
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
181
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
182
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
183
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
184
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
185
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
186
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
187
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
190
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
193
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
194
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
196
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
197
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
198
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
199
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
200
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
204
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
206
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
209
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
210
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
211
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
213
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
214
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
215
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
217
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
218
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
219
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
220
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
221
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
222
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
223
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
224
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
225
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
226
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
228
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
230
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
231
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
233
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
235
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
238
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
240
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
241
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
242
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
244
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
245
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
246
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
247
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
252
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.33.input_layernorm.weight": "model-00004-of-00004.safetensors",
254
+ "model.layers.33.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
255
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
256
+ "model.layers.33.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
257
+ "model.layers.33.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
258
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
259
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
261
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
262
+ "model.layers.34.input_layernorm.weight": "model-00004-of-00004.safetensors",
263
+ "model.layers.34.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
264
+ "model.layers.34.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
265
+ "model.layers.34.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
266
+ "model.layers.34.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
267
+ "model.layers.34.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
268
+ "model.layers.34.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
269
+ "model.layers.34.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
270
+ "model.layers.34.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
271
+ "model.layers.35.input_layernorm.weight": "model-00004-of-00004.safetensors",
272
+ "model.layers.35.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
273
+ "model.layers.35.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
274
+ "model.layers.35.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
275
+ "model.layers.35.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
276
+ "model.layers.35.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
277
+ "model.layers.35.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
278
+ "model.layers.35.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
279
+ "model.layers.35.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
280
+ "model.layers.36.input_layernorm.weight": "model-00004-of-00004.safetensors",
281
+ "model.layers.36.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
282
+ "model.layers.36.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
283
+ "model.layers.36.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
284
+ "model.layers.36.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
285
+ "model.layers.36.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
286
+ "model.layers.36.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
287
+ "model.layers.36.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
288
+ "model.layers.36.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
289
+ "model.layers.37.input_layernorm.weight": "model-00004-of-00004.safetensors",
290
+ "model.layers.37.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
291
+ "model.layers.37.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
292
+ "model.layers.37.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
293
+ "model.layers.37.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
294
+ "model.layers.37.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
295
+ "model.layers.37.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
296
+ "model.layers.37.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
297
+ "model.layers.37.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
298
+ "model.layers.38.input_layernorm.weight": "model-00004-of-00004.safetensors",
299
+ "model.layers.38.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
300
+ "model.layers.38.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
301
+ "model.layers.38.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
302
+ "model.layers.38.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
303
+ "model.layers.38.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
304
+ "model.layers.38.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
305
+ "model.layers.38.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
306
+ "model.layers.38.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
307
+ "model.layers.39.input_layernorm.weight": "model-00004-of-00004.safetensors",
308
+ "model.layers.39.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
309
+ "model.layers.39.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
310
+ "model.layers.39.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
311
+ "model.layers.39.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
312
+ "model.layers.39.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
313
+ "model.layers.39.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
314
+ "model.layers.39.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
315
+ "model.layers.39.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
316
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
317
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
318
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
319
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
320
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
321
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
322
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
323
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
324
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
325
+ "model.layers.40.input_layernorm.weight": "model-00004-of-00004.safetensors",
326
+ "model.layers.40.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
327
+ "model.layers.40.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
328
+ "model.layers.40.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
329
+ "model.layers.40.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
330
+ "model.layers.40.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
331
+ "model.layers.40.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
332
+ "model.layers.40.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
333
+ "model.layers.40.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
334
+ "model.layers.41.input_layernorm.weight": "model-00004-of-00004.safetensors",
335
+ "model.layers.41.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
336
+ "model.layers.41.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
337
+ "model.layers.41.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
338
+ "model.layers.41.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
339
+ "model.layers.41.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
340
+ "model.layers.41.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
341
+ "model.layers.41.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
342
+ "model.layers.41.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
343
+ "model.layers.42.input_layernorm.weight": "model-00004-of-00004.safetensors",
344
+ "model.layers.42.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
345
+ "model.layers.42.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
346
+ "model.layers.42.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
347
+ "model.layers.42.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
348
+ "model.layers.42.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
349
+ "model.layers.42.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
350
+ "model.layers.42.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
351
+ "model.layers.42.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
352
+ "model.layers.43.input_layernorm.weight": "model-00004-of-00004.safetensors",
353
+ "model.layers.43.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
354
+ "model.layers.43.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
355
+ "model.layers.43.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
356
+ "model.layers.43.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
357
+ "model.layers.43.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
358
+ "model.layers.43.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
359
+ "model.layers.43.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
360
+ "model.layers.43.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
361
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
362
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
363
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
364
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
365
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
366
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
367
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
368
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
369
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
370
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
371
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
372
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
373
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
374
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
375
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
376
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
377
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
378
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
379
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
380
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
381
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
382
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
383
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
384
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
385
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
386
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
387
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
388
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
389
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
390
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
391
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
392
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
393
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
394
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
395
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
396
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
397
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00004.safetensors",
398
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
399
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
400
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
401
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
402
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
403
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
404
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
405
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
406
+ "model.norm.weight": "model-00004-of-00004.safetensors"
407
+ }
408
+ }
modeling_evomistral.py ADDED
@@ -0,0 +1,1379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Mistral model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ SequenceClassifierOutputWithPast,
38
+ )
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.utils import (
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ is_flash_attn_2_available,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_evomistral import EvoMistralConfig
48
+
49
+ if is_flash_attn_2_available():
50
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
51
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52
+
53
+ _flash_supports_window_size = "window_size" in list(
54
+ inspect.signature(flash_attn_func).parameters
55
+ )
56
+
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+ _CONFIG_FOR_DOC = "MistralConfig"
61
+
62
+
63
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
64
+ def _get_unpad_data(attention_mask):
65
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
66
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
67
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
68
+ cu_seqlens = F.pad(
69
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
70
+ )
71
+ return (
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ )
76
+
77
+
78
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
79
+ class MistralRMSNorm(nn.Module):
80
+ def __init__(self, hidden_size, eps=1e-6):
81
+ """
82
+ MistralRMSNorm is equivalent to T5LayerNorm
83
+ """
84
+ super().__init__()
85
+ self.weight = nn.Parameter(torch.ones(hidden_size))
86
+ self.variance_epsilon = eps
87
+
88
+ def forward(self, hidden_states, residual=None):
89
+ input_dtype = hidden_states.dtype
90
+ hidden_states = hidden_states.to(torch.float32)
91
+ if residual is not None:
92
+ hidden_states = hidden_states + residual.to(torch.float32)
93
+ residual = hidden_states.to(input_dtype)
94
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
95
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
96
+ hidden_states = self.weight * hidden_states.to(input_dtype)
97
+ if residual is None:
98
+ return hidden_states
99
+ else:
100
+ return hidden_states, residual
101
+
102
+
103
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
104
+ class MistralRotaryEmbedding(nn.Module):
105
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
106
+ super().__init__()
107
+
108
+ self.dim = dim
109
+ self.max_position_embeddings = max_position_embeddings
110
+ self.base = base
111
+ inv_freq = 1.0 / (
112
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
113
+ )
114
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
115
+
116
+ # Build here to make `torch.jit.trace` work.
117
+ self._set_cos_sin_cache(
118
+ seq_len=max_position_embeddings,
119
+ device=self.inv_freq.device,
120
+ dtype=torch.get_default_dtype(),
121
+ )
122
+
123
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
124
+ self.max_seq_len_cached = seq_len
125
+ t = torch.arange(
126
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
127
+ )
128
+
129
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
130
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
131
+ emb = torch.cat((freqs, freqs), dim=-1)
132
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
133
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
134
+
135
+ def forward(self, x, seq_len=None):
136
+ # x: [bs, num_attention_heads, seq_len, head_size]
137
+ if seq_len > self.max_seq_len_cached:
138
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
139
+
140
+ return (
141
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
142
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
143
+ )
144
+
145
+
146
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
147
+ def rotate_half(x):
148
+ """Rotates half the hidden dims of the input."""
149
+ x1 = x[..., : x.shape[-1] // 2]
150
+ x2 = x[..., x.shape[-1] // 2 :]
151
+ return torch.cat((-x2, x1), dim=-1)
152
+
153
+
154
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
155
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
156
+ """Applies Rotary Position Embedding to the query and key tensors.
157
+
158
+ Args:
159
+ q (`torch.Tensor`): The query tensor.
160
+ k (`torch.Tensor`): The key tensor.
161
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
162
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
163
+ position_ids (`torch.Tensor`):
164
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
165
+ used to pass offsetted position ids when working with a KV-cache.
166
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
167
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
168
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
169
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
170
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
171
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
172
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
173
+ Returns:
174
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
175
+ """
176
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
177
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
178
+ q_embed = (q * cos) + (rotate_half(q) * sin)
179
+ k_embed = (k * cos) + (rotate_half(k) * sin)
180
+ return q_embed, k_embed
181
+
182
+
183
+ class MistralMLP(nn.Module):
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.config = config
187
+ self.hidden_size = config.hidden_size
188
+ self.intermediate_size = config.intermediate_size
189
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
190
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
191
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
192
+ self.act_fn = ACT2FN[config.hidden_act]
193
+
194
+ def forward(self, x):
195
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
196
+
197
+
198
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
199
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
200
+ """
201
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
202
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
203
+ """
204
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
205
+ if n_rep == 1:
206
+ return hidden_states
207
+ hidden_states = hidden_states[:, :, None, :, :].expand(
208
+ batch, num_key_value_heads, n_rep, slen, head_dim
209
+ )
210
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
211
+
212
+
213
+ class MistralAttention(nn.Module):
214
+ """
215
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
216
+ and "Generating Long Sequences with Sparse Transformers".
217
+ """
218
+
219
+ def __init__(self, config: EvoMistralConfig):
220
+ super().__init__()
221
+ self.config = config
222
+ self.hidden_size = config.hidden_size
223
+ self.num_heads = config.num_attention_heads
224
+ self.head_dim = self.hidden_size // self.num_heads
225
+ self.num_key_value_heads = config.num_key_value_heads
226
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
227
+ self.max_position_embeddings = config.max_position_embeddings
228
+ self.rope_theta = config.rope_theta
229
+ self.is_causal = True
230
+
231
+ if (self.head_dim * self.num_heads) != self.hidden_size:
232
+ raise ValueError(
233
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
234
+ f" and `num_heads`: {self.num_heads})."
235
+ )
236
+ self.q_proj = nn.Linear(
237
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
238
+ )
239
+ self.k_proj = nn.Linear(
240
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
241
+ )
242
+ self.v_proj = nn.Linear(
243
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
244
+ )
245
+ self.o_proj = nn.Linear(
246
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
247
+ )
248
+
249
+ self.rotary_emb = MistralRotaryEmbedding(
250
+ self.head_dim,
251
+ max_position_embeddings=self.max_position_embeddings,
252
+ base=self.rope_theta,
253
+ )
254
+
255
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
256
+ return (
257
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
258
+ .transpose(1, 2)
259
+ .contiguous()
260
+ )
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.Tensor,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ position_ids: Optional[torch.LongTensor] = None,
267
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
268
+ output_attentions: bool = False,
269
+ use_cache: bool = False,
270
+ **kwargs,
271
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
272
+ if "padding_mask" in kwargs:
273
+ warnings.warn(
274
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
275
+ )
276
+ bsz, q_len, _ = hidden_states.size()
277
+
278
+ query_states = self.q_proj(hidden_states)
279
+ key_states = self.k_proj(hidden_states)
280
+ value_states = self.v_proj(hidden_states)
281
+
282
+ query_states = query_states.view(
283
+ bsz, q_len, self.num_heads, self.head_dim
284
+ ).transpose(1, 2)
285
+ key_states = key_states.view(
286
+ bsz, q_len, self.num_key_value_heads, self.head_dim
287
+ ).transpose(1, 2)
288
+ value_states = value_states.view(
289
+ bsz, q_len, self.num_key_value_heads, self.head_dim
290
+ ).transpose(1, 2)
291
+
292
+ kv_seq_len = key_states.shape[-2]
293
+ if past_key_value is not None:
294
+ kv_seq_len += past_key_value[0].shape[-2]
295
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
296
+ query_states, key_states = apply_rotary_pos_emb(
297
+ query_states, key_states, cos, sin, position_ids
298
+ )
299
+
300
+ if past_key_value is not None:
301
+ # reuse k, v, self_attention
302
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
303
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
304
+
305
+ past_key_value = (key_states, value_states) if use_cache else None
306
+
307
+ # repeat k/v heads if n_kv_heads < n_heads
308
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
309
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
310
+
311
+ attn_weights = torch.matmul(
312
+ query_states, key_states.transpose(2, 3)
313
+ ) / math.sqrt(self.head_dim)
314
+
315
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
316
+ raise ValueError(
317
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
318
+ f" {attn_weights.size()}"
319
+ )
320
+
321
+ if attention_mask is not None:
322
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
323
+ raise ValueError(
324
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
325
+ )
326
+
327
+ attn_weights = attn_weights + attention_mask
328
+
329
+ # upcast attention to fp32
330
+ attn_weights = nn.functional.softmax(
331
+ attn_weights, dim=-1, dtype=torch.float32
332
+ ).to(query_states.dtype)
333
+ attn_output = torch.matmul(attn_weights, value_states)
334
+
335
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
336
+ raise ValueError(
337
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
338
+ f" {attn_output.size()}"
339
+ )
340
+
341
+ attn_output = attn_output.transpose(1, 2).contiguous()
342
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
343
+
344
+ attn_output = self.o_proj(attn_output)
345
+
346
+ if not output_attentions:
347
+ attn_weights = None
348
+
349
+ return attn_output, attn_weights, past_key_value
350
+
351
+
352
+ class MistralFlashAttention2(MistralAttention):
353
+ """
354
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
355
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
356
+ flash attention and deal with padding tokens in case the input contains any of them.
357
+ """
358
+
359
+ def forward(
360
+ self,
361
+ hidden_states: torch.Tensor,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ position_ids: Optional[torch.LongTensor] = None,
364
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
365
+ output_attentions: bool = False,
366
+ use_cache: bool = False,
367
+ **kwargs,
368
+ ):
369
+ if "padding_mask" in kwargs:
370
+ warnings.warn(
371
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
372
+ )
373
+
374
+ # overwrite attention_mask with padding_mask
375
+ attention_mask = kwargs.pop("padding_mask")
376
+ bsz, q_len, _ = hidden_states.size()
377
+
378
+ query_states = self.q_proj(hidden_states)
379
+ key_states = self.k_proj(hidden_states)
380
+ value_states = self.v_proj(hidden_states)
381
+
382
+ query_states = query_states.view(
383
+ bsz, q_len, self.num_heads, self.head_dim
384
+ ).transpose(1, 2)
385
+ key_states = key_states.view(
386
+ bsz, q_len, self.num_key_value_heads, self.head_dim
387
+ ).transpose(1, 2)
388
+ value_states = value_states.view(
389
+ bsz, q_len, self.num_key_value_heads, self.head_dim
390
+ ).transpose(1, 2)
391
+
392
+ kv_seq_len = key_states.shape[-2]
393
+ if past_key_value is not None:
394
+ kv_seq_len += past_key_value[0].shape[-2]
395
+
396
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
397
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
398
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
399
+
400
+ query_states, key_states = apply_rotary_pos_emb(
401
+ query_states, key_states, cos, sin, position_ids
402
+ )
403
+
404
+ use_sliding_windows = (
405
+ _flash_supports_window_size
406
+ and hasattr(self.config, "sliding_window") is not None
407
+ and kv_seq_len > self.config.sliding_window
408
+ )
409
+
410
+ if not _flash_supports_window_size:
411
+ logger.warning_once(
412
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
413
+ " make sure to upgrade flash-attn library."
414
+ )
415
+
416
+ if past_key_value is not None:
417
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
418
+ if (
419
+ hasattr(self.config, "sliding_window")
420
+ and kv_seq_len > self.config.sliding_window
421
+ ):
422
+ slicing_tokens = kv_seq_len - self.config.sliding_window
423
+
424
+ past_key = past_key_value[0]
425
+ past_value = past_key_value[1]
426
+
427
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
428
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
429
+
430
+ if past_key.shape[-2] != self.config.sliding_window - 1:
431
+ raise ValueError(
432
+ f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
433
+ f" {past_key.shape}"
434
+ )
435
+
436
+ past_key_value = (past_key, past_value)
437
+
438
+ if attention_mask is not None:
439
+ attention_mask = attention_mask[:, slicing_tokens:]
440
+ attention_mask = torch.cat(
441
+ [attention_mask, torch.ones_like(attention_mask[:, -1:])],
442
+ dim=-1,
443
+ )
444
+
445
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
446
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
447
+
448
+ past_key_value = (key_states, value_states) if use_cache else None
449
+
450
+ # repeat k/v heads if n_kv_heads < n_heads
451
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
452
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
453
+
454
+ # TODO: Mistral does not have dropout in the config??
455
+ # It is recommended to use dropout with FA according to the docs
456
+ # when training.
457
+ dropout_rate = 0.0 # if not self.training else self.attn_dropout
458
+
459
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
460
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
461
+ # cast them back in float16 just to be sure everything works as expected.
462
+ input_dtype = query_states.dtype
463
+ if input_dtype == torch.float32:
464
+ # Handle the case where the model is quantized
465
+ if hasattr(self.config, "_pre_quantization_dtype"):
466
+ target_dtype = self.config._pre_quantization_dtype
467
+ else:
468
+ target_dtype = self.q_proj.weight.dtype
469
+
470
+ logger.warning_once(
471
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
472
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
473
+ f" {target_dtype}."
474
+ )
475
+
476
+ query_states = query_states.to(target_dtype)
477
+ key_states = key_states.to(target_dtype)
478
+ value_states = value_states.to(target_dtype)
479
+
480
+ # Reashape to the expected shape for Flash Attention
481
+ query_states = query_states.transpose(1, 2)
482
+ key_states = key_states.transpose(1, 2)
483
+ value_states = value_states.transpose(1, 2)
484
+
485
+ attn_output = self._flash_attention_forward(
486
+ query_states,
487
+ key_states,
488
+ value_states,
489
+ attention_mask,
490
+ q_len,
491
+ dropout=dropout_rate,
492
+ use_sliding_windows=use_sliding_windows,
493
+ )
494
+
495
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
496
+ attn_output = self.o_proj(attn_output)
497
+
498
+ if not output_attentions:
499
+ attn_weights = None
500
+
501
+ return attn_output, attn_weights, past_key_value
502
+
503
+ def _flash_attention_forward(
504
+ self,
505
+ query_states,
506
+ key_states,
507
+ value_states,
508
+ attention_mask,
509
+ query_length,
510
+ dropout=0.0,
511
+ softmax_scale=None,
512
+ use_sliding_windows=False,
513
+ ):
514
+ """
515
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
516
+ first unpad the input, then computes the attention scores and pad the final attention scores.
517
+
518
+ Args:
519
+ query_states (`torch.Tensor`):
520
+ Input query states to be passed to Flash Attention API
521
+ key_states (`torch.Tensor`):
522
+ Input key states to be passed to Flash Attention API
523
+ value_states (`torch.Tensor`):
524
+ Input value states to be passed to Flash Attention API
525
+ attention_mask (`torch.Tensor`):
526
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
527
+ position of padding tokens and 1 for the position of non-padding tokens.
528
+ dropout (`int`, *optional*):
529
+ Attention dropout
530
+ softmax_scale (`float`, *optional*):
531
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
532
+ use_sliding_windows (`bool`, *optional*):
533
+ Whether to activate sliding window attention.
534
+ """
535
+ # Contains at least one padding token in the sequence
536
+ if attention_mask is not None:
537
+ batch_size = query_states.shape[0]
538
+ (
539
+ query_states,
540
+ key_states,
541
+ value_states,
542
+ indices_q,
543
+ cu_seq_lens,
544
+ max_seq_lens,
545
+ ) = self._upad_input(
546
+ query_states, key_states, value_states, attention_mask, query_length
547
+ )
548
+
549
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
550
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
551
+
552
+ if not use_sliding_windows:
553
+ attn_output_unpad = flash_attn_varlen_func(
554
+ query_states,
555
+ key_states,
556
+ value_states,
557
+ cu_seqlens_q=cu_seqlens_q,
558
+ cu_seqlens_k=cu_seqlens_k,
559
+ max_seqlen_q=max_seqlen_in_batch_q,
560
+ max_seqlen_k=max_seqlen_in_batch_k,
561
+ dropout_p=dropout,
562
+ softmax_scale=softmax_scale,
563
+ causal=self.is_causal,
564
+ )
565
+ else:
566
+ attn_output_unpad = flash_attn_varlen_func(
567
+ query_states,
568
+ key_states,
569
+ value_states,
570
+ cu_seqlens_q=cu_seqlens_q,
571
+ cu_seqlens_k=cu_seqlens_k,
572
+ max_seqlen_q=max_seqlen_in_batch_q,
573
+ max_seqlen_k=max_seqlen_in_batch_k,
574
+ dropout_p=dropout,
575
+ softmax_scale=softmax_scale,
576
+ causal=self.is_causal,
577
+ window_size=(
578
+ self.config.sliding_window,
579
+ self.config.sliding_window,
580
+ ),
581
+ )
582
+
583
+ attn_output = pad_input(
584
+ attn_output_unpad, indices_q, batch_size, query_length
585
+ )
586
+ else:
587
+ if not use_sliding_windows:
588
+ attn_output = flash_attn_func(
589
+ query_states,
590
+ key_states,
591
+ value_states,
592
+ dropout,
593
+ softmax_scale=softmax_scale,
594
+ causal=self.is_causal,
595
+ )
596
+ else:
597
+ attn_output = flash_attn_func(
598
+ query_states,
599
+ key_states,
600
+ value_states,
601
+ dropout,
602
+ softmax_scale=softmax_scale,
603
+ causal=self.is_causal,
604
+ window_size=(
605
+ self.config.sliding_window,
606
+ self.config.sliding_window,
607
+ ),
608
+ )
609
+
610
+ return attn_output
611
+
612
+ def _upad_input(
613
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
614
+ ):
615
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
616
+
617
+ # On the first iteration we need to properly re-create the padding mask
618
+ # by slicing it on the proper place
619
+ if kv_seq_len != attention_mask.shape[-1]:
620
+ attention_mask_num_tokens = attention_mask.shape[-1]
621
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
622
+
623
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
624
+
625
+ key_layer = index_first_axis(
626
+ key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
627
+ )
628
+ value_layer = index_first_axis(
629
+ value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
630
+ )
631
+
632
+ if query_length == kv_seq_len:
633
+ query_layer = index_first_axis(
634
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
635
+ indices_k,
636
+ )
637
+ cu_seqlens_q = cu_seqlens_k
638
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
639
+ indices_q = indices_k
640
+ elif query_length == 1:
641
+ max_seqlen_in_batch_q = 1
642
+ cu_seqlens_q = torch.arange(
643
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
644
+ ) # There is a memcpy here, that is very bad.
645
+ indices_q = cu_seqlens_q[:-1]
646
+ query_layer = query_layer.squeeze(1)
647
+ else:
648
+ # The -q_len: slice assumes left padding.
649
+ attention_mask = attention_mask[:, -query_length:]
650
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
651
+ query_layer, attention_mask
652
+ )
653
+
654
+ return (
655
+ query_layer,
656
+ key_layer,
657
+ value_layer,
658
+ indices_q,
659
+ (cu_seqlens_q, cu_seqlens_k),
660
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
661
+ )
662
+
663
+
664
+ class MistralDecoderLayer(nn.Module):
665
+ def __init__(self, config: EvoMistralConfig):
666
+ super().__init__()
667
+ self.hidden_size = config.hidden_size
668
+ self.self_attn = (
669
+ MistralAttention(config=config)
670
+ if not getattr(config, "_flash_attn_2_enabled", False)
671
+ else MistralFlashAttention2(config)
672
+ )
673
+ self.mlp = MistralMLP(config)
674
+ self.input_layernorm = MistralRMSNorm(
675
+ config.hidden_size, eps=config.rms_norm_eps
676
+ )
677
+ self.post_attention_layernorm = MistralRMSNorm(
678
+ config.hidden_size, eps=config.rms_norm_eps
679
+ )
680
+
681
+ def forward(
682
+ self,
683
+ hidden_states: torch.Tensor,
684
+ attention_mask: Optional[torch.Tensor] = None,
685
+ position_ids: Optional[torch.LongTensor] = None,
686
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
687
+ output_attentions: Optional[bool] = False,
688
+ use_cache: Optional[bool] = False,
689
+ residual: Optional[torch.Tensor] = None,
690
+ **kwargs,
691
+ ) -> Tuple[
692
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
693
+ ]:
694
+ if "padding_mask" in kwargs:
695
+ warnings.warn(
696
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
697
+ )
698
+ """
699
+ Args:
700
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
701
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
702
+ `(batch, sequence_length)` where padding elements are indicated by 0.
703
+ output_attentions (`bool`, *optional*):
704
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
705
+ returned tensors for more detail.
706
+ use_cache (`bool`, *optional*):
707
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
708
+ (see `past_key_values`).
709
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
710
+ """
711
+ if residual is None:
712
+ residual = hidden_states
713
+ hidden_states = self.input_layernorm(hidden_states)
714
+ else:
715
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
716
+
717
+ # Self Attention
718
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
719
+ hidden_states=hidden_states,
720
+ attention_mask=attention_mask,
721
+ position_ids=position_ids,
722
+ past_key_value=past_key_value,
723
+ output_attentions=output_attentions,
724
+ use_cache=use_cache,
725
+ )
726
+
727
+ # Fully Connected
728
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
729
+ hidden_states = self.mlp(hidden_states)
730
+
731
+ outputs = ((hidden_states, residual),)
732
+
733
+ if output_attentions:
734
+ outputs += (self_attn_weights,)
735
+
736
+ if use_cache:
737
+ outputs += (present_key_value,)
738
+
739
+ return outputs
740
+
741
+
742
+ MISTRAL_START_DOCSTRING = r"""
743
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
744
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
745
+ etc.)
746
+
747
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
748
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
749
+ and behavior.
750
+
751
+ Parameters:
752
+ config ([`MistralConfig`]):
753
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
754
+ load the weights associated with the model, only the configuration. Check out the
755
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
756
+ """
757
+
758
+
759
+ @add_start_docstrings(
760
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
761
+ MISTRAL_START_DOCSTRING,
762
+ )
763
+ class MistralPreTrainedModel(PreTrainedModel):
764
+ config_class = EvoMistralConfig
765
+ base_model_prefix = "model"
766
+ supports_gradient_checkpointing = True
767
+ _no_split_modules = ["MistralDecoderLayer"]
768
+ _skip_keys_device_placement = "past_key_values"
769
+ _supports_flash_attn_2 = True
770
+
771
+ def _init_weights(self, module):
772
+ std = self.config.initializer_range
773
+ if isinstance(module, nn.Linear):
774
+ module.weight.data.normal_(mean=0.0, std=std)
775
+ if module.bias is not None:
776
+ module.bias.data.zero_()
777
+ elif isinstance(module, nn.Embedding):
778
+ module.weight.data.normal_(mean=0.0, std=std)
779
+ if module.padding_idx is not None:
780
+ module.weight.data[module.padding_idx].zero_()
781
+
782
+
783
+ MISTRAL_INPUTS_DOCSTRING = r"""
784
+ Args:
785
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
786
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
787
+ it.
788
+
789
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
790
+ [`PreTrainedTokenizer.__call__`] for details.
791
+
792
+ [What are input IDs?](../glossary#input-ids)
793
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
794
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
795
+
796
+ - 1 for tokens that are **not masked**,
797
+ - 0 for tokens that are **masked**.
798
+
799
+ [What are attention masks?](../glossary#attention-mask)
800
+
801
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
802
+ [`PreTrainedTokenizer.__call__`] for details.
803
+
804
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
805
+ `past_key_values`).
806
+
807
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
808
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
809
+ information on the default strategy.
810
+
811
+ - 1 indicates the head is **not masked**,
812
+ - 0 indicates the head is **masked**.
813
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
814
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
815
+ config.n_positions - 1]`.
816
+
817
+ [What are position IDs?](../glossary#position-ids)
818
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
819
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
820
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
821
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
822
+
823
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
824
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
825
+
826
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
827
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
828
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
829
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
830
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
831
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
832
+ model's internal embedding lookup matrix.
833
+ use_cache (`bool`, *optional*):
834
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
835
+ `past_key_values`).
836
+ output_attentions (`bool`, *optional*):
837
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
838
+ tensors for more detail.
839
+ output_hidden_states (`bool`, *optional*):
840
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
841
+ more detail.
842
+ return_dict (`bool`, *optional*):
843
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
844
+ """
845
+
846
+
847
+ @add_start_docstrings(
848
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
849
+ MISTRAL_START_DOCSTRING,
850
+ )
851
+ class EvoMistralModel(MistralPreTrainedModel):
852
+ """
853
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
854
+
855
+ Args:
856
+ config: MistralConfig
857
+ """
858
+
859
+ def __init__(self, config: EvoMistralConfig):
860
+ super().__init__(config)
861
+ self.padding_idx = config.pad_token_id
862
+ self.vocab_size = config.vocab_size
863
+
864
+ self.embed_tokens = nn.Embedding(
865
+ config.vocab_size, config.hidden_size, self.padding_idx
866
+ )
867
+ self.layers = nn.ModuleList(
868
+ [MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]
869
+ )
870
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
871
+
872
+ self.input_scales = nn.Parameter(
873
+ data=torch.zeros(self.config.num_hops).float(), requires_grad=False
874
+ )
875
+ self.input_layers = nn.Parameter(
876
+ data=torch.zeros(self.config.num_hops).int(), requires_grad=False
877
+ )
878
+
879
+ self.gradient_checkpointing = False
880
+ # Initialize weights and apply final processing
881
+ self.post_init()
882
+
883
+ def get_input_embeddings(self):
884
+ return self.embed_tokens
885
+
886
+ def set_input_embeddings(self, value):
887
+ self.embed_tokens = value
888
+
889
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
890
+ def forward(
891
+ self,
892
+ input_ids: torch.LongTensor = None,
893
+ attention_mask: Optional[torch.Tensor] = None,
894
+ position_ids: Optional[torch.LongTensor] = None,
895
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
896
+ inputs_embeds: Optional[torch.FloatTensor] = None,
897
+ use_cache: Optional[bool] = None,
898
+ output_attentions: Optional[bool] = None,
899
+ output_hidden_states: Optional[bool] = None,
900
+ return_dict: Optional[bool] = None,
901
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
902
+ output_attentions = (
903
+ output_attentions
904
+ if output_attentions is not None
905
+ else self.config.output_attentions
906
+ )
907
+ output_hidden_states = (
908
+ output_hidden_states
909
+ if output_hidden_states is not None
910
+ else self.config.output_hidden_states
911
+ )
912
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
913
+
914
+ return_dict = (
915
+ return_dict if return_dict is not None else self.config.use_return_dict
916
+ )
917
+
918
+ # retrieve input_ids and inputs_embeds
919
+ if input_ids is not None and inputs_embeds is not None:
920
+ raise ValueError(
921
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
922
+ )
923
+ elif input_ids is not None:
924
+ batch_size, seq_length = input_ids.shape
925
+ elif inputs_embeds is not None:
926
+ batch_size, seq_length, _ = inputs_embeds.shape
927
+ else:
928
+ raise ValueError(
929
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
930
+ )
931
+
932
+ seq_length_with_past = seq_length
933
+ past_key_values_length = 0
934
+
935
+ if past_key_values is not None:
936
+ past_key_values_length = past_key_values[0][0].shape[2]
937
+ seq_length_with_past = seq_length_with_past + past_key_values_length
938
+
939
+ if position_ids is None:
940
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
941
+ position_ids = torch.arange(
942
+ past_key_values_length,
943
+ seq_length + past_key_values_length,
944
+ dtype=torch.long,
945
+ device=device,
946
+ )
947
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
948
+ else:
949
+ position_ids = position_ids.view(-1, seq_length).long()
950
+
951
+ if inputs_embeds is None:
952
+ inputs_embeds = self.embed_tokens(input_ids)
953
+
954
+ if (
955
+ attention_mask is not None
956
+ and hasattr(self.config, "_flash_attn_2_enabled")
957
+ and self.config._flash_attn_2_enabled
958
+ and past_key_values is not None
959
+ ):
960
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
961
+ if is_padding_right:
962
+ raise ValueError(
963
+ "You are attempting to perform batched generation with padding_side='right'"
964
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
965
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
966
+ )
967
+
968
+ if getattr(self.config, "_flash_attn_2_enabled", False):
969
+ # 2d mask is passed through the layers
970
+ attention_mask = (
971
+ attention_mask
972
+ if (attention_mask is not None and 0 in attention_mask)
973
+ else None
974
+ )
975
+ else:
976
+ # 4d mask is passed through the layers
977
+ attention_mask = _prepare_4d_causal_attention_mask(
978
+ attention_mask,
979
+ (batch_size, seq_length),
980
+ inputs_embeds,
981
+ past_key_values_length,
982
+ sliding_window=self.config.sliding_window,
983
+ )
984
+
985
+ hidden_states = inputs_embeds
986
+
987
+ if self.gradient_checkpointing and self.training:
988
+ if use_cache:
989
+ logger.warning_once(
990
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
991
+ )
992
+ use_cache = False
993
+
994
+ # decoder layers
995
+ all_hidden_states = () if output_hidden_states else None
996
+ all_self_attns = () if output_attentions else None
997
+ next_decoder_cache = () if use_cache else None
998
+ residual = None
999
+
1000
+ for idx, layer_ix in enumerate(self.input_layers):
1001
+ decoder_layer = self.layers[layer_ix]
1002
+ scale = self.input_scales[idx].to(hidden_states.device)
1003
+
1004
+ if output_hidden_states:
1005
+ all_hidden_states += (hidden_states,)
1006
+
1007
+ past_key_value = (
1008
+ past_key_values[idx] if past_key_values is not None else None
1009
+ )
1010
+
1011
+ if self.gradient_checkpointing and self.training:
1012
+ layer_outputs = self._gradient_checkpointing_func(
1013
+ decoder_layer.__call__,
1014
+ hidden_states * scale,
1015
+ attention_mask,
1016
+ position_ids,
1017
+ past_key_value,
1018
+ output_attentions,
1019
+ use_cache,
1020
+ residual,
1021
+ )
1022
+ else:
1023
+ layer_outputs = decoder_layer(
1024
+ hidden_states * scale,
1025
+ attention_mask=attention_mask,
1026
+ position_ids=position_ids,
1027
+ past_key_value=past_key_value,
1028
+ output_attentions=output_attentions,
1029
+ use_cache=use_cache,
1030
+ residual=residual,
1031
+ )
1032
+
1033
+ hidden_states, residual = layer_outputs[0]
1034
+
1035
+ if use_cache:
1036
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1037
+
1038
+ if output_attentions:
1039
+ all_self_attns += (layer_outputs[1],)
1040
+
1041
+ hidden_states, _ = self.norm(hidden_states, residual)
1042
+
1043
+ # add hidden states from the last decoder layer
1044
+ if output_hidden_states:
1045
+ all_hidden_states += (hidden_states,)
1046
+
1047
+ next_cache = next_decoder_cache if use_cache else None
1048
+ if not return_dict:
1049
+ return tuple(
1050
+ v
1051
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1052
+ if v is not None
1053
+ )
1054
+ return BaseModelOutputWithPast(
1055
+ last_hidden_state=hidden_states,
1056
+ past_key_values=next_cache,
1057
+ hidden_states=all_hidden_states,
1058
+ attentions=all_self_attns,
1059
+ )
1060
+
1061
+
1062
+ class EvoMistralForCausalLM(MistralPreTrainedModel):
1063
+ _tied_weights_keys = ["lm_head.weight"]
1064
+
1065
+ def __init__(self, config):
1066
+ super().__init__(config)
1067
+ self.model = EvoMistralModel(config)
1068
+ self.vocab_size = config.vocab_size
1069
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1070
+
1071
+ # Initialize weights and apply final processing
1072
+ self.post_init()
1073
+
1074
+ def get_input_embeddings(self):
1075
+ return self.model.embed_tokens
1076
+
1077
+ def set_input_embeddings(self, value):
1078
+ self.model.embed_tokens = value
1079
+
1080
+ def get_output_embeddings(self):
1081
+ return self.lm_head
1082
+
1083
+ def set_output_embeddings(self, new_embeddings):
1084
+ self.lm_head = new_embeddings
1085
+
1086
+ def set_decoder(self, decoder):
1087
+ self.model = decoder
1088
+
1089
+ def get_decoder(self):
1090
+ return self.model
1091
+
1092
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1093
+ @replace_return_docstrings(
1094
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1095
+ )
1096
+ def forward(
1097
+ self,
1098
+ input_ids: torch.LongTensor = None,
1099
+ attention_mask: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.LongTensor] = None,
1101
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1102
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1103
+ labels: Optional[torch.LongTensor] = None,
1104
+ use_cache: Optional[bool] = None,
1105
+ output_attentions: Optional[bool] = None,
1106
+ output_hidden_states: Optional[bool] = None,
1107
+ return_dict: Optional[bool] = None,
1108
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1109
+ r"""
1110
+ Args:
1111
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1112
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1113
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1114
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1115
+
1116
+ Returns:
1117
+
1118
+ Example:
1119
+
1120
+ ```python
1121
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
1122
+
1123
+ >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1124
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1125
+
1126
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1127
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1128
+
1129
+ >>> # Generate
1130
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1131
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1132
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1133
+ ```"""
1134
+
1135
+ output_attentions = (
1136
+ output_attentions
1137
+ if output_attentions is not None
1138
+ else self.config.output_attentions
1139
+ )
1140
+ output_hidden_states = (
1141
+ output_hidden_states
1142
+ if output_hidden_states is not None
1143
+ else self.config.output_hidden_states
1144
+ )
1145
+ return_dict = (
1146
+ return_dict if return_dict is not None else self.config.use_return_dict
1147
+ )
1148
+
1149
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1150
+ outputs = self.model(
1151
+ input_ids=input_ids,
1152
+ attention_mask=attention_mask,
1153
+ position_ids=position_ids,
1154
+ past_key_values=past_key_values,
1155
+ inputs_embeds=inputs_embeds,
1156
+ use_cache=use_cache,
1157
+ output_attentions=output_attentions,
1158
+ output_hidden_states=output_hidden_states,
1159
+ return_dict=return_dict,
1160
+ )
1161
+
1162
+ hidden_states = outputs[0]
1163
+ logits = self.lm_head(hidden_states)
1164
+ logits = logits.float()
1165
+
1166
+ loss = None
1167
+ if labels is not None:
1168
+ # Shift so that tokens < n predict n
1169
+ shift_logits = logits[..., :-1, :].contiguous()
1170
+ shift_labels = labels[..., 1:].contiguous()
1171
+ # Flatten the tokens
1172
+ loss_fct = CrossEntropyLoss()
1173
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1174
+ shift_labels = shift_labels.view(-1)
1175
+ # Enable model parallelism
1176
+ shift_labels = shift_labels.to(shift_logits.device)
1177
+ loss = loss_fct(shift_logits, shift_labels)
1178
+
1179
+ if not return_dict:
1180
+ output = (logits,) + outputs[1:]
1181
+ return (loss,) + output if loss is not None else output
1182
+
1183
+ return CausalLMOutputWithPast(
1184
+ loss=loss,
1185
+ logits=logits,
1186
+ past_key_values=outputs.past_key_values,
1187
+ hidden_states=outputs.hidden_states,
1188
+ attentions=outputs.attentions,
1189
+ )
1190
+
1191
+ def prepare_inputs_for_generation(
1192
+ self,
1193
+ input_ids,
1194
+ past_key_values=None,
1195
+ attention_mask=None,
1196
+ inputs_embeds=None,
1197
+ **kwargs,
1198
+ ):
1199
+ # Omit tokens covered by past_key_values
1200
+ if past_key_values:
1201
+ past_length = past_key_values[0][0].shape[2]
1202
+
1203
+ # Some generation methods already pass only the last input ID
1204
+ if input_ids.shape[1] > past_length:
1205
+ remove_prefix_length = past_length
1206
+ else:
1207
+ # Default to old behavior: keep only final ID
1208
+ remove_prefix_length = input_ids.shape[1] - 1
1209
+
1210
+ input_ids = input_ids[:, remove_prefix_length:]
1211
+
1212
+ position_ids = kwargs.get("position_ids", None)
1213
+ if attention_mask is not None and position_ids is None:
1214
+ # create position_ids on the fly for batch generation
1215
+ position_ids = attention_mask.long().cumsum(-1) - 1
1216
+ position_ids.masked_fill_(attention_mask == 0, 1)
1217
+ if past_key_values:
1218
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1219
+
1220
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1221
+ if inputs_embeds is not None and past_key_values is None:
1222
+ model_inputs = {"inputs_embeds": inputs_embeds}
1223
+ else:
1224
+ model_inputs = {"input_ids": input_ids}
1225
+
1226
+ model_inputs.update(
1227
+ {
1228
+ "position_ids": position_ids,
1229
+ "past_key_values": past_key_values,
1230
+ "use_cache": kwargs.get("use_cache"),
1231
+ "attention_mask": attention_mask,
1232
+ }
1233
+ )
1234
+ return model_inputs
1235
+
1236
+ @staticmethod
1237
+ def _reorder_cache(past_key_values, beam_idx):
1238
+ reordered_past = ()
1239
+ for layer_past in past_key_values:
1240
+ reordered_past += (
1241
+ tuple(
1242
+ past_state.index_select(0, beam_idx.to(past_state.device))
1243
+ for past_state in layer_past
1244
+ ),
1245
+ )
1246
+ return reordered_past
1247
+
1248
+
1249
+ @add_start_docstrings(
1250
+ """
1251
+ The Mistral Model transformer with a sequence classification head on top (linear layer).
1252
+
1253
+ [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1254
+ (e.g. GPT-2) do.
1255
+
1256
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1257
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1258
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1259
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1260
+ each row of the batch).
1261
+ """,
1262
+ MISTRAL_START_DOCSTRING,
1263
+ )
1264
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
1265
+ class EvoMistralForSequenceClassification(MistralPreTrainedModel):
1266
+ def __init__(self, config):
1267
+ super().__init__(config)
1268
+ self.num_labels = config.num_labels
1269
+ self.model = EvoMistralModel(config)
1270
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1271
+
1272
+ # Initialize weights and apply final processing
1273
+ self.post_init()
1274
+
1275
+ def get_input_embeddings(self):
1276
+ return self.model.embed_tokens
1277
+
1278
+ def set_input_embeddings(self, value):
1279
+ self.model.embed_tokens = value
1280
+
1281
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1282
+ def forward(
1283
+ self,
1284
+ input_ids: torch.LongTensor = None,
1285
+ attention_mask: Optional[torch.Tensor] = None,
1286
+ position_ids: Optional[torch.LongTensor] = None,
1287
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1288
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1289
+ labels: Optional[torch.LongTensor] = None,
1290
+ use_cache: Optional[bool] = None,
1291
+ output_attentions: Optional[bool] = None,
1292
+ output_hidden_states: Optional[bool] = None,
1293
+ return_dict: Optional[bool] = None,
1294
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1295
+ r"""
1296
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1297
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1298
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1299
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1300
+ """
1301
+ return_dict = (
1302
+ return_dict if return_dict is not None else self.config.use_return_dict
1303
+ )
1304
+
1305
+ transformer_outputs = self.model(
1306
+ input_ids,
1307
+ attention_mask=attention_mask,
1308
+ position_ids=position_ids,
1309
+ past_key_values=past_key_values,
1310
+ inputs_embeds=inputs_embeds,
1311
+ use_cache=use_cache,
1312
+ output_attentions=output_attentions,
1313
+ output_hidden_states=output_hidden_states,
1314
+ return_dict=return_dict,
1315
+ )
1316
+ hidden_states = transformer_outputs[0]
1317
+ logits = self.score(hidden_states)
1318
+
1319
+ if input_ids is not None:
1320
+ batch_size = input_ids.shape[0]
1321
+ else:
1322
+ batch_size = inputs_embeds.shape[0]
1323
+
1324
+ if self.config.pad_token_id is None and batch_size != 1:
1325
+ raise ValueError(
1326
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1327
+ )
1328
+ if self.config.pad_token_id is None:
1329
+ sequence_lengths = -1
1330
+ else:
1331
+ if input_ids is not None:
1332
+ sequence_lengths = (
1333
+ torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1334
+ ).to(logits.device)
1335
+ else:
1336
+ sequence_lengths = -1
1337
+
1338
+ pooled_logits = logits[
1339
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1340
+ ]
1341
+
1342
+ loss = None
1343
+ if labels is not None:
1344
+ labels = labels.to(logits.device)
1345
+ if self.config.problem_type is None:
1346
+ if self.num_labels == 1:
1347
+ self.config.problem_type = "regression"
1348
+ elif self.num_labels > 1 and (
1349
+ labels.dtype == torch.long or labels.dtype == torch.int
1350
+ ):
1351
+ self.config.problem_type = "single_label_classification"
1352
+ else:
1353
+ self.config.problem_type = "multi_label_classification"
1354
+
1355
+ if self.config.problem_type == "regression":
1356
+ loss_fct = MSELoss()
1357
+ if self.num_labels == 1:
1358
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1359
+ else:
1360
+ loss = loss_fct(pooled_logits, labels)
1361
+ elif self.config.problem_type == "single_label_classification":
1362
+ loss_fct = CrossEntropyLoss()
1363
+ loss = loss_fct(
1364
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1365
+ )
1366
+ elif self.config.problem_type == "multi_label_classification":
1367
+ loss_fct = BCEWithLogitsLoss()
1368
+ loss = loss_fct(pooled_logits, labels)
1369
+ if not return_dict:
1370
+ output = (pooled_logits,) + transformer_outputs[1:]
1371
+ return ((loss,) + output) if loss is not None else output
1372
+
1373
+ return SequenceClassifierOutputWithPast(
1374
+ loss=loss,
1375
+ logits=pooled_logits,
1376
+ past_key_values=transformer_outputs.past_key_values,
1377
+ hidden_states=transformer_outputs.hidden_states,
1378
+ attentions=transformer_outputs.attentions,
1379
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "legacy": true,
35
+ "model_max_length": 1000000000000000019884624838656,
36
+ "pad_token": null,
37
+ "sp_model_kwargs": {},
38
+ "spaces_between_special_tokens": false,
39
+ "tokenizer_class": "LlamaTokenizer",
40
+ "unk_token": "<unk>",
41
+ "use_default_system_prompt": false
42
+ }