mayank-mishra commited on
Commit
afc3169
1 Parent(s): 2263de3
config.json CHANGED
@@ -1,40 +1,28 @@
1
  {
2
- "activation_function": "swiglu",
3
- "add_bias": true,
4
- "apply_residual_connection_post_layernorm": false,
5
  "architectures": [
6
- "GraniteForCausalLM"
7
  ],
8
- "attention_head_type": "mha",
9
- "attention_multiplier": null,
10
- "attention_softmax_in_fp32": true,
11
- "attn_pdrop": 0.1,
12
- "auto_map": {
13
- "AutoConfig": "configuration_granite.GraniteConfig",
14
- "AutoModel": "modeling_granite.GraniteModel",
15
- "AutoModelForCausalLM": "modeling_granite.GraniteForCausalLM"
16
- },
17
- "bos_token_id": 0,
18
- "embd_pdrop": 0.1,
19
- "eos_token_id": 0,
20
  "initializer_range": 0.02,
21
- "layer_norm_epsilon": 1e-05,
22
- "model_type": "granite",
23
- "n_embd": 2560,
24
- "n_head": 32,
25
- "n_inner": 10240,
26
- "n_layer": 32,
27
- "n_positions": 2048,
28
- "normalization_function": "rmsnorm",
29
  "num_key_value_heads": 32,
30
- "pad_token_id": 0,
31
- "position_embedding_type": "rope",
32
- "resid_pdrop": 0.1,
33
  "rope_theta": 10000,
34
- "scale_attention_softmax_in_fp32": true,
35
- "scale_attn_weights": true,
36
  "torch_dtype": "bfloat16",
37
- "transformers_version": "4.38.1",
38
  "use_cache": true,
39
  "vocab_size": 49152
40
  }
 
1
  {
 
 
 
2
  "architectures": [
3
+ "LlamaForCausalLM"
4
  ],
5
+ "attention_bias": true,
6
+ "attention_dropout": 0.1,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2560,
 
 
 
 
 
 
11
  "initializer_range": 0.02,
12
+ "intermediate_size": 10240,
13
+ "max_position_embeddings": 2048,
14
+ "mlp_bias": true,
15
+ "model_type": "llama",
16
+ "num_attention_heads": 32,
17
+ "num_hidden_layers": 32,
 
 
18
  "num_key_value_heads": 32,
19
+ "pretraining_tp": 1,
20
+ "rms_norm_eps": 1e-05,
21
+ "rope_scaling": null,
22
  "rope_theta": 10000,
23
+ "tie_word_embeddings": true,
 
24
  "torch_dtype": "bfloat16",
25
+ "transformers_version": "4.41.0.dev0",
26
  "use_cache": true,
27
  "vocab_size": 49152
28
  }
configuration_granite.py DELETED
@@ -1,98 +0,0 @@
1
- from transformers import PretrainedConfig
2
-
3
-
4
- class GraniteConfig(PretrainedConfig):
5
- model_type = "granite"
6
-
7
- keys_to_ignore_at_inference = ["past_key_values"]
8
- attribute_map = {
9
- "hidden_size": "n_embd",
10
- "max_position_embeddings": "n_positions",
11
- "num_attention_heads": "n_head",
12
- "num_hidden_layers": "n_layer",
13
- }
14
-
15
- def __init__(
16
- self,
17
- vocab_size: int = 50257,
18
- n_positions: int = 1024,
19
- n_embd: int = 768,
20
- n_layer: int = 12,
21
- n_head: int = 12,
22
- num_key_value_heads: int = None,
23
- n_inner: int = None,
24
- activation_function: str = "gelu_pytorch_tanh",
25
- attention_head_type: str = "mqa",
26
- resid_pdrop: float = 0.1,
27
- embd_pdrop: float = 0.1,
28
- attn_pdrop: float = 0.1,
29
- normalization_function: str = "layernorm",
30
- layer_norm_epsilon: float = 1e-5,
31
- initializer_range: float = 0.02,
32
- scale_attn_weights: bool = True,
33
- attention_multiplier: float = None,
34
- use_cache: bool = True,
35
- bos_token_id: int = 50256,
36
- eos_token_id: int = 50256,
37
- pad_token_id: int = 50256,
38
- attention_softmax_in_fp32: bool = True,
39
- scale_attention_softmax_in_fp32: bool = True,
40
- add_bias: bool = True,
41
- position_embedding_type: str = "learned_absolute",
42
- rope_theta: int = 10000,
43
- **kwargs,
44
- ) -> None:
45
- self.vocab_size = vocab_size
46
- self.n_positions = n_positions
47
- self.n_embd = n_embd
48
- self.n_layer = n_layer
49
- self.n_head = n_head
50
- self.num_key_value_heads = num_key_value_heads
51
- self.n_inner = 4 * n_embd if n_inner is None else n_inner
52
- self.activation_function = activation_function
53
- self.attention_head_type = attention_head_type
54
- self.resid_pdrop = resid_pdrop
55
- self.embd_pdrop = embd_pdrop
56
- self.attn_pdrop = attn_pdrop
57
- self.normalization_function = normalization_function
58
- self.layer_norm_epsilon = layer_norm_epsilon
59
- self.initializer_range = initializer_range
60
- self.scale_attn_weights = scale_attn_weights
61
- self.attention_multiplier = attention_multiplier
62
- self.use_cache = use_cache
63
- self.attention_softmax_in_fp32 = attention_softmax_in_fp32
64
- self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
65
- self.position_embedding_type = position_embedding_type
66
- self.add_bias = add_bias
67
- self.rope_theta = rope_theta
68
-
69
- if self.attention_multiplier is not None:
70
- assert self.scale_attn_weights
71
-
72
- # for compatibility with some features
73
- self.multi_query = attention_head_type == "mqa"
74
-
75
- if attention_head_type == "mha":
76
- if self.num_key_value_heads is None:
77
- self.num_key_value_heads = self.n_head
78
-
79
- assert (
80
- self.n_head == self.num_key_value_heads
81
- ), "MultiHeadAttention should have same number of heads for query, keys and values"
82
- elif attention_head_type == "mqa":
83
- if self.num_key_value_heads is None:
84
- self.num_key_value_heads = 1
85
-
86
- assert self.num_key_value_heads == 1, "MultiQueryAttention should have 1 head for keys and values"
87
- elif attention_head_type == "gqa":
88
- assert (
89
- self.num_key_value_heads is not None
90
- ), "`num_key_value_heads` needs to be specified with GroupedQueryAttention"
91
-
92
- assert (
93
- self.n_head % self.num_key_value_heads == 0
94
- ), "GroupedQueryAttention should have more than 1 head for keys and values"
95
- else:
96
- raise ValueError(f"unexpected attention_head_type ({attention_head_type})")
97
-
98
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generation_config.json CHANGED
@@ -1,7 +1,6 @@
1
  {
2
  "_from_model_config": true,
3
- "bos_token_id": 0,
4
- "eos_token_id": 0,
5
- "pad_token_id": 0,
6
- "transformers_version": "4.38.1"
7
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.41.0.dev0"
 
6
  }
model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c68c29b3f0a88f6a60a53e2ba2f0c8c94c16f7adaf412981297bd54ed58abed4
3
- size 4919566152
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7127432f1ffff5b25875ce4e287c9dfa458655ad1b6183a2d274c061bf9e4b1b
3
+ size 4972021832
model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9c342afb52cc250f98581d9b2d079c273fdfdf7ab629f9c63ef11c140d6a2a9f
3
- size 2045475816
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:344a03bb26fe17bacac63b616c45397758aeb6dfb0f36b4495139dc4530b885b
3
+ size 1993043568
model.safetensors.index.json CHANGED
@@ -3,327 +3,519 @@
3
  "total_size": 6965007360
4
  },
