LICENSE.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DeciLM License TLDR (not to replace the full version):
2
+
3
+ 1. Deci.AI's DeciLM, derivative of the LLAMA 2 model, offers a vast suite of AI tools and code. This license balances user flexibility with Deci.AI's proprietary rights, aiming for collaborative growth in AI.
4
+ 2. License Access: Deci.AI provides a non-exclusive license for DeciLM Materials. Users can modify, use, and distribute these freely.
5
+ 3. Hosting Terms: "Hosting Use" refers to offering DeciLM Materials or derivative work as shared instances or managed services to third party users in an inference or finetuning API form. To engage in this, user must get a separate permission from Deci.
6
+ 4. Redistribution Rules: If users share DeciLM or its derivatives, they must include this license agreement.
7
+
8
+ © 2023 – Deci.AI, Ltd.
9
+
10
+ DeciLM License version 1.0
11
+ September 2023
12
+
13
+
14
+ These foundational large language models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Deci.AI, Ltd (“Deci”) at “Documentation” means the specifications, manuals and documentation accompanying DeciLM distributed Deci at [https://huggingface.co/Deci/DeciLM-6b-instruct](https://huggingface.co/Deci/DeciLM-6b-instruct) (the "Software") is licensed to you by Deci under the following terms:
15
+ This license is, in part, based on the LLAMA 2 Community License Agreement (available at https://ai.meta.com/llama/license/), with a series of modifications. Use of DeciLM for hosted services may require a separate license.
16
+
17
+
18
+ Please also note that DeciLM is a derivative of Llama 2, which is licensed under the LLAMA 2 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.
19
+
20
+
21
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the DeciLM set forth herein.
22
+ “Documentation” means the specifications, manuals and documentation accompanying DeciLM distributed Deci at [https://huggingface.co/Deci/DeciLM-6b-instruct](https://huggingface.co/Deci/DeciLM-6b-instruct)
23
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
24
+ “DeciLM” means the foundational large language models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Deci at “Documentation” means the specifications, manuals and documentation accompanying DeciLM distributed Deci at [https://huggingface.co/Deci/DeciLM-6b-instruct](https://huggingface.co/Deci/DeciLM-6b-instruct).
25
+ “DeciLM Materials” means, collectively, Deci’s proprietary DeciLM and Documentation (and any portion thereof) made available under this Agreement.
26
+ “Deci” or “we” means Deci.AI Ltd.
27
+ By clicking “I Accept” below or by using or distributing any portion or element of the DeciLM Materials, you agree to be bound by this Agreement.
28
+ “Hosting Use”” means any use of the DeciLM Materials or a derivative work to offer shared instances or managed services based on the DeciLM Materials or a derivative work (including fine-tuned versions of a Work or Derivative Work) to third party users in an inference or finetuning API form.
29
+
30
+
31
+ “Hosting User” means someone who has applied to make Hosting Use of the Work and been granted permission by the Licensor to make such Hosting Use subject to a separate licence agreement.
32
+ 1. License Rights and Redistribution.
33
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and limited license under Deci’s intellectual property or other rights owned by Deci embodied in the DeciLM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the DeciLM Materials. Other than where you are a Hosting User in accordance with Section 6, your copyright license to use the DeciLM Materials be royalty free.
34
+ b. Redistribution and Use.
35
+ i. If you distribute or make the DeciLM Materials, or any derivative works thereof, available to a third party, you shall provide a copy of this Agreement to such third party. Hosting-based restrictions as set out in Section 6 of this license, and which do not otherwise conflict with those provisions, must be included as enforceable provisions by you in any type of legal agreement (e.g. a license) governing the use and/or distribution of the DeciLM Materials or any derivative works that You distribute;
36
+ ii. If you receive DeciLM Materials, or any derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will not apply to you.
37
+ iii. You must retain in all copies of the DeciLM Materials that you distribute the following attribution notice within a “Notice” text file distributed as a part of such copies: “DeciLM is licensed under the DeciLM License, Copyright © Deci.AI Ltd,. All Rights Reserved.”
38
+ iv. Your use of the DeciLM Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Llama Materials (available at https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into this Agreement.
39
+ v. You will not use the DeciLM Materials or any output or results of the DeciLM Materials to improve any other large language model (excluding Llama 2 or DeciLM derivative works thereof).
40
+ 2. Additional Commercial Terms. If, on the DeciLM version release date, the monthly active users of the products or services made available by or for Licensee, or Licensee’s affiliates, is greater than 700 million monthly active users in the preceding calendar month, you must request a license from Deci, which Deci may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Deci otherwise expressly grants you such rights.
41
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE DECILM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE DECILM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE DECILM MATERIALS AND ANY OUTPUT AND RESULTS.
42
+ 4. Limitation of Liability. IN NO EVENT WILL DECI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF DECI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
43
+ 5. Intellectual Property.
44
+ a. No trademark licenses are granted under this Agreement, and in connection with the DeciLM Materials, neither Deci nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the DeciLM Materials.
45
+ b. Subject to Deci’s ownership of Deci Materials and derivatives made by or for Deci, with respect to any derivative works and modifications of the Deci Materials that are made by you, as between you and Deci, and other than where you are a Hosting User in accordance with Section 6, you are and will be the owner of such derivative works and modifications.
46
+ c. If you institute litigation or other proceedings against Deci or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the DeciLM Materials or DeciLM outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Deci from and against any claim by any third party arising out of or related to your use or distribution of the DeciLM Materials.
47
+ 6. Hosting Use
48
+ a. You are not licensed to use the DeciLM Materials or derivative works under this license for Hosting Use. Where You wish to make Hosting Use of DeciLM Materials or derivative works, You must apply to Deci for permission to make Hosting Use of that Work in, providing such information as may be required.
49
+ b. Where Deci grants permission for You to make Hosting Use of the relevant Work, then for that purpose you shall be considered a Hosting User, and your use of DeciLM Materials or derivative works shall be subject to the separate license granted by Deci relating to that use.
50
+ 7. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the DeciLM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Deci may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the DeciLM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
51
+ 8. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of Israel without regard to choice of law principles. The courts of Israel shall have exclusive jurisdiction of any dispute arising out of this Agreement.
52
+
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeciLMForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_decilm.DeciLMConfig",
7
+ "AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM"
8
+ },
9
+ "bos_token_id": 1,
10
+ "eos_token_id": 2,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 11008,
15
+ "max_position_embeddings": 4096,
16
+ "num_attention_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_key_value_heads_per_layer": [4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 4],
19
+ "pretraining_tp": 1,
20
+ "rms_norm_eps": 1e-05,
21
+ "rope_scaling": {"type": "dynamic", "factor": 2.0},
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "bfloat16",
24
+ "use_bfloat16": true,
25
+ "transformers_version": "4.31.0",
26
+ "use_cache": true,
27
+ "vocab_size": 32000
28
+ }
configuration_decilm.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ import transformers
3
+ if version.parse(transformers.__version__) < version.parse("4.31.0"):
4
+ raise ImportError(
5
+ f"You are using transformers=={transformers.__version__}, but transformers>=4.31.0 is required to use DeciLM. Please upgrade transformers."
6
+ )
7
+ from transformers.models.llama.configuration_llama import LlamaConfig
8
+ from transformers.utils import logging
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
14
+
15
+
16
+ class DeciLMConfig(LlamaConfig):
17
+ r"""
18
+
19
+ Args:
20
+ num_key_value_heads_per_layer (`List[int]`):
21
+ The number of key-value heads per layer.
22
+ naive_attention_prefill (`bool`, *optional*, defaults to False):
23
+ Whether to use naive matmul or scaled dot product attention during prefill.
24
+ naive_attention_decode_batched (`bool`, *optional*, defaults to True):
25
+ Whether to use naive matmul or scaled dot product attention during decode for batch_size > 1.
26
+ naive_attention_decode_single (`bool`, *optional*, defaults to False):
27
+ Whether to use naive matmul or scaled dot product attention during decode for batch_size == 1.
28
+
29
+
30
+ ```"""
31
+ keys_to_ignore_at_inference = ["past_key_values"]
32
+
33
+ def __init__(
34
+ self,
35
+ num_key_value_heads_per_layer: list[int] = None,
36
+ naive_attention_prefill: bool = False,
37
+ naive_attention_decode_batched: bool = False,
38
+ naive_attention_decode_single: bool = False,
39
+ **kwargs,
40
+ ):
41
+ self.num_key_value_heads_per_layer = num_key_value_heads_per_layer
42
+ self.naive_attention_prefill = naive_attention_prefill
43
+ self.naive_attention_decode_batched = naive_attention_decode_batched
44
+ self.naive_attention_decode_single = naive_attention_decode_single
45
+ super().__init__(**kwargs, )
46
+
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4d4c218eaa63b65c4d894d370b2ab5dc43646c37b464157abbee7916f84d487
3
+ size 4953700360
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9de4c260b481ba715dc8186a57e45fa4f259355d2afd83139daee05e2d5eaf9d
3
+ size 4915985520
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7adcb6ac52834cb6ffe44ca24ef940857188a51da8e20af3883f405ad472fa0d
3
+ size 1564553032
model.safetensors.index.json ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 11434205184
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00003-of-00003.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
26
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
27
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
28
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
29
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
30
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
31
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
32
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
33
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
34
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
35
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
36
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
37
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
38
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
39
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
40
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
41
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
42
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
43
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
44
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00003.safetensors",
45
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
46
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
47
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
48
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
49
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
50
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
51
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
52
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
53
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
54
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
55
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
56
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
57
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
58
+ "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
59
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
60
+ "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
61
+ "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
62
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
63
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
64
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
65
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
66
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
67
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
68
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
69
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
70
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
71
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
72
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
73
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
74
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
75
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
76
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
77
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
78
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
79
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
80
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
81
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
82
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
84
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
86
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
88
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
90
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
91
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
92
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
93
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
94
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
95
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
96
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
97
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
98
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
99
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
100
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
101
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
102
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
103
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
104
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
105
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
106
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
107
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
108
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
109
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
110
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
111
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
112
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
113
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
114
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
115
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
116
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
117
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
118
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
119
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
120
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
121
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
122
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
123
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
124
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
125
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
126
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
127
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
128
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
129
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
130
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
131
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
132
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
134
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
135
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
136
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
137
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
140
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
141
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
142
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
143
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
144
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
145
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
146
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
147
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
148
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
149
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
150
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
151
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
152
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
153
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
154
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
155
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
156
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
157
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
158
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
159
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
160
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
161
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00003.safetensors",
162
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
163
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
164
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
165
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
166
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
167
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
168
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
169
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
170
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00003.safetensors",
171
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
172
+ "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
173
+ "model.layers.25.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
174
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
175
+ "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
176
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
177
+ "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
178
+ "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
179
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00003.safetensors",
180
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
181
+ "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
182
+ "model.layers.26.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
183
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
184
+ "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
185
+ "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
186
+ "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
187
+ "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
188
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00003.safetensors",
189
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
190
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
191
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
192
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
193
+ "model.layers.27.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
194
+ "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
195
+ "model.layers.27.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
196
+ "model.layers.27.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
197
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
198
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
199
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
200
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
201
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
202
+ "model.layers.28.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
203
+ "model.layers.28.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
204
+ "model.layers.28.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
205
+ "model.layers.28.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
206
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
207
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
208
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
209
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
210
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
211
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
212
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
213
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
214
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
215
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
216
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
217
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
218
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
219
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
220
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
221
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
222
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
223
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
224
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
225
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
226
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
227
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
228
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
229
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
230
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
231
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
232
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
233
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
234
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
235
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
236
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
237
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
238
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
239
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
240
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
241
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
242
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
243
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
244
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
245
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
246
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
247
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
248
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
249
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
250
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
251
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
252
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
253
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
254
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
255
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
256
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
257
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
258
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
259
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
260
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
261
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
262
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
263
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
264
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
265
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
266
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
267
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
268
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
269
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
270
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
271
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
272
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
273
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
274
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
275
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
276
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
277
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
278
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
279
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
280
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
281
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
282
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
283
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
284
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
285
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
286
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
287
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
288
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
289
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
290
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
291
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
292
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
293
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
294
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
295
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
296
+ "model.norm.weight": "model-00003-of-00003.safetensors"
297
+ }
298
+ }
modeling_decilm.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright and license here
3
+ """ PyTorch DeciLM model."""
4
+ import math
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from packaging import version
12
+ import transformers
13
+ if version.parse(transformers.__version__) < version.parse("4.31.0"):
14
+ raise ImportError(
15
+ f"You are using transformers=={transformers.__version__}, but transformers>=4.31.0 is required to use DeciLM. Please upgrade transformers."
16
+ )
17
+ from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, LlamaAttention, apply_rotary_pos_emb, \
18
+ repeat_kv, LlamaPreTrainedModel, LLAMA_START_DOCSTRING, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
19
+ from transformers.utils import add_start_docstrings
20
+
21
+ from .configuration_decilm import DeciLMConfig
22
+
23
+ _CONFIG_FOR_DOC = "DeciLMConfig"
24
+
25
+
26
+ class DeciLMAttention(LlamaAttention):
27
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
28
+
29
+ def __init__(self, config: DeciLMConfig, layer_idx: int):
30
+ nn.Module.__init__(self)
31
+ self.config = config
32
+ self.hidden_size = config.hidden_size
33
+ self.num_heads = config.num_attention_heads
34
+ self.head_dim = self.hidden_size // self.num_heads
35
+ self.layer_idx = layer_idx
36
+ self.num_key_value_heads = config.num_key_value_heads_per_layer[layer_idx]
37
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
38
+ self.pretraining_tp = config.pretraining_tp
39
+ self.max_position_embeddings = config.max_position_embeddings
40
+ self.rope_theta = getattr(config, 'rope_theta', None)
41
+
42
+ if (self.head_dim * self.num_heads) != self.hidden_size:
43
+ raise ValueError(
44
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
45
+ f" and `num_heads`: {self.num_heads})."
46
+ )
47
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
48
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
49
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
50
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
51
+
52
+ self.naive_attention_prefill = config.naive_attention_prefill
53
+ self.naive_attention_decode_batched = config.naive_attention_decode_batched
54
+ self.naive_attention_decode_single = config.naive_attention_decode_single
55
+ self._init_rope()
56
+
57
+ def forward(
58
+ self,
59
+ hidden_states: torch.Tensor,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
63
+ output_attentions: bool = False,
64
+ use_cache: bool = False,
65
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
66
+ bsz, q_len, _ = hidden_states.size()
67
+ if past_key_value is None:
68
+ is_decode = False
69
+ else:
70
+ is_decode = True
71
+ if self.pretraining_tp > 1:
72
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
73
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
74
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
75
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
76
+
77
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
78
+ query_states = torch.cat(query_states, dim=-1)
79
+
80
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
81
+ key_states = torch.cat(key_states, dim=-1)
82
+
83
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
84
+ value_states = torch.cat(value_states, dim=-1)
85
+
86
+ else:
87
+ query_states = self.q_proj(hidden_states)
88
+ key_states = self.k_proj(hidden_states)
89
+ value_states = self.v_proj(hidden_states)
90
+
91
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
92
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
93
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
94
+
95
+ kv_seq_len = key_states.shape[-2]
96
+ if past_key_value is not None:
97
+ kv_seq_len += past_key_value[0].shape[-2]
98
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
99
+
100
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
101
+
102
+ if past_key_value is not None:
103
+ # reuse k, v, self_attention
104
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
105
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
106
+
107
+ past_key_value = (key_states, value_states) if use_cache else None
108
+
109
+ # repeat k/v heads if n_kv_heads < n_heads
110
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
111
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
112
+ if is_decode:
113
+ if self.naive_attention_decode_batched and bsz > 1 or self.naive_attention_decode_single and bsz == 1:
114
+ attn_weights = (query_states @ key_states.transpose(-2, -1)) / math.sqrt(key_states.size(-1))
115
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
116
+ if attention_mask is not None:
117
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
118
+ raise ValueError(
119
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
120
+ )
121
+ attn_weights = attn_weights + attention_mask
122
+
123
+ attn_output = torch.matmul(attn_weights, value_states)
124
+ else:
125
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=False,
126
+ dropout_p=0.0)
127
+ attn_output = attn_output.contiguous().view(bsz, q_len, self.hidden_size)
128
+
129
+ else:
130
+ if not self.naive_attention_prefill:
131
+ with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
132
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True,
133
+ dropout_p=0.0)
134
+ else:
135
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
136
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
137
+ raise ValueError(
138
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
139
+ f" {attn_weights.size()}"
140
+ )
141
+
142
+ if attention_mask is not None:
143
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
144
+ raise ValueError(
145
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
146
+ )
147
+ attn_weights = attn_weights + attention_mask
148
+
149
+ # upcast attention to fp32
150
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
151
+ attn_output = torch.matmul(attn_weights, value_states)
152
+
153
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
154
+ raise ValueError(
155
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
156
+ f" {attn_output.size()}"
157
+ )
158
+
159
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
160
+
161
+ if self.pretraining_tp > 1:
162
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
163
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
164
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
165
+ else:
166
+ attn_output = self.o_proj(attn_output)
167
+
168
+ if not output_attentions:
169
+ attn_weights = None
170
+
171
+ return attn_output, attn_weights, past_key_value
172
+
173
+
174
+ class DeciLMDecoderLayer(LlamaDecoderLayer):
175
+ def __init__(self, config: DeciLMConfig, layer_idx: int):
176
+ nn.Module.__init__(self)
177
+ self.hidden_size = config.hidden_size
178
+ self.layer_idx = layer_idx
179
+ self.self_attn = DeciLMAttention(config=config, layer_idx=layer_idx)
180
+ self.mlp = LlamaMLP(config)
181
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
182
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
183
+
184
+
185
+ @add_start_docstrings(
186
+ "The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
187
+ LLAMA_START_DOCSTRING,
188
+ )
189
+ class DeciLMPreTrainedModel(LlamaPreTrainedModel):
190
+ config_class = DeciLMConfig
191
+ _no_split_modules = ["DeciLMDecoderLayer"]
192
+ _keys_to_ignore_on_load_missing = ["self_attn.rotary_emb.inv_freq"]
193
+
194
+
195
+ @add_start_docstrings(
196
+ "The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
197
+ LLAMA_START_DOCSTRING,
198
+ )
199
+ class DeciLMModel(LlamaModel, DeciLMPreTrainedModel):
200
+ """
201
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`]
202
+
203
+ Args:
204
+ config: DeciLMConfig
205
+ """
206
+
207
+ def __init__(self, config: DeciLMConfig):
208
+ DeciLMPreTrainedModel.__init__(self, config)
209
+ self.padding_idx = config.pad_token_id
210
+ self.vocab_size = config.vocab_size
211
+
212
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
213
+ self.layers = nn.ModuleList([DeciLMDecoderLayer(config, layer_idx) for layer_idx
214
+ in range(config.num_hidden_layers)])
215
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
216
+
217
+ self.gradient_checkpointing = False
218
+ # Initialize weights and apply final processing
219
+ self.post_init()
220
+
221
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
222
+ self._validate_config_supports_attention_mask(attention_mask, input_shape, past_key_values_length)
223
+ return LlamaModel._prepare_decoder_attention_mask(
224
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length)
225
+
226
+ def _validate_config_supports_attention_mask(self, attention_mask, input_shape, past_key_values_length):
227
+ is_decode = past_key_values_length > 0
228
+ if not torch.all(torch.eq(attention_mask, 1)).item():
229
+ if is_decode:
230
+ if input_shape[0] == 1 and not self.config.naive_attention_decode_single:
231
+ raise ValueError(
232
+ "For support of custom attention masks please set naive_attention_decode_single to True in the "
233
+ "config")
234
+ elif input_shape[0] > 1 and not self.config.naive_attention_decode_batched:
235
+ raise ValueError(
236
+ "For support of custom attention masks please set naive_attention_decode_batched to True in the"
237
+ "config")
238
+ else:
239
+ if not self.config.naive_attention_prefill:
240
+ raise ValueError("For support of custom attention masks please set naive_attention_prefill to "
241
+ "True in the config")
242
+
243
+
244
+ class DeciLMForCausalLM(LlamaForCausalLM, DeciLMPreTrainedModel):
245
+ def __init__(self, config):
246
+ DeciLMPreTrainedModel.__init__(self, config)
247
+ self.model = DeciLMModel(config)
248
+ self.pretraining_tp = config.pretraining_tp
249
+ self.vocab_size = config.vocab_size
250
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
251
+
252
+ # Initialize weights and apply final processing
253
+ self.post_init()
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": false,
22
+ "model_max_length": 1000000000000000019884624838656,
23
+ "pad_token": null,
24
+ "padding_side": "right",
25
+ "sp_model_kwargs": {},
26
+ "tokenizer_class": "LlamaTokenizer",
27
+ "unk_token": {
28
+ "__type": "AddedToken",
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ }
35
+ }