5
  "weight_map": {
6
- "transformer.h.0.attn.c_attn.bias": "model-00001-of-00002.safetensors",
7
- "transformer.h.0.attn.c_attn.weight": "model-00001-of-00002.safetensors",
8
- "transformer.h.0.attn.c_proj.bias": "model-00001-of-00002.safetensors",
9
- "transformer.h.0.attn.c_proj.weight": "model-00001-of-00002.safetensors",
10
- "transformer.h.0.ln_1.weight": "model-00001-of-00002.safetensors",
11
- "transformer.h.0.ln_2.weight": "model-00001-of-00002.safetensors",
12
- "transformer.h.0.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
13
- "transformer.h.0.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
14
- "transformer.h.0.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
15
- "transformer.h.0.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
16
- "transformer.h.1.attn.c_attn.bias": "model-00001-of-00002.safetensors",
17
- "transformer.h.1.attn.c_attn.weight": "model-00001-of-00002.safetensors",
18
- "transformer.h.1.attn.c_proj.bias": "model-00001-of-00002.safetensors",
19
- "transformer.h.1.attn.c_proj.weight": "model-00001-of-00002.safetensors",
20
- "transformer.h.1.ln_1.weight": "model-00001-of-00002.safetensors",
21
- "transformer.h.1.ln_2.weight": "model-00001-of-00002.safetensors",
22
- "transformer.h.1.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
23
- "transformer.h.1.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
24
- "transformer.h.1.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
25
- "transformer.h.1.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
26
- "transformer.h.10.attn.c_attn.bias": "model-00001-of-00002.safetensors",
27
- "transformer.h.10.attn.c_attn.weight": "model-00001-of-00002.safetensors",
28
- "transformer.h.10.attn.c_proj.bias": "model-00001-of-00002.safetensors",
29
- "transformer.h.10.attn.c_proj.weight": "model-00001-of-00002.safetensors",
30
- "transformer.h.10.ln_1.weight": "model-00001-of-00002.safetensors",
31
- "transformer.h.10.ln_2.weight": "model-00001-of-00002.safetensors",
32
- "transformer.h.10.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
33
- "transformer.h.10.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
34
- "transformer.h.10.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
35
- "transformer.h.10.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
36
- "transformer.h.11.attn.c_attn.bias": "model-00001-of-00002.safetensors",
37
- "transformer.h.11.attn.c_attn.weight": "model-00001-of-00002.safetensors",
38
- "transformer.h.11.attn.c_proj.bias": "model-00001-of-00002.safetensors",
39
- "transformer.h.11.attn.c_proj.weight": "model-00001-of-00002.safetensors",
40
- "transformer.h.11.ln_1.weight": "model-00001-of-00002.safetensors",
41
- "transformer.h.11.ln_2.weight": "model-00001-of-00002.safetensors",
42
- "transformer.h.11.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
43
- "transformer.h.11.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
44
- "transformer.h.11.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
45
- "transformer.h.11.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
46
- "transformer.h.12.attn.c_attn.bias": "model-00001-of-00002.safetensors",
47
- "transformer.h.12.attn.c_attn.weight": "model-00001-of-00002.safetensors",
48
- "transformer.h.12.attn.c_proj.bias": "model-00001-of-00002.safetensors",
49
- "transformer.h.12.attn.c_proj.weight": "model-00001-of-00002.safetensors",
50
- "transformer.h.12.ln_1.weight": "model-00001-of-00002.safetensors",
51
- "transformer.h.12.ln_2.weight": "model-00001-of-00002.safetensors",
52
- "transformer.h.12.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
53
- "transformer.h.12.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
54
- "transformer.h.12.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
55
- "transformer.h.12.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
56
- "transformer.h.13.attn.c_attn.bias": "model-00001-of-00002.safetensors",
57
- "transformer.h.13.attn.c_attn.weight": "model-00001-of-00002.safetensors",
58
- "transformer.h.13.attn.c_proj.bias": "model-00001-of-00002.safetensors",
59
- "transformer.h.13.attn.c_proj.weight": "model-00001-of-00002.safetensors",
60
- "transformer.h.13.ln_1.weight": "model-00001-of-00002.safetensors",
61
- "transformer.h.13.ln_2.weight": "model-00001-of-00002.safetensors",
62
- "transformer.h.13.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
63
- "transformer.h.13.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
64
- "transformer.h.13.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
65
- "transformer.h.13.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
66
- "transformer.h.14.attn.c_attn.bias": "model-00001-of-00002.safetensors",
67
- "transformer.h.14.attn.c_attn.weight": "model-00001-of-00002.safetensors",
68
- "transformer.h.14.attn.c_proj.bias": "model-00001-of-00002.safetensors",
69
- "transformer.h.14.attn.c_proj.weight": "model-00001-of-00002.safetensors",
70
- "transformer.h.14.ln_1.weight": "model-00001-of-00002.safetensors",
71
- "transformer.h.14.ln_2.weight": "model-00001-of-00002.safetensors",
72
- "transformer.h.14.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
73
- "transformer.h.14.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
74
- "transformer.h.14.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
75
- "transformer.h.14.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
76
- "transformer.h.15.attn.c_attn.bias": "model-00001-of-00002.safetensors",
77
- "transformer.h.15.attn.c_attn.weight": "model-00001-of-00002.safetensors",
78
- "transformer.h.15.attn.c_proj.bias": "model-00001-of-00002.safetensors",
79
- "transformer.h.15.attn.c_proj.weight": "model-00001-of-00002.safetensors",
80
- "transformer.h.15.ln_1.weight": "model-00001-of-00002.safetensors",
81
- "transformer.h.15.ln_2.weight": "model-00001-of-00002.safetensors",
82
- "transformer.h.15.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
83
- "transformer.h.15.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
84
- "transformer.h.15.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
85
- "transformer.h.15.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
86
- "transformer.h.16.attn.c_attn.bias": "model-00001-of-00002.safetensors",
87
- "transformer.h.16.attn.c_attn.weight": "model-00001-of-00002.safetensors",
88
- "transformer.h.16.attn.c_proj.bias": "model-00001-of-00002.safetensors",
89
- "transformer.h.16.attn.c_proj.weight": "model-00001-of-00002.safetensors",
90
- "transformer.h.16.ln_1.weight": "model-00001-of-00002.safetensors",
91
- "transformer.h.16.ln_2.weight": "model-00001-of-00002.safetensors",
92
- "transformer.h.16.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
93
- "transformer.h.16.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
94
- "transformer.h.16.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
95
- "transformer.h.16.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
96
- "transformer.h.17.attn.c_attn.bias": "model-00001-of-00002.safetensors",
97
- "transformer.h.17.attn.c_attn.weight": "model-00001-of-00002.safetensors",
98
- "transformer.h.17.attn.c_proj.bias": "model-00001-of-00002.safetensors",
99
- "transformer.h.17.attn.c_proj.weight": "model-00001-of-00002.safetensors",
100
- "transformer.h.17.ln_1.weight": "model-00001-of-00002.safetensors",
101
- "transformer.h.17.ln_2.weight": "model-00001-of-00002.safetensors",
102
- "transformer.h.17.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
103
- "transformer.h.17.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
104
- "transformer.h.17.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
105
- "transformer.h.17.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
106
- "transformer.h.18.attn.c_attn.bias": "model-00001-of-00002.safetensors",
107
- "transformer.h.18.attn.c_attn.weight": "model-00001-of-00002.safetensors",
108
- "transformer.h.18.attn.c_proj.bias": "model-00001-of-00002.safetensors",
109
- "transformer.h.18.attn.c_proj.weight": "model-00001-of-00002.safetensors",
110
- "transformer.h.18.ln_1.weight": "model-00001-of-00002.safetensors",
111
- "transformer.h.18.ln_2.weight": "model-00001-of-00002.safetensors",
112
- "transformer.h.18.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
113
- "transformer.h.18.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
114
- "transformer.h.18.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
115
- "transformer.h.18.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
116
- "transformer.h.19.attn.c_attn.bias": "model-00001-of-00002.safetensors",
117
- "transformer.h.19.attn.c_attn.weight": "model-00001-of-00002.safetensors",
118
- "transformer.h.19.attn.c_proj.bias": "model-00001-of-00002.safetensors",
119
- "transformer.h.19.attn.c_proj.weight": "model-00001-of-00002.safetensors",
120
- "transformer.h.19.ln_1.weight": "model-00001-of-00002.safetensors",
121
- "transformer.h.19.ln_2.weight": "model-00001-of-00002.safetensors",
122
- "transformer.h.19.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
123
- "transformer.h.19.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
124
- "transformer.h.19.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
125
- "transformer.h.19.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
126
- "transformer.h.2.attn.c_attn.bias": "model-00001-of-00002.safetensors",
127
- "transformer.h.2.attn.c_attn.weight": "model-00001-of-00002.safetensors",
128
- "transformer.h.2.attn.c_proj.bias": "model-00001-of-00002.safetensors",
129
- "transformer.h.2.attn.c_proj.weight": "model-00001-of-00002.safetensors",
130
- "transformer.h.2.ln_1.weight": "model-00001-of-00002.safetensors",
131
- "transformer.h.2.ln_2.weight": "model-00001-of-00002.safetensors",
132
- "transformer.h.2.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
133
- "transformer.h.2.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
134
- "transformer.h.2.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
135
- "transformer.h.2.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
136
- "transformer.h.20.attn.c_attn.bias": "model-00001-of-00002.safetensors",
137
- "transformer.h.20.attn.c_attn.weight": "model-00001-of-00002.safetensors",
138
- "transformer.h.20.attn.c_proj.bias": "model-00001-of-00002.safetensors",
139
- "transformer.h.20.attn.c_proj.weight": "model-00001-of-00002.safetensors",
140
- "transformer.h.20.ln_1.weight": "model-00001-of-00002.safetensors",
141
- "transformer.h.20.ln_2.weight": "model-00001-of-00002.safetensors",
142
- "transformer.h.20.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
143
- "transformer.h.20.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
144
- "transformer.h.20.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
145
- "transformer.h.20.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
146
- "transformer.h.21.attn.c_attn.bias": "model-00001-of-00002.safetensors",
147
- "transformer.h.21.attn.c_attn.weight": "model-00001-of-00002.safetensors",
148
- "transformer.h.21.attn.c_proj.bias": "model-00001-of-00002.safetensors",
149
- "transformer.h.21.attn.c_proj.weight": "model-00001-of-00002.safetensors",
150
- "transformer.h.21.ln_1.weight": "model-00001-of-00002.safetensors",
151
- "transformer.h.21.ln_2.weight": "model-00001-of-00002.safetensors",
152
- "transformer.h.21.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
153
- "transformer.h.21.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
154
- "transformer.h.21.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
155
- "transformer.h.21.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
156
- "transformer.h.22.attn.c_attn.bias": "model-00001-of-00002.safetensors",
157
- "transformer.h.22.attn.c_attn.weight": "model-00001-of-00002.safetensors",
158
- "transformer.h.22.attn.c_proj.bias": "model-00001-of-00002.safetensors",
159
- "transformer.h.22.attn.c_proj.weight": "model-00001-of-00002.safetensors",
160
- "transformer.h.22.ln_1.weight": "model-00001-of-00002.safetensors",
161
- "transformer.h.22.ln_2.weight": "model-00001-of-00002.safetensors",
162
- "transformer.h.22.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
163
- "transformer.h.22.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
164
- "transformer.h.22.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
165
- "transformer.h.22.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
166
- "transformer.h.23.attn.c_attn.bias": "model-00002-of-00002.safetensors",
167
- "transformer.h.23.attn.c_attn.weight": "model-00002-of-00002.safetensors",
168
- "transformer.h.23.attn.c_proj.bias": "model-00002-of-00002.safetensors",
169
- "transformer.h.23.attn.c_proj.weight": "model-00002-of-00002.safetensors",
170
- "transformer.h.23.ln_1.weight": "model-00002-of-00002.safetensors",
171
- "transformer.h.23.ln_2.weight": "model-00002-of-00002.safetensors",
172
- "transformer.h.23.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
173
- "transformer.h.23.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
174
- "transformer.h.23.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
175
- "transformer.h.23.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
176
- "transformer.h.24.attn.c_attn.bias": "model-00002-of-00002.safetensors",
177
- "transformer.h.24.attn.c_attn.weight": "model-00002-of-00002.safetensors",
178
- "transformer.h.24.attn.c_proj.bias": "model-00002-of-00002.safetensors",
179
- "transformer.h.24.attn.c_proj.weight": "model-00002-of-00002.safetensors",
180
- "transformer.h.24.ln_1.weight": "model-00002-of-00002.safetensors",
181
- "transformer.h.24.ln_2.weight": "model-00002-of-00002.safetensors",
182
- "transformer.h.24.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
183
- "transformer.h.24.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
184
- "transformer.h.24.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
185
- "transformer.h.24.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
186
- "transformer.h.25.attn.c_attn.bias": "model-00002-of-00002.safetensors",
187
- "transformer.h.25.attn.c_attn.weight": "model-00002-of-00002.safetensors",
188
- "transformer.h.25.attn.c_proj.bias": "model-00002-of-00002.safetensors",
189
- "transformer.h.25.attn.c_proj.weight": "model-00002-of-00002.safetensors",
190
- "transformer.h.25.ln_1.weight": "model-00002-of-00002.safetensors",
191
- "transformer.h.25.ln_2.weight": "model-00002-of-00002.safetensors",
192
- "transformer.h.25.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
193
- "transformer.h.25.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
194
- "transformer.h.25.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
195
- "transformer.h.25.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
196
- "transformer.h.26.attn.c_attn.bias": "model-00002-of-00002.safetensors",
197
- "transformer.h.26.attn.c_attn.weight": "model-00002-of-00002.safetensors",
198
- "transformer.h.26.attn.c_proj.bias": "model-00002-of-00002.safetensors",
199
- "transformer.h.26.attn.c_proj.weight": "model-00002-of-00002.safetensors",
200
- "transformer.h.26.ln_1.weight": "model-00002-of-00002.safetensors",
201
- "transformer.h.26.ln_2.weight": "model-00002-of-00002.safetensors",
202
- "transformer.h.26.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
203
- "transformer.h.26.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
204
- "transformer.h.26.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
205
- "transformer.h.26.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
206
- "transformer.h.27.attn.c_attn.bias": "model-00002-of-00002.safetensors",
207
- "transformer.h.27.attn.c_attn.weight": "model-00002-of-00002.safetensors",
208
- "transformer.h.27.attn.c_proj.bias": "model-00002-of-00002.safetensors",
209
- "transformer.h.27.attn.c_proj.weight": "model-00002-of-00002.safetensors",
210
- "transformer.h.27.ln_1.weight": "model-00002-of-00002.safetensors",
211
- "transformer.h.27.ln_2.weight": "model-00002-of-00002.safetensors",
212
- "transformer.h.27.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
213
- "transformer.h.27.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
214
- "transformer.h.27.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
215
- "transformer.h.27.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
216
- "transformer.h.28.attn.c_attn.bias": "model-00002-of-00002.safetensors",
217
- "transformer.h.28.attn.c_attn.weight": "model-00002-of-00002.safetensors",
218
- "transformer.h.28.attn.c_proj.bias": "model-00002-of-00002.safetensors",
219
- "transformer.h.28.attn.c_proj.weight": "model-00002-of-00002.safetensors",
220
- "transformer.h.28.ln_1.weight": "model-00002-of-00002.safetensors",
221
- "transformer.h.28.ln_2.weight": "model-00002-of-00002.safetensors",
222
- "transformer.h.28.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
223
- "transformer.h.28.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
224
- "transformer.h.28.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
225
- "transformer.h.28.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
226
- "transformer.h.29.attn.c_attn.bias": "model-00002-of-00002.safetensors",
227
- "transformer.h.29.attn.c_attn.weight": "model-00002-of-00002.safetensors",
228
- "transformer.h.29.attn.c_proj.bias": "model-00002-of-00002.safetensors",
229
- "transformer.h.29.attn.c_proj.weight": "model-00002-of-00002.safetensors",
230
- "transformer.h.29.ln_1.weight": "model-00002-of-00002.safetensors",
231
- "transformer.h.29.ln_2.weight": "model-00002-of-00002.safetensors",
232
- "transformer.h.29.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
233
- "transformer.h.29.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
234
- "transformer.h.29.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
235
- "transformer.h.29.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
236
- "transformer.h.3.attn.c_attn.bias": "model-00001-of-00002.safetensors",
237
- "transformer.h.3.attn.c_attn.weight": "model-00001-of-00002.safetensors",
238
- "transformer.h.3.attn.c_proj.bias": "model-00001-of-00002.safetensors",
239
- "transformer.h.3.attn.c_proj.weight": "model-00001-of-00002.safetensors",
240
- "transformer.h.3.ln_1.weight": "model-00001-of-00002.safetensors",
241
- "transformer.h.3.ln_2.weight": "model-00001-of-00002.safetensors",
242
- "transformer.h.3.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
243
- "transformer.h.3.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
244
- "transformer.h.3.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
245
- "transformer.h.3.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
246
- "transformer.h.30.attn.c_attn.bias": "model-00002-of-00002.safetensors",
247
- "transformer.h.30.attn.c_attn.weight": "model-00002-of-00002.safetensors",
248
- "transformer.h.30.attn.c_proj.bias": "model-00002-of-00002.safetensors",
249
- "transformer.h.30.attn.c_proj.weight": "model-00002-of-00002.safetensors",
250
- "transformer.h.30.ln_1.weight": "model-00002-of-00002.safetensors",
251
- "transformer.h.30.ln_2.weight": "model-00002-of-00002.safetensors",
252
- "transformer.h.30.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
253
- "transformer.h.30.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
254
- "transformer.h.30.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
255
- "transformer.h.30.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
256
- "transformer.h.31.attn.c_attn.bias": "model-00002-of-00002.safetensors",
257
- "transformer.h.31.attn.c_attn.weight": "model-00002-of-00002.safetensors",
258
- "transformer.h.31.attn.c_proj.bias": "model-00002-of-00002.safetensors",
259
- "transformer.h.31.attn.c_proj.weight": "model-00002-of-00002.safetensors",
260
- "transformer.h.31.ln_1.weight": "model-00002-of-00002.safetensors",
261
- "transformer.h.31.ln_2.weight": "model-00002-of-00002.safetensors",
262
- "transformer.h.31.mlp.c_fc.bias": "model-00002-of-00002.safetensors",
263
- "transformer.h.31.mlp.c_fc.weight": "model-00002-of-00002.safetensors",
264
- "transformer.h.31.mlp.c_proj.bias": "model-00002-of-00002.safetensors",
265
- "transformer.h.31.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
266
- "transformer.h.4.attn.c_attn.bias": "model-00001-of-00002.safetensors",
267
- "transformer.h.4.attn.c_attn.weight": "model-00001-of-00002.safetensors",
268
- "transformer.h.4.attn.c_proj.bias": "model-00001-of-00002.safetensors",
269
- "transformer.h.4.attn.c_proj.weight": "model-00001-of-00002.safetensors",
270
- "transformer.h.4.ln_1.weight": "model-00001-of-00002.safetensors",
271
- "transformer.h.4.ln_2.weight": "model-00001-of-00002.safetensors",
272
- "transformer.h.4.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
273
- "transformer.h.4.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
274
- "transformer.h.4.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
275
- "transformer.h.4.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
276
- "transformer.h.5.attn.c_attn.bias": "model-00001-of-00002.safetensors",
277
- "transformer.h.5.attn.c_attn.weight": "model-00001-of-00002.safetensors",
278
- "transformer.h.5.attn.c_proj.bias": "model-00001-of-00002.safetensors",
279
- "transformer.h.5.attn.c_proj.weight": "model-00001-of-00002.safetensors",
280
- "transformer.h.5.ln_1.weight": "model-00001-of-00002.safetensors",
281
- "transformer.h.5.ln_2.weight": "model-00001-of-00002.safetensors",
282
- "transformer.h.5.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
283
- "transformer.h.5.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
284
- "transformer.h.5.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
285
- "transformer.h.5.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
286
- "transformer.h.6.attn.c_attn.bias": "model-00001-of-00002.safetensors",
287
- "transformer.h.6.attn.c_attn.weight": "model-00001-of-00002.safetensors",
288
- "transformer.h.6.attn.c_proj.bias": "model-00001-of-00002.safetensors",
289
- "transformer.h.6.attn.c_proj.weight": "model-00001-of-00002.safetensors",
290
- "transformer.h.6.ln_1.weight": "model-00001-of-00002.safetensors",
291
- "transformer.h.6.ln_2.weight": "model-00001-of-00002.safetensors",
292
- "transformer.h.6.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
293
- "transformer.h.6.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
294
- "transformer.h.6.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
295
- "transformer.h.6.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
296
- "transformer.h.7.attn.c_attn.bias": "model-00001-of-00002.safetensors",
297
- "transformer.h.7.attn.c_attn.weight": "model-00001-of-00002.safetensors",
298
- "transformer.h.7.attn.c_proj.bias": "model-00001-of-00002.safetensors",
299
- "transformer.h.7.attn.c_proj.weight": "model-00001-of-00002.safetensors",
300
- "transformer.h.7.ln_1.weight": "model-00001-of-00002.safetensors",
301
- "transformer.h.7.ln_2.weight": "model-00001-of-00002.safetensors",
302
- "transformer.h.7.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
303
- "transformer.h.7.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
304
- "transformer.h.7.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
305
- "transformer.h.7.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
306
- "transformer.h.8.attn.c_attn.bias": "model-00001-of-00002.safetensors",
307
- "transformer.h.8.attn.c_attn.weight": "model-00001-of-00002.safetensors",
308
- "transformer.h.8.attn.c_proj.bias": "model-00001-of-00002.safetensors",
309
- "transformer.h.8.attn.c_proj.weight": "model-00001-of-00002.safetensors",
310
- "transformer.h.8.ln_1.weight": "model-00001-of-00002.safetensors",
311
- "transformer.h.8.ln_2.weight": "model-00001-of-00002.safetensors",
312
- "transformer.h.8.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
313
- "transformer.h.8.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
314
- "transformer.h.8.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
315
- "transformer.h.8.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
316
- "transformer.h.9.attn.c_attn.bias": "model-00001-of-00002.safetensors",
317
- "transformer.h.9.attn.c_attn.weight": "model-00001-of-00002.safetensors",
318
- "transformer.h.9.attn.c_proj.bias": "model-00001-of-00002.safetensors",
319
- "transformer.h.9.attn.c_proj.weight": "model-00001-of-00002.safetensors",
320
- "transformer.h.9.ln_1.weight": "model-00001-of-00002.safetensors",
321
- "transformer.h.9.ln_2.weight": "model-00001-of-00002.safetensors",
322
- "transformer.h.9.mlp.c_fc.bias": "model-00001-of-00002.safetensors",
323
- "transformer.h.9.mlp.c_fc.weight": "model-00001-of-00002.safetensors",
324
- "transformer.h.9.mlp.c_proj.bias": "model-00001-of-00002.safetensors",
325
- "transformer.h.9.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
326
- "transformer.ln_f.weight": "model-00002-of-00002.safetensors",
327
- "transformer.wte.weight": "model-00001-of-00002.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  }
329
  }
 
3
  "total_size": 6965007360
4
  },
5
  "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
8
+ "model.layers.0.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.0.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
18
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
20
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
22
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.1.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.1.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
27
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.1.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
29
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
32
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.1.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
34
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
36
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
38
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.10.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
41
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.10.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
43
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.10.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
45
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.10.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
48
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.10.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
50
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.10.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
52
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
54
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.11.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
57
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.11.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
59
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.11.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
61
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
64
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.11.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
66
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.11.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
68
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.11.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
70
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.12.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
73
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.12.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
75
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.12.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
77
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.12.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
80
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.12.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
82
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.12.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
84
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.12.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
86
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
88
+ "model.layers.13.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
89
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
90
+ "model.layers.13.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
91
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.13.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
93
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
95
+ "model.layers.13.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
96
+ "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.13.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
98
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
99
+ "model.layers.13.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
100
+ "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
101
+ "model.layers.13.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
102
+ "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
103
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
104
+ "model.layers.14.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
105
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.layers.14.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
107
+ "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
108
+ "model.layers.14.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
109
+ "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
110
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.14.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
112
+ "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
113
+ "model.layers.14.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
114
+ "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
115
+ "model.layers.14.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
116
+ "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.14.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
118
+ "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
119
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
120
+ "model.layers.15.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
121
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.15.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
123
+ "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.15.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
125
+ "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.15.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
128
+ "model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
129
+ "model.layers.15.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
130
+ "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
131
+ "model.layers.15.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
132
+ "model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
133
+ "model.layers.15.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
134
+ "model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
135
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
136
+ "model.layers.16.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
137
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.16.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
139
+ "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
140
+ "model.layers.16.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
141
+ "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
143
+ "model.layers.16.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
144
+ "model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
145
+ "model.layers.16.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
146
+ "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.16.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
148
+ "model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.16.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
150
+ "model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.17.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
153
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
154
+ "model.layers.17.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
155
+ "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.17.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
157
+ "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
158
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
159
+ "model.layers.17.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
160
+ "model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
161
+ "model.layers.17.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
162
+ "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
163
+ "model.layers.17.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
164
+ "model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
165
+ "model.layers.17.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
166
+ "model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
167
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
168
+ "model.layers.18.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
169
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
170
+ "model.layers.18.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
171
+ "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
172
+ "model.layers.18.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
173
+ "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
174
+ "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
175
+ "model.layers.18.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
176
+ "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
177
+ "model.layers.18.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
178
+ "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
179
+ "model.layers.18.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
180
+ "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
181
+ "model.layers.18.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
182
+ "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
183
+ "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
184
+ "model.layers.19.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
185
+ "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
186
+ "model.layers.19.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
187
+ "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
188
+ "model.layers.19.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
189
+ "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
190
+ "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
191
+ "model.layers.19.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
192
+ "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
193
+ "model.layers.19.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
194
+ "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
195
+ "model.layers.19.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
196
+ "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
197
+ "model.layers.19.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
198
+ "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
199
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
200
+ "model.layers.2.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
201
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
202
+ "model.layers.2.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
203
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
204
+ "model.layers.2.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
205
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
206
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
207
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
208
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
209
+ "model.layers.2.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
210
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
211
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
212
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
213
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
214
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
215
+ "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
216
+ "model.layers.20.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
217
+ "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
218
+ "model.layers.20.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
219
+ "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
220
+ "model.layers.20.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
221
+ "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
222
+ "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
223
+ "model.layers.20.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
224
+ "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
225
+ "model.layers.20.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
226
+ "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
227
+ "model.layers.20.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
228
+ "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
229
+ "model.layers.20.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
230
+ "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
231
+ "model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors",
232
+ "model.layers.21.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
233
+ "model.layers.21.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
234
+ "model.layers.21.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
235
+ "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
236
+ "model.layers.21.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
237
+ "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
238
+ "model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
239
+ "model.layers.21.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
240
+ "model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
241
+ "model.layers.21.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
242
+ "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
243
+ "model.layers.21.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
244
+ "model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
245
+ "model.layers.21.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
246
+ "model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
247
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
248
+ "model.layers.22.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
249
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
250
+ "model.layers.22.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
251
+ "model.layers.22.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
252
+ "model.layers.22.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
253
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
254
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
255
+ "model.layers.22.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
256
+ "model.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
257
+ "model.layers.22.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
258
+ "model.layers.22.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
259
+ "model.layers.22.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
260
+ "model.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
261
+ "model.layers.22.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
262
+ "model.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
263
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
264
+ "model.layers.23.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
265
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
266
+ "model.layers.23.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
267
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
268
+ "model.layers.23.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
269
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
270
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
271
+ "model.layers.23.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
272
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
273
+ "model.layers.23.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
274
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
275
+ "model.layers.23.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
276
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
277
+ "model.layers.23.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
278
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
279
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
280
+ "model.layers.24.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
281
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
282
+ "model.layers.24.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
283
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
284
+ "model.layers.24.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
285
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
286
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
287
+ "model.layers.24.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
288
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
289
+ "model.layers.24.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
290
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
291
+ "model.layers.24.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
292
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
293
+ "model.layers.24.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
294
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
295
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
296
+ "model.layers.25.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
297
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
298
+ "model.layers.25.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
299
+ "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
300
+ "model.layers.25.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
301
+ "model.layers.25.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
302
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
303
+ "model.layers.25.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
304
+ "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
305
+ "model.layers.25.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
306
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
307
+ "model.layers.25.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
308
+ "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
309
+ "model.layers.25.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
310
+ "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
311
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
312
+ "model.layers.26.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
313
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
314
+ "model.layers.26.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
315
+ "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
316
+ "model.layers.26.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
317
+ "model.layers.26.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
318
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
319
+ "model.layers.26.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
320
+ "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
321
+ "model.layers.26.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
322
+ "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
323
+ "model.layers.26.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
324
+ "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
325
+ "model.layers.26.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
326
+ "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
327
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
328
+ "model.layers.27.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
329
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
330
+ "model.layers.27.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
331
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
332
+ "model.layers.27.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
333
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
334
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
335
+ "model.layers.27.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
336
+ "model.layers.27.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
337
+ "model.layers.27.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
338
+ "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
339
+ "model.layers.27.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
340
+ "model.layers.27.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
341
+ "model.layers.27.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
342
+ "model.layers.27.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
343
+ "model.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
344
+ "model.layers.28.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
345
+ "model.layers.28.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
346
+ "model.layers.28.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
347
+ "model.layers.28.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
348
+ "model.layers.28.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
349
+ "model.layers.28.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
350
+ "model.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
351
+ "model.layers.28.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
352
+ "model.layers.28.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
353
+ "model.layers.28.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
354
+ "model.layers.28.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
355
+ "model.layers.28.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
356
+ "model.layers.28.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
357
+ "model.layers.28.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
358
+ "model.layers.28.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
359
+ "model.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
360
+ "model.layers.29.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
361
+ "model.layers.29.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
362
+ "model.layers.29.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
363
+ "model.layers.29.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
364
+ "model.layers.29.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
365
+ "model.layers.29.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
366
+ "model.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
367
+ "model.layers.29.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
368
+ "model.layers.29.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
369
+ "model.layers.29.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
370
+ "model.layers.29.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
371
+ "model.layers.29.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
372
+ "model.layers.29.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
373
+ "model.layers.29.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
374
+ "model.layers.29.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
375
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
376
+ "model.layers.3.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
377
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
378
+ "model.layers.3.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
379
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
380
+ "model.layers.3.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
381
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
382
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
383
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
384
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
385
+ "model.layers.3.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
386
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
387
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
388
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
389
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
390
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
391
+ "model.layers.30.input_layernorm.weight": "model-00002-of-00002.safetensors",
392
+ "model.layers.30.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
393
+ "model.layers.30.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
394
+ "model.layers.30.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
395
+ "model.layers.30.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
396
+ "model.layers.30.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
397
+ "model.layers.30.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
398
+ "model.layers.30.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
399
+ "model.layers.30.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
400
+ "model.layers.30.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
401
+ "model.layers.30.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
402
+ "model.layers.30.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
403
+ "model.layers.30.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
404
+ "model.layers.30.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
405
+ "model.layers.30.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
406
+ "model.layers.30.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
407
+ "model.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
408
+ "model.layers.31.mlp.down_proj.bias": "model-00002-of-00002.safetensors",
409
+ "model.layers.31.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
410
+ "model.layers.31.mlp.gate_proj.bias": "model-00002-of-00002.safetensors",
411
+ "model.layers.31.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
412
+ "model.layers.31.mlp.up_proj.bias": "model-00002-of-00002.safetensors",
413
+ "model.layers.31.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
414
+ "model.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
415
+ "model.layers.31.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
416
+ "model.layers.31.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
417
+ "model.layers.31.self_attn.o_proj.bias": "model-00002-of-00002.safetensors",
418
+ "model.layers.31.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
419
+ "model.layers.31.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
420
+ "model.layers.31.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
421
+ "model.layers.31.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
422
+ "model.layers.31.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
423
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
424
+ "model.layers.4.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
425
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
426
+ "model.layers.4.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
427
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
428
+ "model.layers.4.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
429
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
430
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
431
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
432
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
433
+ "model.layers.4.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
434
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
435
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
436
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
437
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
438
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
439
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
440
+ "model.layers.5.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
441
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
442
+ "model.layers.5.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
443
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
444
+ "model.layers.5.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
445
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
446
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
447
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
448
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
449
+ "model.layers.5.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
450
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
451
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
452
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
453
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
454
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
455
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
456
+ "model.layers.6.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
457
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
458
+ "model.layers.6.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
459
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
460
+ "model.layers.6.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
461
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
462
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
463
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
464
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
465
+ "model.layers.6.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
466
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
467
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
468
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
469
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
470
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
471
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
472
+ "model.layers.7.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
473
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
474
+ "model.layers.7.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
475
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
476
+ "model.layers.7.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
477
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
478
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
479
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
480
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
481
+ "model.layers.7.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
482
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
483
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
484
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
485
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
486
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
487
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
488
+ "model.layers.8.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
489
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
490
+ "model.layers.8.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
491
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
492
+ "model.layers.8.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
493
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
494
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
495
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
496
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
497
+ "model.layers.8.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
498
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
499
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
500
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
501
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
502
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
503
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
504
+ "model.layers.9.mlp.down_proj.bias": "model-00001-of-00002.safetensors",
505
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
506
+ "model.layers.9.mlp.gate_proj.bias": "model-00001-of-00002.safetensors",
507
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
508
+ "model.layers.9.mlp.up_proj.bias": "model-00001-of-00002.safetensors",
509
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
510
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
511
+ "model.layers.9.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
512
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
513
+ "model.layers.9.self_attn.o_proj.bias": "model-00001-of-00002.safetensors",
514
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
515
+ "model.layers.9.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
516
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
517
+ "model.layers.9.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
518
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
519
+ "model.norm.weight": "model-00002-of-00002.safetensors"
520
  }
521
  }
modeling_granite.py DELETED
@@ -1,1376 +0,0 @@
1
- import math
2
- import numbers
3
- import warnings
4
- from enum import Enum
5
- from typing import Optional, Tuple, Union
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from transformers import DynamicCache, PreTrainedModel
11
- from transformers.activations import get_activation as get_base_activation
12
- from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
13
- from transformers.utils import is_flash_attn_2_available
14
-
15
- from .configuration_granite import GraniteConfig
16
-
17
-
18
- class PositionEmbeddingType(Enum):
19
- learned_absolute = "learned_absolute"
20
- alibi = "alibi"
21
- rope = "rope"
22
-
23
-
24
- class AttentionHeadType(Enum):
25
- mha = "mha"
26
- mqa = "mqa"
27
- gqa = "gqa"
28
-
29
-
30
- if is_flash_attn_2_available():
31
- from flash_attn.bert_padding import IndexFirstAxis, pad_input, unpad_input
32
- from flash_attn.flash_attn_interface import flash_attn_varlen_func
33
-
34
-
35
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
36
- def get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
37
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
38
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
39
- max_seqlen_in_batch = seqlens_in_batch.max().item()
40
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
41
- return indices, cu_seqlens, max_seqlen_in_batch
42
-
43
-
44
- def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor:
45
- num_groups = num_heads // num_key_value_heads
46
-
47
- # mha
48
- if num_groups == 1:
49
- return x
50
-
51
- # mqa
52
- if num_key_value_heads == 1:
53
- return x.expand(-1, num_heads, -1, -1)
54
-
55
- # gqa
56
- return x.repeat_interleave(num_groups, dim=1)
57
-
58
-
59
- ##################################################
60
- # activation functions
61
-
62
-
63
- _GLU_BASE_MAPPING = {
64
- "geglu": "gelu",
65
- "miglu": "mish",
66
- "mishglu": "mish",
67
- "swiglu": "swish",
68
- }
69
-
70
-
71
- class GLUActivation(nn.Module):
72
- def __init__(self, base_activation: nn.Module) -> None:
73
- super().__init__()
74
- self.base_activation = base_activation
75
-
76
- def forward(self, x: torch.Tensor) -> torch.Tensor:
77
- x = x.chunk(2, dim=-1)
78
- return x[0] * self.base_activation(x[1])
79
-
80
-
81
- def is_glu(name: str) -> bool:
82
- return name.endswith("glu")
83
-
84
-
85
- def get_activation_function(name: str) -> nn.Module:
86
- if is_glu(name):
87
- # for glu and sigmoid_glu, we directly return the pytorch's GLU
88
- if name in ["glu", "sigmoid_glu"]:
89
- activation_function = nn.modules.GLU()
90
- else:
91
- if name in _GLU_BASE_MAPPING:
92
- name = _GLU_BASE_MAPPING[name]
93
- elif name.endswith("_glu"):
94
- name = name.rstrip("_glu")
95
- else:
96
- raise ValueError("invalid activation function")
97
-
98
- base_activation = get_base_activation(name)
99
- activation_function = GLUActivation(base_activation)
100
- else:
101
- activation_function = get_base_activation(name)
102
-
103
- return activation_function
104
-
105
-
106
- ##################################################
107
- # normalization functions
108
-
109
-
110
- class RMSNorm(nn.Module):
111
- def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None:
112
- super().__init__()
113
-
114
- self.weight = nn.Parameter(torch.ones(normalized_shape))
115
- self.eps = eps
116
-
117
- if isinstance(normalized_shape, numbers.Integral):
118
- normalized_shape = (normalized_shape,)
119
- self.normalized_shape = normalized_shape
120
-
121
- def forward(self, input: torch.Tensor) -> torch.Tensor:
122
- input_dtype = input.dtype
123
-
124
- input = input.to(torch.float32)
125
- variance = input.pow(2).mean(-1, keepdim=True)
126
- input = input * torch.rsqrt(variance + self.eps)
127
-
128
- return self.weight * input.to(input_dtype)
129
-
130
- def extra_repr(self) -> str:
131
- return f"{self.normalized_shape}, eps={self.eps}"
132
-
133
- def reset_parameters(self) -> None:
134
- nn.init.ones_(self.weight)
135
-
136
-
137
- _NORMALIZATION_FUNCTIONS = {
138
- "layernorm": nn.LayerNorm,
139
- "rmsnorm": RMSNorm,
140
- }
141
-
142
-
143
- def get_normalization_function(name: str, normalized_shape: int, eps: float = 1e-5) -> nn.Module:
144
- if name in _NORMALIZATION_FUNCTIONS:
145
- return _NORMALIZATION_FUNCTIONS[name](normalized_shape, eps=eps)
146
-
147
- raise ValueError(f"unexpected `normalization_function` {name}")
148
-
149
-
150
- ##################################################
151
- # attention modules
152
-
153
-
154
- class GraniteAttention(nn.Module):
155
- def __init__(self, config: GraniteConfig, causal: bool, layer_idx: Optional[int] = None) -> None:
156
- super().__init__()
157
-
158
- self.causal = causal
159
- self.hidden_size = config.n_embd
160
- self.num_heads = config.n_head
161
- self.num_key_value_heads = config.num_key_value_heads
162
- self.add_bias = config.add_bias
163
-
164
- assert (
165
- self.hidden_size % self.num_heads == 0
166
- ), f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})"
167
-
168
- self.head_dim = self.hidden_size // self.num_heads
169
- self.attention_head_type = AttentionHeadType(config.attention_head_type)
170
-
171
- self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
172
- self.scale_attn_weights = config.scale_attn_weights
173
- self.attention_multiplier = config.attention_multiplier
174
-
175
- self.layer_idx = layer_idx
176
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
177
- self.scale_attention_softmax_in_fp32 = (
178
- config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
179
- )
180
-
181
- if self.attention_head_type == AttentionHeadType.mha:
182
- if self.num_key_value_heads is None:
183
- self.num_key_value_heads = self.num_heads
184
-
185
- assert (
186
- self.num_heads == self.num_key_value_heads
187
- ), f"{self.__class__.__name__} should have same number of heads for query, keys and values"
188
- elif self.attention_head_type == AttentionHeadType.gqa:
189
- assert (
190
- self.num_key_value_heads is not None
191
- ), "`num_key_value_heads` needs to be specified with GroupedQueryAttention"
192
-
193
- assert self.num_heads % self.num_key_value_heads == 0, (
194
- f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` "
195
- f"({self.num_key_value_heads})"
196
- )
197
- elif self.attention_head_type == AttentionHeadType.mqa:
198
- if self.num_key_value_heads is None:
199
- self.num_key_value_heads = 1
200
-
201
- assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values"
202
- else:
203
- raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
204
-
205
- # note that the actual layout is different for the output and depends on whether we are using MHA, MQA or GQA
206
- # (self.hidden_size + 2 * self.num_key_value_heads * self.head_dim) is just the actual number output features
207
- self.c_attn = nn.Linear(
208
- self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=self.add_bias
209
- )
210
- self.c_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.add_bias)
211
-
212
- self.attn_pdrop = config.attn_pdrop
213
- self.resid_pdrop = config.resid_pdrop
214
-
215
- self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop)
216
- self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop)
217
-
218
- def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
219
- # ==========================================================================================
220
- # hidden_states -> (batch_size, query_length, num_heads * head_dim)
221
- # ==========================================================================================
222
-
223
- # the output of following is a tuple if using MQA with tensor parallel
224
- hidden_states = self.c_attn(hidden_states)
225
-
226
- # ==========================================================================================
227
- # hidden_states -> (batch_size, query_length, [num_heads + num_key_value_heads * 2] * head_dim)
228
- # ==========================================================================================
229
-
230
- # for MHA, we can get away with doing just 1 transpose which is not true for GQA
231
- if self.attention_head_type == AttentionHeadType.mha:
232
- query, key, value = self._prepare_qkv_for_forward_mha(hidden_states)
233
- elif self.attention_head_type == AttentionHeadType.gqa:
234
- query, key, value = self._prepare_qkv_for_forward_gqa(hidden_states)
235
- elif self.attention_head_type == AttentionHeadType.mqa:
236
- query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states)
237
- else:
238
- raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
239
-
240
- # ==========================================================================================
241
- # query -> (batch_size, num_heads, query_length, head_dim)
242
- # key -> (batch_size, num_key_value_heads, query_length, head_dim)
243
- # value -> (batch_size, num_key_value_heads, query_length, head_dim)
244
- # ==========================================================================================
245
-
246
- return query, key, value
247
-
248
- def _prepare_qkv_for_forward_mha(
249
- self, hidden_states: torch.Tensor
250
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
251
- batch_size, query_length = hidden_states.shape[:-1]
252
-
253
- hidden_states = hidden_states.view(batch_size, query_length, self.num_heads, -1)
254
- hidden_states = hidden_states.transpose(1, 2)
255
-
256
- query, key, value = hidden_states.chunk(3, dim=-1)
257
-
258
- return query, key, value
259
-
260
- def _prepare_qkv_for_forward_gqa(
261
- self, hidden_states: torch.Tensor
262
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
263
- batch_size, query_length = hidden_states.shape[:-1]
264
-
265
- hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1)
266
-
267
- query, key, value = hidden_states.split(
268
- ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1
269
- )
270
-
271
- # this needs to be a reshape instead of view sadly
272
- query = query.reshape(batch_size, query_length, -1, self.head_dim)
273
-
274
- query = query.transpose(1, 2)
275
- key = key.transpose(1, 2)
276
- value = value.transpose(1, 2)
277
-
278
- return query, key, value
279
-
280
- def _prepare_qkv_for_forward_mqa(
281
- self, hidden_states: torch.Tensor
282
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
283
- batch_size, query_length = hidden_states.shape[:-1]
284
-
285
- query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1)
286
-
287
- query = query.view(batch_size, query_length, self.num_heads, -1)
288
-
289
- query = query.transpose(1, 2)
290
- key = key.unsqueeze(1)
291
- value = value.unsqueeze(1)
292
-
293
- return query, key, value
294
-
295
- def forward(
296
- self,
297
- hidden_states: torch.Tensor,
298
- past_key_values: Optional[DynamicCache] = None,
299
- attention_mask: Optional[torch.Tensor] = None,
300
- rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
301
- ) -> torch.Tensor:
302
- # ==========================================================================================
303
- # hidden_states -> (batch_size, query_length, num_heads * head_dim)
304
- # ==========================================================================================
305
-
306
- query, key, value = self._prepare_qkv_for_forward(hidden_states)
307
-
308
- # ==========================================================================================
309
- # query -> (batch_size, num_heads, query_length, head_dim)
310
- # key -> (batch_size, num_key_value_heads, query_length, head_dim)
311
- # value -> (batch_size, num_key_value_heads, query_length, head_dim)
312
- # ==========================================================================================
313
-
314
- if self.position_embedding_type == PositionEmbeddingType.rope:
315
- query = apply_rotary_pos_emb(query, rope_cos_sin)
316
- key = apply_rotary_pos_emb(key, rope_cos_sin)
317
-
318
- if past_key_values is not None:
319
- key, value = past_key_values.update(key, value, self.layer_idx)
320
-
321
- # ==========================================================================================
322
- # query -> (batch_size, num_heads, query_length, head_dim)
323
- # key -> (batch_size, num_key_value_heads, key_length, head_dim)
324
- # value -> (batch_size, num_key_value_heads, key_length, head_dim)
325
- # ==========================================================================================
326
-
327
- key = key.transpose(-1, -2)
328
-
329
- dtype = query.dtype
330
- softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
331
-
332
- if self.scale_attn_weights:
333
- if self.attention_multiplier is None:
334
- scale_factor = 1 / self.head_dim**0.5
335
- else:
336
- scale_factor = self.attention_multiplier
337
- else:
338
- scale_factor = 1
339
-
340
- # ==========================================================================================
341
- # query -> (batch_size, num_heads, query_length, head_dim)
342
- # key -> (batch_size, num_key_value_heads, head_dim, key_length)
343
- # value -> (batch_size, num_key_value_heads, key_length, head_dim)
344
- # ==========================================================================================
345
-
346
- batch_size = query.shape[0]
347
- query_length = query.shape[2]
348
- key_length = key.shape[-1]
349
-
350
- key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
351
- value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
352
-
353
- # Always copies
354
- query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
355
- # No copy when layer_past is provided.
356
- key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
357
-
358
- # ==========================================================================================
359
- # query -> (batch_size * num_heads, query_length, head_dim)
360
- # key -> (batch_size * num_heads, head_dim, key_length)
361
- # value -> (batch_size, num_heads, key_length, head_dim)
362
- # ==========================================================================================
363
-
364
- attn_weights = torch.empty(
365
- (batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype
366
- )
367
-
368
- attn_weights = torch.baddbmm(attn_weights, query, key, beta=0, alpha=scale_factor).view(
369
- batch_size, self.num_heads, query_length, key_length
370
- )
371
-
372
- # ==========================================================================================
373
- # attn_weights -> (batch_size, num_heads, query_length, key_length)
374
- # ==========================================================================================
375
-
376
- attn_weights = attn_weights.to(softmax_dtype)
377
-
378
- if attention_mask is not None:
379
- attn_weights = attn_weights + attention_mask
380
-
381
- attn_weights = F.softmax(attn_weights, dim=-1).to(dtype)
382
-
383
- attn_weights = self.attn_dropout(attn_weights)
384
-
385
- # ==========================================================================================
386
- # value -> (batch_size, num_heads, key_length, head_dim)
387
- # attn_weights -> (batch_size, num_heads, query_length, key_length)
388
- # ==========================================================================================
389
-
390
- attn_output = torch.matmul(attn_weights, value)
391
-
392
- # ==========================================================================================
393
- # attn_output -> (batch_size, num_heads, query_length, head_dim)
394
- # ==========================================================================================
395
-
396
- attn_output = attn_output.transpose(1, 2)
397
- attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
398
-
399
- # ==========================================================================================
400
- # attn_output -> (batch_size, query_length, num_heads * head_dim)
401
- # ==========================================================================================
402
-
403
- attn_output = self.c_proj(attn_output)
404
- attn_output = self.resid_dropout(attn_output)
405
-
406
- return attn_output
407
-
408
-
409
- class GraniteSDPA(GraniteAttention):
410
- def forward(
411
- self,
412
- hidden_states: torch.Tensor,
413
- past_key_values: Optional[DynamicCache] = None,
414
- attention_mask: Optional[torch.Tensor] = None,
415
- rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
416
- ) -> torch.Tensor:
417
- # ==========================================================================================
418
- # hidden_states -> (batch_size, query_length, num_heads * head_dim)
419
- # ==========================================================================================
420
-
421
- query, key, value = self._prepare_qkv_for_forward(hidden_states)
422
-
423
- # ==========================================================================================
424
- # query -> (batch_size, num_heads, query_length, head_dim)
425
- # key -> (batch_size, num_key_value_heads, query_length, head_dim)
426
- # value -> (batch_size, num_key_value_heads, query_length, head_dim)
427
- # ==========================================================================================
428
-
429
- if self.position_embedding_type == PositionEmbeddingType.rope:
430
- query = apply_rotary_pos_emb(query, rope_cos_sin)
431
- key = apply_rotary_pos_emb(key, rope_cos_sin)
432
-
433
- if past_key_values is not None:
434
- key, value = past_key_values.update(key, value, self.layer_idx)
435
-
436
- # ==========================================================================================
437
- # query -> (batch_size, num_heads, query_length, head_dim)
438
- # key -> (batch_size, num_key_value_heads, key_length, head_dim)
439
- # value -> (batch_size, num_key_value_heads, key_length, head_dim)
440
- # ==========================================================================================
441
-
442
- key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
443
- value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
444
-
445
- # ==========================================================================================
446
- # query -> (batch_size, num_heads, query_length, head_dim)
447
- # key -> (batch_size, num_heads, key_length, head_dim)
448
- # value -> (batch_size, num_heads, key_length, head_dim)
449
- # ==========================================================================================
450
-
451
- attn_output = F.scaled_dot_product_attention(
452
- query,
453
- key,
454
- value,
455
- attn_mask=attention_mask,
456
- dropout_p=self.attn_pdrop if self.training else 0,
457
- is_causal=self.causal if attention_mask is None else False,
458
- scale=self.attention_multiplier if self.scale_attn_weights else 1,
459
- )
460
-
461
- # ==========================================================================================
462
- # attn_output -> (batch_size, num_heads, query_length, head_dim)
463
- # ==========================================================================================
464
-
465
- batch_size = attn_output.shape[0]
466
- attn_output = attn_output.transpose(1, 2)
467
- attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
468
-
469
- # ==========================================================================================
470
- # attn_output -> (batch_size, query_length, num_heads * head_dim)
471
- # ==========================================================================================
472
-
473
- attn_output = self.c_proj(attn_output)
474
- attn_output = self.resid_dropout(attn_output)
475
-
476
- return attn_output
477
-
478
-
479
- class GraniteFlashAttention2(GraniteAttention):
480
- def forward(
481
- self,
482
- hidden_states: torch.Tensor,
483
- past_key_values: Optional[DynamicCache] = None,
484
- attention_mask: Optional[torch.Tensor] = None,
485
- rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
486
- ) -> torch.Tensor:
487
- # ==========================================================================================
488
- # hidden_states -> (batch_size, query_length, num_heads * head_dim)
489
- # ==========================================================================================
490
-
491
- query, key, value = self._prepare_qkv_for_forward(hidden_states)
492
-
493
- # ==========================================================================================
494
- # query -> (batch_size, num_heads, query_length, head_dim)
495
- # key -> (batch_size, num_key_value_heads, query_length, head_dim)
496
- # value -> (batch_size, num_key_value_heads, query_length, head_dim)
497
- # ==========================================================================================
498
-
499
- if self.position_embedding_type == PositionEmbeddingType.rope:
500
- query = apply_rotary_pos_emb(query, rope_cos_sin)
501
- key = apply_rotary_pos_emb(key, rope_cos_sin)
502
-
503
- if past_key_values is not None:
504
- key, value = past_key_values.update(key, value, self.layer_idx)
505
-
506
- # ==========================================================================================
507
- # query -> (batch_size, num_heads, query_length, head_dim)
508
- # key -> (batch_size, num_key_value_heads, key_length, head_dim)
509
- # value -> (batch_size, num_key_value_heads, key_length, head_dim)
510
- # ==========================================================================================
511
-
512
- # TODO avoid this extra transpose
513
- query = query.transpose(1, 2)
514
- if self.attention_head_type == AttentionHeadType.mqa:
515
- key = key.squeeze(1).unsqueeze(2)
516
- value = value.squeeze(1).unsqueeze(2)
517
- else:
518
- key = key.transpose(1, 2)
519
- value = value.transpose(1, 2)
520
-
521
- # ==========================================================================================
522
- # query -> (batch_size, query_length, num_heads, head_dim)
523
- # key -> (batch_size, key_length, num_heads, head_dim)
524
- # value -> (batch_size, key_length, num_heads, head_dim)
525
- # ==========================================================================================
526
-
527
- batch_size, query_length = query.shape[:2]
528
- key_length = key.shape[1]
529
- indices_k, cu_seqlens_k, max_seqlen_k = get_unpad_data(attention_mask)
530
-
531
- key = IndexFirstAxis.apply(
532
- key.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
533
- )
534
- value = IndexFirstAxis.apply(
535
- value.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
536
- )
537
-
538
- if query_length == key_length:
539
- query = IndexFirstAxis.apply(
540
- query.reshape(batch_size * key_length, self.num_heads, self.head_dim), indices_k
541
- )
542
- cu_seqlens_q = cu_seqlens_k
543
- max_seqlen_q = max_seqlen_k
544
- indices_q = indices_k
545
- elif query_length == 1:
546
- max_seqlen_q = 1
547
- cu_seqlens_q = torch.arange(
548
- batch_size + 1, dtype=torch.int32, device=query.device
549
- ) # There is a memcpy here, that is very bad.
550
- indices_q = cu_seqlens_q[:-1]
551
- query = query.squeeze(1)
552
- else:
553
- # The -q_len: slice assumes left padding.
554
- attention_mask = attention_mask[:, -query_length:]
555
- query, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query, attention_mask)
556
-
557
- # ==========================================================================================
558
- # query -> (total_q, num_heads, head_dim)
559
- # key -> (total_q, num_heads, head_dim)
560
- # value -> (total_q, num_heads, head_dim)
561
- # ==========================================================================================
562
-
563
- attn_output = flash_attn_varlen_func(
564
- query,
565
- key,
566
- value,
567
- cu_seqlens_q=cu_seqlens_q,
568
- cu_seqlens_k=cu_seqlens_k,
569
- max_seqlen_q=max_seqlen_q,
570
- max_seqlen_k=max_seqlen_k,
571
- dropout_p=self.attn_pdrop if self.training else 0,
572
- softmax_scale=self.attention_multiplier if self.scale_attn_weights else 1,
573
- causal=self.causal,
574
- )
575
-
576
- # ==========================================================================================
577
- # attn_output -> (total_q, num_heads, head_dim)
578
- # ==========================================================================================
579
-
580
- attn_output = pad_input(attn_output, indices_q, batch_size, query_length)
581
- attn_output = attn_output.view(batch_size, query_length, -1)
582
-
583
- # ==========================================================================================
584
- # attn_output -> (batch_size, query_length, num_heads * head_dim)
585
- # ==========================================================================================
586
-
587
- attn_output = self.c_proj(attn_output)
588
- attn_output = self.resid_dropout(attn_output)
589
-
590
- return attn_output
591
-
592
-
593
- _ATTENTION_MODULES = {
594
- "eager": GraniteAttention,
595
- "sdpa": GraniteSDPA,
596
- "flash_attention_2": GraniteFlashAttention2,
597
- }
598
-
599
-
600
- def get_attention_module(
601
- config: GraniteConfig, causal: bool, attention_implementation: str, layer_idx: int
602
- ) -> GraniteAttention:
603
- if attention_implementation in _ATTENTION_MODULES:
604
- return _ATTENTION_MODULES[attention_implementation](config, causal=causal, layer_idx=layer_idx)
605
- raise ValueError(f"unexpected `attention_implementation` {attention_implementation}")
606
-
607
-
608
- ##################################################
609
- # position embeddings
610
-
611
-
612
- class Alibi(nn.Module):
613
- def __init__(self, num_heads: int) -> None:
614
- super().__init__()
615
- self.num_heads = num_heads
616
-
617
- self.reset_parameters()
618
-
619
- def forward(
620
- self, attention_mask: torch.Tensor, batch_size: int, key_length: int, device: torch.device, dtype: torch.dtype
621
- ) -> torch.Tensor:
622
- """
623
- Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
624
- relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
625
- `softmax(l+a) = softmax(l)`. Based on
626
- https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
627
- TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
628
-
629
- Args:
630
- attention_mask (torch.Tensor): attention_mask tensor of shape (`batch_size`, `key_length`)
631
- num_heads (int): `num_heads` for the model
632
- batch_size (int): `batch_size`
633
- key_length (int): `key_length`
634
- device (torch.device): device for the tensors
635
- dtype (torch.dtype): dtype to use for the tensors
636
-
637
- Returns:
638
- torch.Tensor: alibi tensor of shape (`batch_size`, `num_heads`, `key_length`)
639
- """
640
-
641
- # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
642
- # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
643
- # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
644
- # => the query_length dimension will then be broadcasted correctly
645
- # This is more or less identical to T5's relative position bias:
646
- # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
647
- if attention_mask is None:
648
- arange_tensor = (
649
- torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1)
650
- )
651
- else:
652
- arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1)
653
-
654
- alibi = self.slopes.unsqueeze(1) * arange_tensor
655
- return alibi.to(dtype)
656
-
657
- def reset_parameters(self) -> None:
658
- closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads))
659
- base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32)
660
- powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
661
- slopes = torch.pow(base, powers)
662
-
663
- if closest_power_of_2 != self.num_heads:
664
- extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32)
665
- num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2)
666
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32)
667
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
668
-
669
- self.register_buffer("slopes", slopes, persistent=False)
670
-
671
-
672
- class RoPE(nn.Module):
673
- def __init__(
674
- self,
675
- head_dim: int,
676
- max_position_embeddings: int = 2048,
677
- base: int = 10000,
678
- ) -> None:
679
- super().__init__()
680
-
681
- self.head_dim = head_dim
682
- self.max_position_embeddings = max_position_embeddings
683
- self.base = base
684
- self.mscale = 1
685
-
686
- self.reset_parameters()
687
-
688
- def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
689
- if seq_len > self.max_seq_len_cached:
690
- self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype)
691
-
692
- cos = self.cos_cached[:seq_len].to(dtype)
693
- sin = self.sin_cached[:seq_len].to(dtype)
694
-
695
- return cos, sin
696
-
697
- def reset_parameters(self) -> None:
698
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
699
- self.register_buffer("inv_freq", inv_freq, persistent=False)
700
-
701
- # Build here to make `torch.jit.trace` work.
702
- self._set_cos_sin_cache(
703
- seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
704
- )
705
-
706
- @torch.no_grad()
707
- def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
708
- self.max_seq_len_cached = seq_len
709
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
710
-
711
- freqs = torch.outer(t, self.inv_freq)
712
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
713
- emb = torch.cat((freqs, freqs), dim=-1)
714
-
715
- self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
716
- self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
717
-
718
-
719
- def apply_rotary_pos_emb(x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
720
- cos, sin = cos_sin
721
- x = (x * cos) + (_rotate_half(x) * sin)
722
- return x
723
-
724
-
725
- def _rotate_half(x: torch.Tensor) -> torch.Tensor:
726
- x1, x2 = torch.chunk(x, 2, dim=-1)
727
- return torch.cat((-x2, x1), dim=-1)
728
-
729
-
730
- ##################################################
731
- # MLP
732
-
733
-
734
- class GraniteMLP(nn.Module):
735
- def __init__(self, config: GraniteConfig) -> None:
736
- super().__init__()
737
-
738
- hidden_size = config.n_embd
739
- intermediate_size = config.n_inner
740
- activation_function = config.activation_function
741
- add_bias = config.add_bias
742
- residual_dropout = config.resid_pdrop
743
-
744
- self.c_fc = nn.Linear(
745
- hidden_size,
746
- 2 * intermediate_size if is_glu(activation_function) else intermediate_size,
747
- bias=add_bias,
748
- )
749
- self.act = get_activation_function(activation_function)
750
- self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=add_bias)
751
- self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout)
752
-
753
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
754
- hidden_states = self.c_fc(hidden_states)
755
- hidden_states = self.act(hidden_states)
756
- hidden_states = self.c_proj(hidden_states)
757
- hidden_states = self.dropout(hidden_states)
758
- return hidden_states
759
-
760
-
761
- ##################################################
762
- # transformer layer
763
-
764
-
765
- class GraniteBlock(nn.Module):
766
- def __init__(
767
- self,
768
- config: GraniteConfig,
769
- attention_implementation: str,
770
- layer_idx: Optional[int] = None,
771
- ) -> None:
772
- super().__init__()
773
-
774
- hidden_size = config.hidden_size
775
- self.inner_dim = config.n_inner
776
- self.layer_idx = layer_idx
777
-
778
- self.ln_1 = get_normalization_function(
779
- config.normalization_function,
780
- hidden_size,
781
- eps=config.layer_norm_epsilon,
782
- )
783
- self.attn = get_attention_module(config, True, attention_implementation, layer_idx)
784
- self.ln_2 = get_normalization_function(
785
- config.normalization_function,
786
- hidden_size,
787
- eps=config.layer_norm_epsilon,
788
- )
789
- self.mlp = GraniteMLP(config)
790
-
791
- def forward(
792
- self,
793
- hidden_states: torch.Tensor,
794
- past_key_values: Optional[DynamicCache] = None,
795
- attention_mask: Optional[torch.Tensor] = None,
796
- rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
797
- ) -> torch.Tensor:
798
- residual = hidden_states
799
- hidden_states = self.ln_1(hidden_states)
800
-
801
- attn_output = self.attn(
802
- hidden_states,
803
- past_key_values=past_key_values,
804
- attention_mask=attention_mask,
805
- rope_cos_sin=rope_cos_sin,
806
- )
807
-
808
- # residual connection
809
- hidden_states = attn_output + residual
810
-
811
- residual = hidden_states
812
- hidden_states = self.ln_2(hidden_states)
813
-
814
- feed_forward_hidden_states = self.mlp(hidden_states)
815
-
816
- # residual connection
817
- hidden_states = residual + feed_forward_hidden_states
818
-
819
- return hidden_states
820
-
821
-
822
- ##################################################
823
- # model classes
824
-
825
-
826
- class GranitePreTrainedModel(PreTrainedModel):
827
- """
828
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
829
- models.
830
- """
831
-
832
- config_class = GraniteConfig
833
- base_model_prefix = "transformer"
834
- causal = True
835
- _no_split_modules = ["GraniteBlock"]
836
- _skip_keys_device_placement = "past_key_values"
837
- _supports_sdpa = True
838
- _supports_flash_attn_2 = True
839
-
840
- def __init__(self, config: GraniteConfig, *inputs, **kwargs):
841
- super().__init__(config, *inputs, **kwargs)
842
-
843
- self.attention_implementation = self.config._attn_implementation
844
- self._use_eager_attention = self.attention_implementation == "eager"
845
- self._use_sdpa = self.attention_implementation == "sdpa"
846
- self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2"
847
-
848
- self.initializer_range = config.initializer_range
849
-
850
- def _init_weights(self, module: nn.Module) -> None:
851
- if isinstance(module, (nn.LayerNorm, RMSNorm, Alibi, RoPE)):
852
- module.reset_parameters()
853
- elif isinstance(module, nn.Linear):
854
- nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
855
- if module.bias is not None:
856
- module.bias.zero_()
857
- elif isinstance(module, nn.Embedding):
858
- nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
859
- if module.padding_idx is not None:
860
- module.weight[module.padding_idx].zero_()
861
-
862
-
863
- class GraniteModel(GranitePreTrainedModel):
864
- _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
865
- mask_value = None
866
-
867
- def __init__(self, config: GraniteConfig, **kwargs) -> None:
868
- super().__init__(config, **kwargs)
869
-
870
- self.attention_head_type = AttentionHeadType(config.attention_head_type)
871
- self.embed_dim = config.hidden_size
872
- self.num_heads = config.num_attention_heads
873
- self.num_key_value_heads = config.num_key_value_heads
874
-
875
- assert (
876
- self.embed_dim % self.num_heads == 0
877
- ), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})"
878
-
879
- self.head_dim = self.embed_dim // self.num_heads
880
-
881
- self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
882
-
883
- self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop)
884
- self.h = nn.ModuleList(
885
- [GraniteBlock(config, self.attention_implementation, layer_idx=i) for i in range(config.num_hidden_layers)]
886
- )
887
- self.ln_f = get_normalization_function(
888
- config.normalization_function,
889
- self.embed_dim,
890
- eps=config.layer_norm_epsilon,
891
- )
892
-
893
- self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
894
-
895
- if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
896
- self.wpe = nn.Embedding(config.n_positions, self.embed_dim)
897
- elif self.position_embedding_type == PositionEmbeddingType.alibi:
898
- assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention"
899
-
900
- self.alibi = Alibi(self.num_heads)
901
- elif self.position_embedding_type == PositionEmbeddingType.rope:
902
- self.rope = RoPE(self.head_dim, max_position_embeddings=config.n_positions, base=config.rope_theta)
903
- else:
904
- raise NotImplementedError()
905
-
906
- # Initialize weights and apply final processing
907
- self.post_init()
908
-
909
- def get_input_embeddings(self) -> nn.Embedding:
910
- return self.wte
911
-
912
- def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
913
- self.wte = new_embeddings
914
-
915
- def forward(
916
- self,
917
- input_ids: Optional[torch.Tensor] = None,
918
- past_key_values: Optional[DynamicCache] = None,
919
- attention_mask: Optional[torch.Tensor] = None,
920
- token_type_ids: Optional[torch.Tensor] = None,
921
- position_ids: Optional[torch.Tensor] = None,
922
- inputs_embeds: Optional[torch.Tensor] = None,
923
- use_cache: Optional[bool] = None,
924
- output_hidden_states: Optional[bool] = None,
925
- return_dict: Optional[bool] = None,
926
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
927
- (
928
- output_hidden_states,
929
- use_cache,
930
- return_dict,
931
- input_shape,
932
- hidden_states,
933
- attention_mask,
934
- position_ids,
935
- rope_cos_sin,
936
- past_key_values,
937
- ) = self._prepare_a_bunch_of_stuff(
938
- input_ids=input_ids,
939
- past_key_values=past_key_values,
940
- attention_mask=attention_mask,
941
- token_type_ids=token_type_ids,
942
- position_ids=position_ids,
943
- inputs_embeds=inputs_embeds,
944
- use_cache=use_cache,
945
- output_hidden_states=output_hidden_states,
946
- return_dict=return_dict,
947
- )
948
-
949
- # ==========================================================================================
950
- # flash:
951
- # attention_mask -> (batch_size, key_length)
952
- # else:
953
- # attention_mask -> (batch_size, 1, query_length, key_length)
954
- # ==========================================================================================
955
-
956
- output_shape = input_shape + (hidden_states.size(-1),)
957
-
958
- past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values
959
- all_hidden_states = () if output_hidden_states else None
960
- for block in self.h:
961
- if output_hidden_states:
962
- all_hidden_states += (hidden_states,)
963
-
964
- hidden_states = block(
965
- hidden_states,
966
- past_key_values=past_key_values,
967
- attention_mask=attention_mask,
968
- rope_cos_sin=rope_cos_sin,
969
- )
970
-
971
- hidden_states = self.ln_f(hidden_states)
972
-
973
- hidden_states = hidden_states.view(output_shape)
974
- # Add last hidden state
975
- if output_hidden_states:
976
- all_hidden_states += (hidden_states,)
977
-
978
- if not return_dict:
979
- return tuple(v for v in [hidden_states, past_key_values, all_hidden_states] if v is not None)
980
-
981
- return BaseModelOutputWithPastAndCrossAttentions(
982
- last_hidden_state=hidden_states,
983
- past_key_values=past_key_values,
984
- hidden_states=all_hidden_states,
985
- )
986
-
987
- def _get_position_ids(
988
- self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device
989
- ) -> torch.Tensor:
990
- if attention_mask is not None and len(attention_mask.shape) == 2:
991
- # create position_ids on the fly for batch generation
992
- position_ids = attention_mask.long().cumsum(-1) - 1
993
- position_ids.masked_fill_(attention_mask == 0, 0)
994
- if past_length > 0:
995
- position_ids = position_ids[:, past_length:key_length:]
996
- else:
997
- position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device)
998
- position_ids = position_ids.unsqueeze(0).view(-1, query_length)
999
-
1000
- return position_ids
1001
-
1002
- def _get_alibi_bias(
1003
- self,
1004
- attention_mask: torch.Tensor,
1005
- batch_size: int,
1006
- query_length: int,
1007
- key_length: int,
1008
- device: torch.device,
1009
- dtype: torch.dtype,
1010
- ) -> torch.Tensor:
1011
- if self.position_embedding_type != PositionEmbeddingType.alibi:
1012
- return None
1013
-
1014
- alibi_bias = self.alibi(attention_mask, batch_size, key_length, device, dtype)
1015
-
1016
- # ==========================================================================================
1017
- # alibi_bias -> (batch_size, num_heads, key_length)
1018
- # ==========================================================================================
1019
-
1020
- alibi_bias = alibi_bias.unsqueeze(2)
1021
- if query_length != 1:
1022
- alibi_bias = alibi_bias.expand(-1, -1, query_length, -1)
1023
-
1024
- # ==========================================================================================
1025
- # alibi_bias -> (batch_size, num_heads, query_length, key_length)
1026
- # ==========================================================================================
1027
-
1028
- return alibi_bias
1029
-
1030
- def _get_rope_cos_sin(
1031
- self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device
1032
- ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
1033
- if self.position_embedding_type == PositionEmbeddingType.rope:
1034
- cos, sin = self.rope(key_length, dtype=dtype, device=device)
1035
- cos = cos[position_ids].unsqueeze(1)
1036
- sin = sin[position_ids].unsqueeze(1)
1037
- return cos, sin
1038
-
1039
- def _prepare_causal_attention_mask(
1040
- self, attention_mask: torch.Tensor, batch_size: int, query_length: int, key_length: int, device: torch.device
1041
- ) -> torch.Tensor:
1042
- past_length = key_length - query_length
1043
-
1044
- # ==========================================================================================
1045
- # attention_mask -> (batch_size, key_length)
1046
- # ==========================================================================================
1047
-
1048
- if query_length > 1:
1049
- # (query_length, key_length)
1050
- causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device)
1051
- causal_mask[:, past_length:] = torch.tril(
1052
- torch.ones(query_length, query_length, dtype=torch.bool, device=device)
1053
- )
1054
-
1055
- if past_length > 0:
1056
- causal_mask[:, :past_length] = True
1057
-
1058
- # (query_length, key_length) -> (1, query_length, key_length)
1059
- causal_mask = causal_mask.unsqueeze(0)
1060
-
1061
- if attention_mask is None:
1062
- # (1, query_length, key_length) -> (batch_size, query_length, key_length)
1063
- causal_mask = causal_mask.expand(batch_size, -1, -1)
1064
- else:
1065
- # (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length)
1066
- causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool)
1067
- else:
1068
- if attention_mask is None:
1069
- # (batch_size, query_length, key_length)
1070
- causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device)
1071
- else:
1072
- # (batch_size, query_length, key_length)
1073
- causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device)
1074
-
1075
- # ==========================================================================================
1076
- # attention_mask -> (batch_size, query_length, key_length)
1077
- # ==========================================================================================
1078
-
1079
- causal_mask = causal_mask.unsqueeze(1)
1080
-
1081
- # ==========================================================================================
1082
- # attention_mask -> (batch_size, 1, query_length, key_length)
1083
- # ==========================================================================================
1084
-
1085
- return causal_mask
1086
-
1087
- def _get_initial_hidden_state(
1088
- self,
1089
- input_ids: torch.Tensor,
1090
- inputs_embeds: torch.Tensor,
1091
- position_ids: torch.Tensor,
1092
- token_type_ids: torch.Tensor,
1093
- ) -> torch.Tensor:
1094
- if inputs_embeds is None:
1095
- inputs_embeds = self.wte(input_ids)
1096
-
1097
- if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
1098
- inputs_embeds = inputs_embeds + self.wpe(position_ids)
1099
-
1100
- if token_type_ids is not None:
1101
- inputs_embeds = inputs_embeds + self.wte(token_type_ids)
1102
-
1103
- inputs_embeds = self.drop(inputs_embeds)
1104
-
1105
- return inputs_embeds
1106
-
1107
- def _prepare_a_bunch_of_stuff(
1108
- self,
1109
- input_ids: torch.Tensor,
1110
- past_key_values: DynamicCache,
1111
- attention_mask: torch.Tensor,
1112
- token_type_ids: torch.Tensor,
1113
- position_ids: torch.Tensor,
1114
- inputs_embeds: torch.Tensor,
1115
- use_cache: bool,
1116
- output_hidden_states: bool,
1117
- return_dict: bool,
1118
- ) -> Tuple[
1119
- bool,
1120
- bool,
1121
- bool,
1122
- torch.Size,
1123
- torch.Tensor,
1124
- torch.Tensor,
1125
- torch.Tensor,
1126
- Optional[Tuple[torch.Tensor, torch.Tensor]],
1127
- DynamicCache,
1128
- ]:
1129
- output_hidden_states = (
1130
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1131
- )
1132
-
1133
- use_cache = self.config.use_cache if use_cache is None else use_cache
1134
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1135
-
1136
- if input_ids is not None and inputs_embeds is not None:
1137
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1138
- elif input_ids is not None:
1139
- input_shape = input_ids.size()
1140
- elif inputs_embeds is not None:
1141
- # TODO special handling for padding free transformer needed here if we support inputs_embeds argument
1142
- input_shape = inputs_embeds.size()[:-1]
1143
- else:
1144
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1145
-
1146
- batch_size = input_shape[0]
1147
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1148
-
1149
- if self.position_embedding_type == PositionEmbeddingType.alibi:
1150
- if position_ids is not None:
1151
- warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning)
1152
-
1153
- if token_type_ids is not None:
1154
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
1155
-
1156
- # ==========================================================================================
1157
- # input_ids -> (batch_size, query_length)
1158
- # attention_mask -> None or (batch_size, key_length)
1159
- # position_ids -> None or (batch_size, key_length)
1160
- # ==========================================================================================
1161
-
1162
- past_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1163
- query_length = input_shape[-1]
1164
- key_length = past_length + query_length
1165
-
1166
- if position_ids is None:
1167
- position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device)
1168
-
1169
- # ==========================================================================================
1170
- # input_ids -> (batch_size, query_length)
1171
- # attention_mask -> None or (batch_size, key_length)
1172
- # position_ids -> (batch_size, query_length)
1173
- # ==========================================================================================
1174
-
1175
- hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids)
1176
-
1177
- # ==========================================================================================
1178
- # hidden_states -> (batch_size, query_length, num_heads * head_dim)
1179
- # ==========================================================================================
1180
-
1181
- alibi_bias = self._get_alibi_bias(
1182
- attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype
1183
- )
1184
-
1185
- # ==========================================================================================
1186
- # alibi_bias -> (batch_size, num_heads, query_length, key_length)
1187
- # ==========================================================================================
1188
-
1189
- rope_cos_sin = self._get_rope_cos_sin(
1190
- key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device
1191
- )
1192
-
1193
- # ==========================================================================================
1194
- # rope_cos_sin -> 2 * (key_length, head_dim)
1195
- # ==========================================================================================
1196
-
1197
- # prepare causal mask only if not using flash attention
1198
- if self._use_flash_attention_2:
1199
- if attention_mask is None:
1200
- attention_mask = torch.ones_like(input_ids)
1201
- elif self._use_sdpa:
1202
- # we use the causal/non-causal argument of SDPA for attention in this case
1203
- if attention_mask is not None:
1204
- attention_mask = self._prepare_causal_attention_mask(
1205
- attention_mask, batch_size, query_length, key_length, device
1206
- )
1207
-
1208
- attention_mask = torch.where(
1209
- attention_mask,
1210
- ~attention_mask if alibi_bias is None else alibi_bias,
1211
- self._get_mask_value(attention_mask.device, hidden_states.dtype),
1212
- )
1213
- else:
1214
- attention_mask = self._prepare_causal_attention_mask(
1215
- attention_mask, batch_size, query_length, key_length, device
1216
- )
1217
-
1218
- attention_mask = torch.where(
1219
- attention_mask,
1220
- ~attention_mask if alibi_bias is None else alibi_bias,
1221
- self._get_mask_value(attention_mask.device, hidden_states.dtype),
1222
- )
1223
-
1224
- return (
1225
- output_hidden_states,
1226
- use_cache,
1227
- return_dict,
1228
- input_shape,
1229
- hidden_states,
1230
- attention_mask,
1231
- position_ids,
1232
- rope_cos_sin,
1233
- past_key_values,
1234
- )
1235
-
1236
- def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1237
- # torch.where expects a tensor. We use a cache to avoid recreating it every time.
1238
- if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
1239
- self.mask_value = torch.full([], torch.finfo(torch.float16).min, dtype=dtype, device=device)
1240
- return self.mask_value
1241
-
1242
-
1243
- class GraniteForCausalLM(GranitePreTrainedModel):
1244
- _keys_to_ignore_on_load_missing = ["lm_head.weight"]
1245
-
1246
- def __init__(self, config: GraniteConfig, **kwargs) -> None:
1247
- super().__init__(config, **kwargs)
1248
- self.transformer = GraniteModel(config, **kwargs)
1249
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1250
-
1251
- # Initialize weights and apply final processing
1252
- self.post_init()
1253
-
1254
- def get_input_embeddings(self) -> nn.Embedding:
1255
- return self.transformer.wte
1256
-
1257
- def set_input_embeddings(self, value: nn.Embedding) -> None:
1258
- self.transformer.wte = value
1259
-
1260
- def get_output_embeddings(self) -> nn.Linear:
1261
- return self.lm_head
1262
-
1263
- def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1264
- self.lm_head = new_embeddings
1265
-
1266
- # FIXME typing
1267
- def prepare_inputs_for_generation(
1268
- self,
1269
- input_ids: torch.Tensor,
1270
- past_key_values: Optional[DynamicCache] = None,
1271
- inputs_embeds: Optional[torch.Tensor] = None,
1272
- **kwargs,
1273
- ) -> dict:
1274
- token_type_ids = kwargs.get("token_type_ids", None)
1275
- # Omit tokens covered by past_key_values
1276
- if past_key_values:
1277
- past_length = past_key_values.get_seq_length()
1278
-
1279
- # Some generation methods already pass only the last input ID
1280
- if input_ids.shape[1] > past_length:
1281
- remove_prefix_length = past_length
1282
- else:
1283
- # Default to old behavior: keep only final ID
1284
- remove_prefix_length = input_ids.shape[1] - 1
1285
-
1286
- input_ids = input_ids[:, remove_prefix_length:]
1287
- if token_type_ids is not None:
1288
- token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1289
-
1290
- attention_mask = kwargs.get("attention_mask", None)
1291
- position_ids = kwargs.get("position_ids", None)
1292
-
1293
- if attention_mask is not None and position_ids is None:
1294
- # create position_ids on the fly for batch generation
1295
- position_ids = attention_mask.long().cumsum(-1) - 1
1296
- position_ids.masked_fill_(attention_mask == 0, 0)
1297
- if past_key_values:
1298
- position_ids = position_ids[:, -input_ids.shape[1] :]
1299
- else:
1300
- position_ids = None
1301
-
1302
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1303
- if inputs_embeds is not None and past_key_values is None:
1304
- model_inputs = {"inputs_embeds": inputs_embeds}
1305
- else:
1306
- model_inputs = {"input_ids": input_ids}
1307
-
1308
- model_inputs.update(
1309
- {
1310
- "past_key_values": past_key_values,
1311
- "use_cache": kwargs.get("use_cache"),
1312
- "position_ids": position_ids,
1313
- "attention_mask": attention_mask,
1314
- "token_type_ids": token_type_ids,
1315
- }
1316
- )
1317
- return model_inputs
1318
-
1319
- def forward(
1320
- self,
1321
- input_ids: Optional[Union[torch.Tensor]] = None,
1322
- past_key_values: Optional[DynamicCache] = None,
1323
- attention_mask: Optional[torch.Tensor] = None,
1324
- token_type_ids: Optional[Union[torch.Tensor]] = None,
1325
- position_ids: Optional[Union[torch.Tensor]] = None,
1326
- inputs_embeds: Optional[Union[torch.Tensor]] = None,
1327
- labels: Optional[Union[torch.Tensor]] = None,
1328
- use_cache: Optional[bool] = None,
1329
- output_attentions: Optional[bool] = None,
1330
- output_hidden_states: Optional[bool] = None,
1331
- return_dict: Optional[bool] = None,
1332
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1333
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1334
-
1335
- # ==========================================================================================
1336
- # input_ids -> (batch_size, query_length)
1337
- # attention_mask -> None or (batch_size, key_length)
1338
- # position_ids -> None or (batch_size, key_length)
1339
- # ==========================================================================================
1340
-
1341
- transformer_outputs = self.transformer(
1342
- input_ids,
1343
- past_key_values=past_key_values,
1344
- attention_mask=attention_mask,
1345
- token_type_ids=token_type_ids,
1346
- position_ids=position_ids,
1347
- inputs_embeds=inputs_embeds,
1348
- use_cache=use_cache,
1349
- output_hidden_states=output_hidden_states,
1350
- return_dict=return_dict,
1351
- )
1352
- hidden_states = transformer_outputs[0]
1353
-
1354
- lm_logits = self.lm_head(hidden_states)
1355
-
1356
- loss = None
1357
- # Shift so that tokens < n predict n
1358
- if labels is not None:
1359
- shift_logits = lm_logits[..., :-1, :].contiguous()
1360
- shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
1361
-
1362
- # Flatten the tokens
1363
- loss_fct = nn.CrossEntropyLoss()
1364
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1365
-
1366
- if not return_dict:
1367
- output = (lm_logits,) + transformer_outputs[1:]
1368
- return ((loss,) + output) if loss is not None else output
1369
-
1370
- return CausalLMOutputWithCrossAttentions(
1371
- loss=loss,
1372
- logits=lm_logits,
1373
- past_key_values=transformer_outputs.past_key_values,
1374
- hidden_states=transformer_outputs.hidden_states,
1375
- attentions=transformer_outputs.attentions,
1376
- )