mayank-mishra commited on
Commit
448e236
1 Parent(s): 842533b

upload model

Browse files
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "swiglu",
3
+ "add_bias": true,
4
+ "apply_residual_connection_post_layernorm": false,
5
+ "architectures": [
6
+ "GraniteForCausalLM"
7
+ ],
8
+ "attention_head_type": "mha",
9
+ "attention_multiplier": null,
10
+ "attention_softmax_in_fp32": true,
11
+ "attn_pdrop": 0.1,
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_granite.GraniteConfig",
14
+ "AutoModel": "modeling_granite.GraniteModel",
15
+ "AutoModelForCausalLM": "modeling_granite.GraniteForCausalLM"
16
+ },
17
+ "bos_token_id": 0,
18
+ "embd_pdrop": 0.1,
19
+ "eos_token_id": 0,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_epsilon": 1e-05,
22
+ "model_type": "granite",
23
+ "n_embd": 2560,
24
+ "n_head": 32,
25
+ "n_inner": 10240,
26
+ "n_layer": 32,
27
+ "n_positions": 2048,
28
+ "normalization_function": "rmsnorm",
29
+ "num_key_value_heads": 32,
30
+ "pad_token_id": 0,
31
+ "position_embedding_type": "rope",
32
+ "resid_pdrop": 0.1,
33
+ "rope_theta": 10000,
34
+ "scale_attention_softmax_in_fp32": true,
35
+ "scale_attn_weights": true,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.38.1",
38
+ "use_cache": true,
39
+ "vocab_size": 49152
40
+ }
configuration_granite.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class GraniteConfig(PretrainedConfig):
5
+ model_type = "granite"
6
+
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+ attribute_map = {
9
+ "hidden_size": "n_embd",
10
+ "max_position_embeddings": "n_positions",
11
+ "num_attention_heads": "n_head",
12
+ "num_hidden_layers": "n_layer",
13
+ }
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size: int = 50257,
18
+ n_positions: int = 1024,
19
+ n_embd: int = 768,
20
+ n_layer: int = 12,
21
+ n_head: int = 12,
22
+ num_key_value_heads: int = None,
23
+ n_inner: int = None,
24
+ activation_function: str = "gelu_pytorch_tanh",
25
+ attention_head_type: str = "mqa",
26
+ resid_pdrop: float = 0.1,
27
+ embd_pdrop: float = 0.1,
28
+ attn_pdrop: float = 0.1,
29
+ normalization_function: str = "layernorm",
30
+ layer_norm_epsilon: float = 1e-5,
31
+ initializer_range: float = 0.02,
32
+ scale_attn_weights: bool = True,
33
+ attention_multiplier: float = None,
34
+ use_cache: bool = True,
35
+ bos_token_id: int = 50256,
36
+ eos_token_id: int = 50256,
37
+ pad_token_id: int = 50256,
38
+ attention_softmax_in_fp32: bool = True,
39
+ scale_attention_softmax_in_fp32: bool = True,
40
+ add_bias: bool = True,
41
+ position_embedding_type: str = "learned_absolute",
42
+ rope_theta: int = 10000,
43
+ **kwargs,
44
+ ) -> None:
45
+ self.vocab_size = vocab_size
46
+ self.n_positions = n_positions
47
+ self.n_embd = n_embd
48
+ self.n_layer = n_layer
49
+ self.n_head = n_head
50
+ self.num_key_value_heads = num_key_value_heads
51
+ self.n_inner = 4 * n_embd if n_inner is None else n_inner
52
+ self.activation_function = activation_function
53
+ self.attention_head_type = attention_head_type
54
+ self.resid_pdrop = resid_pdrop
55
+ self.embd_pdrop = embd_pdrop
56
+ self.attn_pdrop = attn_pdrop
57
+ self.normalization_function = normalization_function
58
+ self.layer_norm_epsilon = layer_norm_epsilon
59
+ self.initializer_range = initializer_range
60
+ self.scale_attn_weights = scale_attn_weights
61
+ self.attention_multiplier = attention_multiplier
62
+ self.use_cache = use_cache
63
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
64
+ self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
65
+ self.position_embedding_type = position_embedding_type
66
+ self.add_bias = add_bias
67
+ self.rope_theta = rope_theta
68
+
69
+ if self.attention_multiplier is not None:
70
+ assert self.scale_attn_weights
71
+
72
+ # for compatibility with some features
73
+ self.multi_query = attention_head_type == "mqa"
74
+
75
+ if attention_head_type == "mha":
76
+ if self.num_key_value_heads is None:
77
+ self.num_key_value_heads = self.n_head
78
+
79
+ assert (
80
+ self.n_head == self.num_key_value_heads
81
+ ), "MultiHeadAttention should have same number of heads for query, keys and values"
82
+ elif attention_head_type == "mqa":
83
+ if self.num_key_value_heads is None:
84
+ self.num_key_value_heads = 1
85
+
86
+ assert self.num_key_value_heads == 1, "MultiQueryAttention should have 1 head for keys and values"
87
+ elif attention_head_type == "gqa":
88
+ assert (
89
+ self.num_key_value_heads is not None
90
+ ), "`num_key_value_heads` needs to be specified with GroupedQueryAttention"
91
+
92
+ assert (
93
+ self.n_head % self.num_key_value_heads == 0
94
+ ), "GroupedQueryAttention should have more than 1 head for keys and values"
95
+ else:
96
+ raise ValueError(f"unexpected attention_head_type ({attention_head_type})")
97
+
98
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.38.1"
7
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:612678f5629dbc658d29e393ef01b0ddc26c0ddcc7eb15e98bc1145c2f66c20b
3
+ size 4804086856
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:085bbe4bfa511e3b8cd345c2f65ca44075894cb967e2748bf2a5780b195b10f9
3
+ size 4930111520
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c05dc3ccf273010235c01cadae113ec29a366a467cd78a5f2c46e713e2bedf3
3
+ size 4195850696
model.safetensors.index.json ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 13930014720
4
+ },
5
+ "weight_map": {
6
+ "transformer.h.0.attn.c_attn.bias": "model-00001-of-00003.safetensors",
7
+ "transformer.h.0.attn.c_attn.weight": "model-00001-of-00003.safetensors",
8
+ "transformer.h.0.attn.c_proj.bias": "model-00001-of-00003.safetensors",
9
+ "transformer.h.0.attn.c_proj.weight": "model-00001-of-00003.safetensors",
10
+ "transformer.h.0.ln_1.weight": "model-00001-of-00003.safetensors",
11
+ "transformer.h.0.ln_2.weight": "model-00001-of-00003.safetensors",
12
+ "transformer.h.0.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
13
+ "transformer.h.0.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
14
+ "transformer.h.0.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
15
+ "transformer.h.0.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
16
+ "transformer.h.1.attn.c_attn.bias": "model-00001-of-00003.safetensors",
17
+ "transformer.h.1.attn.c_attn.weight": "model-00001-of-00003.safetensors",
18
+ "transformer.h.1.attn.c_proj.bias": "model-00001-of-00003.safetensors",
19
+ "transformer.h.1.attn.c_proj.weight": "model-00001-of-00003.safetensors",
20
+ "transformer.h.1.ln_1.weight": "model-00001-of-00003.safetensors",
21
+ "transformer.h.1.ln_2.weight": "model-00001-of-00003.safetensors",
22
+ "transformer.h.1.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
23
+ "transformer.h.1.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
24
+ "transformer.h.1.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
25
+ "transformer.h.1.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
26
+ "transformer.h.10.attn.c_attn.bias": "model-00001-of-00003.safetensors",
27
+ "transformer.h.10.attn.c_attn.weight": "model-00001-of-00003.safetensors",
28
+ "transformer.h.10.attn.c_proj.bias": "model-00001-of-00003.safetensors",
29
+ "transformer.h.10.attn.c_proj.weight": "model-00001-of-00003.safetensors",
30
+ "transformer.h.10.ln_1.weight": "model-00001-of-00003.safetensors",
31
+ "transformer.h.10.ln_2.weight": "model-00001-of-00003.safetensors",
32
+ "transformer.h.10.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
33
+ "transformer.h.10.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
34
+ "transformer.h.10.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
35
+ "transformer.h.10.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
36
+ "transformer.h.11.attn.c_attn.bias": "model-00002-of-00003.safetensors",
37
+ "transformer.h.11.attn.c_attn.weight": "model-00002-of-00003.safetensors",
38
+ "transformer.h.11.attn.c_proj.bias": "model-00002-of-00003.safetensors",
39
+ "transformer.h.11.attn.c_proj.weight": "model-00002-of-00003.safetensors",
40
+ "transformer.h.11.ln_1.weight": "model-00002-of-00003.safetensors",
41
+ "transformer.h.11.ln_2.weight": "model-00002-of-00003.safetensors",
42
+ "transformer.h.11.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
43
+ "transformer.h.11.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
44
+ "transformer.h.11.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
45
+ "transformer.h.11.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
46
+ "transformer.h.12.attn.c_attn.bias": "model-00002-of-00003.safetensors",
47
+ "transformer.h.12.attn.c_attn.weight": "model-00002-of-00003.safetensors",
48
+ "transformer.h.12.attn.c_proj.bias": "model-00002-of-00003.safetensors",
49
+ "transformer.h.12.attn.c_proj.weight": "model-00002-of-00003.safetensors",
50
+ "transformer.h.12.ln_1.weight": "model-00002-of-00003.safetensors",
51
+ "transformer.h.12.ln_2.weight": "model-00002-of-00003.safetensors",
52
+ "transformer.h.12.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
53
+ "transformer.h.12.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
54
+ "transformer.h.12.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
55
+ "transformer.h.12.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
56
+ "transformer.h.13.attn.c_attn.bias": "model-00002-of-00003.safetensors",
57
+ "transformer.h.13.attn.c_attn.weight": "model-00002-of-00003.safetensors",
58
+ "transformer.h.13.attn.c_proj.bias": "model-00002-of-00003.safetensors",
59
+ "transformer.h.13.attn.c_proj.weight": "model-00002-of-00003.safetensors",
60
+ "transformer.h.13.ln_1.weight": "model-00002-of-00003.safetensors",
61
+ "transformer.h.13.ln_2.weight": "model-00002-of-00003.safetensors",
62
+ "transformer.h.13.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
63
+ "transformer.h.13.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
64
+ "transformer.h.13.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
65
+ "transformer.h.13.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
66
+ "transformer.h.14.attn.c_attn.bias": "model-00002-of-00003.safetensors",
67
+ "transformer.h.14.attn.c_attn.weight": "model-00002-of-00003.safetensors",
68
+ "transformer.h.14.attn.c_proj.bias": "model-00002-of-00003.safetensors",
69
+ "transformer.h.14.attn.c_proj.weight": "model-00002-of-00003.safetensors",
70
+ "transformer.h.14.ln_1.weight": "model-00002-of-00003.safetensors",
71
+ "transformer.h.14.ln_2.weight": "model-00002-of-00003.safetensors",
72
+ "transformer.h.14.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
73
+ "transformer.h.14.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
74
+ "transformer.h.14.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
75
+ "transformer.h.14.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
76
+ "transformer.h.15.attn.c_attn.bias": "model-00002-of-00003.safetensors",
77
+ "transformer.h.15.attn.c_attn.weight": "model-00002-of-00003.safetensors",
78
+ "transformer.h.15.attn.c_proj.bias": "model-00002-of-00003.safetensors",
79
+ "transformer.h.15.attn.c_proj.weight": "model-00002-of-00003.safetensors",
80
+ "transformer.h.15.ln_1.weight": "model-00002-of-00003.safetensors",
81
+ "transformer.h.15.ln_2.weight": "model-00002-of-00003.safetensors",
82
+ "transformer.h.15.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
83
+ "transformer.h.15.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
84
+ "transformer.h.15.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
85
+ "transformer.h.15.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
86
+ "transformer.h.16.attn.c_attn.bias": "model-00002-of-00003.safetensors",
87
+ "transformer.h.16.attn.c_attn.weight": "model-00002-of-00003.safetensors",
88
+ "transformer.h.16.attn.c_proj.bias": "model-00002-of-00003.safetensors",
89
+ "transformer.h.16.attn.c_proj.weight": "model-00002-of-00003.safetensors",
90
+ "transformer.h.16.ln_1.weight": "model-00002-of-00003.safetensors",
91
+ "transformer.h.16.ln_2.weight": "model-00002-of-00003.safetensors",
92
+ "transformer.h.16.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
93
+ "transformer.h.16.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
94
+ "transformer.h.16.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
95
+ "transformer.h.16.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
96
+ "transformer.h.17.attn.c_attn.bias": "model-00002-of-00003.safetensors",
97
+ "transformer.h.17.attn.c_attn.weight": "model-00002-of-00003.safetensors",
98
+ "transformer.h.17.attn.c_proj.bias": "model-00002-of-00003.safetensors",
99
+ "transformer.h.17.attn.c_proj.weight": "model-00002-of-00003.safetensors",
100
+ "transformer.h.17.ln_1.weight": "model-00002-of-00003.safetensors",
101
+ "transformer.h.17.ln_2.weight": "model-00002-of-00003.safetensors",
102
+ "transformer.h.17.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
103
+ "transformer.h.17.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
104
+ "transformer.h.17.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
105
+ "transformer.h.17.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
106
+ "transformer.h.18.attn.c_attn.bias": "model-00002-of-00003.safetensors",
107
+ "transformer.h.18.attn.c_attn.weight": "model-00002-of-00003.safetensors",
108
+ "transformer.h.18.attn.c_proj.bias": "model-00002-of-00003.safetensors",
109
+ "transformer.h.18.attn.c_proj.weight": "model-00002-of-00003.safetensors",
110
+ "transformer.h.18.ln_1.weight": "model-00002-of-00003.safetensors",
111
+ "transformer.h.18.ln_2.weight": "model-00002-of-00003.safetensors",
112
+ "transformer.h.18.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
113
+ "transformer.h.18.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
114
+ "transformer.h.18.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
115
+ "transformer.h.18.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
116
+ "transformer.h.19.attn.c_attn.bias": "model-00002-of-00003.safetensors",
117
+ "transformer.h.19.attn.c_attn.weight": "model-00002-of-00003.safetensors",
118
+ "transformer.h.19.attn.c_proj.bias": "model-00002-of-00003.safetensors",
119
+ "transformer.h.19.attn.c_proj.weight": "model-00002-of-00003.safetensors",
120
+ "transformer.h.19.ln_1.weight": "model-00002-of-00003.safetensors",
121
+ "transformer.h.19.ln_2.weight": "model-00002-of-00003.safetensors",
122
+ "transformer.h.19.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
123
+ "transformer.h.19.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
124
+ "transformer.h.19.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
125
+ "transformer.h.19.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
126
+ "transformer.h.2.attn.c_attn.bias": "model-00001-of-00003.safetensors",
127
+ "transformer.h.2.attn.c_attn.weight": "model-00001-of-00003.safetensors",
128
+ "transformer.h.2.attn.c_proj.bias": "model-00001-of-00003.safetensors",
129
+ "transformer.h.2.attn.c_proj.weight": "model-00001-of-00003.safetensors",
130
+ "transformer.h.2.ln_1.weight": "model-00001-of-00003.safetensors",
131
+ "transformer.h.2.ln_2.weight": "model-00001-of-00003.safetensors",
132
+ "transformer.h.2.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
133
+ "transformer.h.2.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
134
+ "transformer.h.2.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
135
+ "transformer.h.2.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
136
+ "transformer.h.20.attn.c_attn.bias": "model-00002-of-00003.safetensors",
137
+ "transformer.h.20.attn.c_attn.weight": "model-00002-of-00003.safetensors",
138
+ "transformer.h.20.attn.c_proj.bias": "model-00002-of-00003.safetensors",
139
+ "transformer.h.20.attn.c_proj.weight": "model-00002-of-00003.safetensors",
140
+ "transformer.h.20.ln_1.weight": "model-00002-of-00003.safetensors",
141
+ "transformer.h.20.ln_2.weight": "model-00002-of-00003.safetensors",
142
+ "transformer.h.20.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
143
+ "transformer.h.20.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
144
+ "transformer.h.20.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
145
+ "transformer.h.20.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
146
+ "transformer.h.21.attn.c_attn.bias": "model-00002-of-00003.safetensors",
147
+ "transformer.h.21.attn.c_attn.weight": "model-00002-of-00003.safetensors",
148
+ "transformer.h.21.attn.c_proj.bias": "model-00002-of-00003.safetensors",
149
+ "transformer.h.21.attn.c_proj.weight": "model-00002-of-00003.safetensors",
150
+ "transformer.h.21.ln_1.weight": "model-00002-of-00003.safetensors",
151
+ "transformer.h.21.ln_2.weight": "model-00002-of-00003.safetensors",
152
+ "transformer.h.21.mlp.c_fc.bias": "model-00002-of-00003.safetensors",
153
+ "transformer.h.21.mlp.c_fc.weight": "model-00002-of-00003.safetensors",
154
+ "transformer.h.21.mlp.c_proj.bias": "model-00002-of-00003.safetensors",
155
+ "transformer.h.21.mlp.c_proj.weight": "model-00002-of-00003.safetensors",
156
+ "transformer.h.22.attn.c_attn.bias": "model-00003-of-00003.safetensors",
157
+ "transformer.h.22.attn.c_attn.weight": "model-00003-of-00003.safetensors",
158
+ "transformer.h.22.attn.c_proj.bias": "model-00003-of-00003.safetensors",
159
+ "transformer.h.22.attn.c_proj.weight": "model-00003-of-00003.safetensors",
160
+ "transformer.h.22.ln_1.weight": "model-00002-of-00003.safetensors",
161
+ "transformer.h.22.ln_2.weight": "model-00003-of-00003.safetensors",
162
+ "transformer.h.22.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
163
+ "transformer.h.22.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
164
+ "transformer.h.22.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
165
+ "transformer.h.22.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
166
+ "transformer.h.23.attn.c_attn.bias": "model-00003-of-00003.safetensors",
167
+ "transformer.h.23.attn.c_attn.weight": "model-00003-of-00003.safetensors",
168
+ "transformer.h.23.attn.c_proj.bias": "model-00003-of-00003.safetensors",
169
+ "transformer.h.23.attn.c_proj.weight": "model-00003-of-00003.safetensors",
170
+ "transformer.h.23.ln_1.weight": "model-00003-of-00003.safetensors",
171
+ "transformer.h.23.ln_2.weight": "model-00003-of-00003.safetensors",
172
+ "transformer.h.23.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
173
+ "transformer.h.23.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
174
+ "transformer.h.23.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
175
+ "transformer.h.23.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
176
+ "transformer.h.24.attn.c_attn.bias": "model-00003-of-00003.safetensors",
177
+ "transformer.h.24.attn.c_attn.weight": "model-00003-of-00003.safetensors",
178
+ "transformer.h.24.attn.c_proj.bias": "model-00003-of-00003.safetensors",
179
+ "transformer.h.24.attn.c_proj.weight": "model-00003-of-00003.safetensors",
180
+ "transformer.h.24.ln_1.weight": "model-00003-of-00003.safetensors",
181
+ "transformer.h.24.ln_2.weight": "model-00003-of-00003.safetensors",
182
+ "transformer.h.24.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
183
+ "transformer.h.24.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
184
+ "transformer.h.24.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
185
+ "transformer.h.24.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
186
+ "transformer.h.25.attn.c_attn.bias": "model-00003-of-00003.safetensors",
187
+ "transformer.h.25.attn.c_attn.weight": "model-00003-of-00003.safetensors",
188
+ "transformer.h.25.attn.c_proj.bias": "model-00003-of-00003.safetensors",
189
+ "transformer.h.25.attn.c_proj.weight": "model-00003-of-00003.safetensors",
190
+ "transformer.h.25.ln_1.weight": "model-00003-of-00003.safetensors",
191
+ "transformer.h.25.ln_2.weight": "model-00003-of-00003.safetensors",
192
+ "transformer.h.25.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
193
+ "transformer.h.25.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
194
+ "transformer.h.25.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
195
+ "transformer.h.25.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
196
+ "transformer.h.26.attn.c_attn.bias": "model-00003-of-00003.safetensors",
197
+ "transformer.h.26.attn.c_attn.weight": "model-00003-of-00003.safetensors",
198
+ "transformer.h.26.attn.c_proj.bias": "model-00003-of-00003.safetensors",
199
+ "transformer.h.26.attn.c_proj.weight": "model-00003-of-00003.safetensors",
200
+ "transformer.h.26.ln_1.weight": "model-00003-of-00003.safetensors",
201
+ "transformer.h.26.ln_2.weight": "model-00003-of-00003.safetensors",
202
+ "transformer.h.26.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
203
+ "transformer.h.26.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
204
+ "transformer.h.26.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
205
+ "transformer.h.26.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
206
+ "transformer.h.27.attn.c_attn.bias": "model-00003-of-00003.safetensors",
207
+ "transformer.h.27.attn.c_attn.weight": "model-00003-of-00003.safetensors",
208
+ "transformer.h.27.attn.c_proj.bias": "model-00003-of-00003.safetensors",
209
+ "transformer.h.27.attn.c_proj.weight": "model-00003-of-00003.safetensors",
210
+ "transformer.h.27.ln_1.weight": "model-00003-of-00003.safetensors",
211
+ "transformer.h.27.ln_2.weight": "model-00003-of-00003.safetensors",
212
+ "transformer.h.27.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
213
+ "transformer.h.27.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
214
+ "transformer.h.27.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
215
+ "transformer.h.27.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
216
+ "transformer.h.28.attn.c_attn.bias": "model-00003-of-00003.safetensors",
217
+ "transformer.h.28.attn.c_attn.weight": "model-00003-of-00003.safetensors",
218
+ "transformer.h.28.attn.c_proj.bias": "model-00003-of-00003.safetensors",
219
+ "transformer.h.28.attn.c_proj.weight": "model-00003-of-00003.safetensors",
220
+ "transformer.h.28.ln_1.weight": "model-00003-of-00003.safetensors",
221
+ "transformer.h.28.ln_2.weight": "model-00003-of-00003.safetensors",
222
+ "transformer.h.28.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
223
+ "transformer.h.28.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
224
+ "transformer.h.28.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
225
+ "transformer.h.28.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
226
+ "transformer.h.29.attn.c_attn.bias": "model-00003-of-00003.safetensors",
227
+ "transformer.h.29.attn.c_attn.weight": "model-00003-of-00003.safetensors",
228
+ "transformer.h.29.attn.c_proj.bias": "model-00003-of-00003.safetensors",
229
+ "transformer.h.29.attn.c_proj.weight": "model-00003-of-00003.safetensors",
230
+ "transformer.h.29.ln_1.weight": "model-00003-of-00003.safetensors",
231
+ "transformer.h.29.ln_2.weight": "model-00003-of-00003.safetensors",
232
+ "transformer.h.29.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
233
+ "transformer.h.29.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
234
+ "transformer.h.29.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
235
+ "transformer.h.29.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
236
+ "transformer.h.3.attn.c_attn.bias": "model-00001-of-00003.safetensors",
237
+ "transformer.h.3.attn.c_attn.weight": "model-00001-of-00003.safetensors",
238
+ "transformer.h.3.attn.c_proj.bias": "model-00001-of-00003.safetensors",
239
+ "transformer.h.3.attn.c_proj.weight": "model-00001-of-00003.safetensors",
240
+ "transformer.h.3.ln_1.weight": "model-00001-of-00003.safetensors",
241
+ "transformer.h.3.ln_2.weight": "model-00001-of-00003.safetensors",
242
+ "transformer.h.3.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
243
+ "transformer.h.3.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
244
+ "transformer.h.3.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
245
+ "transformer.h.3.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
246
+ "transformer.h.30.attn.c_attn.bias": "model-00003-of-00003.safetensors",
247
+ "transformer.h.30.attn.c_attn.weight": "model-00003-of-00003.safetensors",
248
+ "transformer.h.30.attn.c_proj.bias": "model-00003-of-00003.safetensors",
249
+ "transformer.h.30.attn.c_proj.weight": "model-00003-of-00003.safetensors",
250
+ "transformer.h.30.ln_1.weight": "model-00003-of-00003.safetensors",
251
+ "transformer.h.30.ln_2.weight": "model-00003-of-00003.safetensors",
252
+ "transformer.h.30.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
253
+ "transformer.h.30.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
254
+ "transformer.h.30.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
255
+ "transformer.h.30.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
256
+ "transformer.h.31.attn.c_attn.bias": "model-00003-of-00003.safetensors",
257
+ "transformer.h.31.attn.c_attn.weight": "model-00003-of-00003.safetensors",
258
+ "transformer.h.31.attn.c_proj.bias": "model-00003-of-00003.safetensors",
259
+ "transformer.h.31.attn.c_proj.weight": "model-00003-of-00003.safetensors",
260
+ "transformer.h.31.ln_1.weight": "model-00003-of-00003.safetensors",
261
+ "transformer.h.31.ln_2.weight": "model-00003-of-00003.safetensors",
262
+ "transformer.h.31.mlp.c_fc.bias": "model-00003-of-00003.safetensors",
263
+ "transformer.h.31.mlp.c_fc.weight": "model-00003-of-00003.safetensors",
264
+ "transformer.h.31.mlp.c_proj.bias": "model-00003-of-00003.safetensors",
265
+ "transformer.h.31.mlp.c_proj.weight": "model-00003-of-00003.safetensors",
266
+ "transformer.h.4.attn.c_attn.bias": "model-00001-of-00003.safetensors",
267
+ "transformer.h.4.attn.c_attn.weight": "model-00001-of-00003.safetensors",
268
+ "transformer.h.4.attn.c_proj.bias": "model-00001-of-00003.safetensors",
269
+ "transformer.h.4.attn.c_proj.weight": "model-00001-of-00003.safetensors",
270
+ "transformer.h.4.ln_1.weight": "model-00001-of-00003.safetensors",
271
+ "transformer.h.4.ln_2.weight": "model-00001-of-00003.safetensors",
272
+ "transformer.h.4.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
273
+ "transformer.h.4.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
274
+ "transformer.h.4.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
275
+ "transformer.h.4.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
276
+ "transformer.h.5.attn.c_attn.bias": "model-00001-of-00003.safetensors",
277
+ "transformer.h.5.attn.c_attn.weight": "model-00001-of-00003.safetensors",
278
+ "transformer.h.5.attn.c_proj.bias": "model-00001-of-00003.safetensors",
279
+ "transformer.h.5.attn.c_proj.weight": "model-00001-of-00003.safetensors",
280
+ "transformer.h.5.ln_1.weight": "model-00001-of-00003.safetensors",
281
+ "transformer.h.5.ln_2.weight": "model-00001-of-00003.safetensors",
282
+ "transformer.h.5.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
283
+ "transformer.h.5.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
284
+ "transformer.h.5.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
285
+ "transformer.h.5.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
286
+ "transformer.h.6.attn.c_attn.bias": "model-00001-of-00003.safetensors",
287
+ "transformer.h.6.attn.c_attn.weight": "model-00001-of-00003.safetensors",
288
+ "transformer.h.6.attn.c_proj.bias": "model-00001-of-00003.safetensors",
289
+ "transformer.h.6.attn.c_proj.weight": "model-00001-of-00003.safetensors",
290
+ "transformer.h.6.ln_1.weight": "model-00001-of-00003.safetensors",
291
+ "transformer.h.6.ln_2.weight": "model-00001-of-00003.safetensors",
292
+ "transformer.h.6.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
293
+ "transformer.h.6.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
294
+ "transformer.h.6.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
295
+ "transformer.h.6.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
296
+ "transformer.h.7.attn.c_attn.bias": "model-00001-of-00003.safetensors",
297
+ "transformer.h.7.attn.c_attn.weight": "model-00001-of-00003.safetensors",
298
+ "transformer.h.7.attn.c_proj.bias": "model-00001-of-00003.safetensors",
299
+ "transformer.h.7.attn.c_proj.weight": "model-00001-of-00003.safetensors",
300
+ "transformer.h.7.ln_1.weight": "model-00001-of-00003.safetensors",
301
+ "transformer.h.7.ln_2.weight": "model-00001-of-00003.safetensors",
302
+ "transformer.h.7.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
303
+ "transformer.h.7.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
304
+ "transformer.h.7.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
305
+ "transformer.h.7.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
306
+ "transformer.h.8.attn.c_attn.bias": "model-00001-of-00003.safetensors",
307
+ "transformer.h.8.attn.c_attn.weight": "model-00001-of-00003.safetensors",
308
+ "transformer.h.8.attn.c_proj.bias": "model-00001-of-00003.safetensors",
309
+ "transformer.h.8.attn.c_proj.weight": "model-00001-of-00003.safetensors",
310
+ "transformer.h.8.ln_1.weight": "model-00001-of-00003.safetensors",
311
+ "transformer.h.8.ln_2.weight": "model-00001-of-00003.safetensors",
312
+ "transformer.h.8.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
313
+ "transformer.h.8.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
314
+ "transformer.h.8.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
315
+ "transformer.h.8.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
316
+ "transformer.h.9.attn.c_attn.bias": "model-00001-of-00003.safetensors",
317
+ "transformer.h.9.attn.c_attn.weight": "model-00001-of-00003.safetensors",
318
+ "transformer.h.9.attn.c_proj.bias": "model-00001-of-00003.safetensors",
319
+ "transformer.h.9.attn.c_proj.weight": "model-00001-of-00003.safetensors",
320
+ "transformer.h.9.ln_1.weight": "model-00001-of-00003.safetensors",
321
+ "transformer.h.9.ln_2.weight": "model-00001-of-00003.safetensors",
322
+ "transformer.h.9.mlp.c_fc.bias": "model-00001-of-00003.safetensors",
323
+ "transformer.h.9.mlp.c_fc.weight": "model-00001-of-00003.safetensors",
324
+ "transformer.h.9.mlp.c_proj.bias": "model-00001-of-00003.safetensors",
325
+ "transformer.h.9.mlp.c_proj.weight": "model-00001-of-00003.safetensors",
326
+ "transformer.ln_f.weight": "model-00003-of-00003.safetensors",
327
+ "transformer.wte.weight": "model-00001-of-00003.safetensors"
328
+ }
329
+ }
modeling_granite.py ADDED
@@ -0,0 +1,1374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from enum import Enum
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers import DynamicCache, PreTrainedModel
9
+ from transformers.activations import get_activation as get_base_activation
10
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
11
+ from transformers.utils import is_flash_attn_2_available
12
+
13
+ from .configuration_granite import GraniteConfig
14
+
15
+
16
+ class PositionEmbeddingType(Enum):
17
+ learned_absolute = "learned_absolute"
18
+ alibi = "alibi"
19
+ rope = "rope"
20
+
21
+
22
+ class AttentionHeadType(Enum):
23
+ mha = "mha"
24
+ mqa = "mqa"
25
+ gqa = "gqa"
26
+
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn.bert_padding import IndexFirstAxis, pad_input, unpad_input
30
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
31
+
32
+
33
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
34
+ def get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
36
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
37
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
38
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
39
+ return indices, cu_seqlens, max_seqlen_in_batch
40
+
41
+
42
+ def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor:
43
+ num_groups = num_heads // num_key_value_heads
44
+
45
+ # mha
46
+ if num_groups == 1:
47
+ return x
48
+
49
+ # mqa
50
+ if num_key_value_heads == 1:
51
+ return x.expand(-1, num_heads, -1, -1)
52
+
53
+ # gqa
54
+ return x.repeat_interleave(num_groups, dim=1)
55
+
56
+
57
+ ##################################################
58
+ # activation functions
59
+
60
+
61
+ _GLU_BASE_MAPPING = {
62
+ "geglu": "gelu",
63
+ "miglu": "mish",
64
+ "mishglu": "mish",
65
+ "swiglu": "swish",
66
+ }
67
+
68
+
69
+ class GLUActivation(nn.Module):
70
+ def __init__(self, base_activation: nn.Module) -> None:
71
+ super().__init__()
72
+ self.base_activation = base_activation
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ x = x.chunk(2, dim=-1)
76
+ return x[0] * self.base_activation(x[1])
77
+
78
+
79
+ def is_glu(name: str) -> bool:
80
+ return name.endswith("glu")
81
+
82
+
83
+ def get_activation_function(name: str) -> nn.Module:
84
+ if is_glu(name):
85
+ # for glu and sigmoid_glu, we directly return the pytorch's GLU
86
+ if name in ["glu", "sigmoid_glu"]:
87
+ activation_function = nn.modules.GLU()
88
+ else:
89
+ if name in _GLU_BASE_MAPPING:
90
+ name = _GLU_BASE_MAPPING[name]
91
+ elif name.endswith("_glu"):
92
+ name = name.rstrip("_glu")
93
+ else:
94
+ raise ValueError("invalid activation function")
95
+
96
+ base_activation = get_base_activation(name)
97
+ activation_function = GLUActivation(base_activation)
98
+ else:
99
+ activation_function = get_base_activation(name)
100
+
101
+ return activation_function
102
+
103
+
104
+ ##################################################
105
+ # normalization functions
106
+
107
+
108
+ class RMSNorm(nn.Module):
109
+ def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None:
110
+ super().__init__()
111
+
112
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
113
+ self.eps = eps
114
+
115
+ if isinstance(normalized_shape, numbers.Integral):
116
+ normalized_shape = (normalized_shape,)
117
+ self.normalized_shape = normalized_shape
118
+
119
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
120
+ input_dtype = input.dtype
121
+
122
+ input = input.to(torch.float32)
123
+ variance = input.pow(2).mean(-1, keepdim=True)
124
+ input = input * torch.rsqrt(variance + self.eps)
125
+
126
+ return self.weight * input.to(input_dtype)
127
+
128
+ def extra_repr(self) -> str:
129
+ return f"{self.normalized_shape}, eps={self.eps}"
130
+
131
+ def reset_parameters(self) -> None:
132
+ nn.init.ones_(self.weight)
133
+
134
+
135
+ _NORMALIZATION_FUNCTIONS = {
136
+ "layernorm": nn.LayerNorm,
137
+ "rmsnorm": RMSNorm,
138
+ }
139
+
140
+
141
+ def get_normalization_function(name: str, normalized_shape: int, eps: float = 1e-5) -> nn.Module:
142
+ if name in _NORMALIZATION_FUNCTIONS:
143
+ return _NORMALIZATION_FUNCTIONS[name](normalized_shape, eps=eps)
144
+
145
+ raise ValueError(f"unexpected `normalization_function` {name}")
146
+
147
+
148
+ ##################################################
149
+ # attention modules
150
+
151
+
152
+ class GraniteAttention(nn.Module):
153
+ def __init__(self, config: GraniteConfig, causal: bool, layer_idx: Optional[int] = None) -> None:
154
+ super().__init__()
155
+
156
+ self.causal = causal
157
+ self.hidden_size = config.n_embd
158
+ self.num_heads = config.n_head
159
+ self.num_key_value_heads = config.num_key_value_heads
160
+ self.add_bias = config.add_bias
161
+
162
+ assert (
163
+ self.hidden_size % self.num_heads == 0
164
+ ), f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})"
165
+
166
+ self.head_dim = self.hidden_size // self.num_heads
167
+ self.attention_head_type = AttentionHeadType(config.attention_head_type)
168
+
169
+ self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
170
+ self.scale_attn_weights = config.scale_attn_weights
171
+ self.attention_multiplier = config.attention_multiplier
172
+
173
+ self.layer_idx = layer_idx
174
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
175
+ self.scale_attention_softmax_in_fp32 = (
176
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
177
+ )
178
+
179
+ if self.attention_head_type == AttentionHeadType.mha:
180
+ if self.num_key_value_heads is None:
181
+ self.num_key_value_heads = self.num_heads
182
+
183
+ assert (
184
+ self.num_heads == self.num_key_value_heads
185
+ ), f"{self.__class__.__name__} should have same number of heads for query, keys and values"
186
+ elif self.attention_head_type == AttentionHeadType.gqa:
187
+ assert (
188
+ self.num_key_value_heads is not None
189
+ ), "`num_key_value_heads` needs to be specified with GroupedQueryAttention"
190
+
191
+ assert self.num_heads % self.num_key_value_heads == 0, (
192
+ f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` "
193
+ f"({self.num_key_value_heads})"
194
+ )
195
+ elif self.attention_head_type == AttentionHeadType.mqa:
196
+ if self.num_key_value_heads is None:
197
+ self.num_key_value_heads = 1
198
+
199
+ assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values"
200
+ else:
201
+ raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
202
+
203
+ # note that the actual layout is different for the output and depends on whether we are using MHA, MQA or GQA
204
+ # (self.hidden_size + 2 * self.num_key_value_heads * self.head_dim) is just the actual number output features
205
+ self.c_attn = nn.Linear(
206
+ self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=self.add_bias
207
+ )
208
+ self.c_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.add_bias)
209
+
210
+ self.attn_pdrop = config.attn_pdrop
211
+ self.resid_pdrop = config.resid_pdrop
212
+
213
+ self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop)
214
+ self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop)
215
+
216
+ def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
217
+ # ==========================================================================================
218
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
219
+ # ==========================================================================================
220
+
221
+ # the output of following is a tuple if using MQA with tensor parallel
222
+ hidden_states = self.c_attn(hidden_states)
223
+
224
+ # ==========================================================================================
225
+ # hidden_states -> (batch_size, query_length, [num_heads + num_key_value_heads * 2] * head_dim)
226
+ # ==========================================================================================
227
+
228
+ # for MHA, we can get away with doing just 1 transpose which is not true for GQA
229
+ if self.attention_head_type == AttentionHeadType.mha:
230
+ query, key, value = self._prepare_qkv_for_forward_mha(hidden_states)
231
+ elif self.attention_head_type == AttentionHeadType.gqa:
232
+ query, key, value = self._prepare_qkv_for_forward_gqa(hidden_states)
233
+ elif self.attention_head_type == AttentionHeadType.mqa:
234
+ query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states)
235
+ else:
236
+ raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
237
+
238
+ # ==========================================================================================
239
+ # query -> (batch_size, num_heads, query_length, head_dim)
240
+ # key -> (batch_size, num_key_value_heads, query_length, head_dim)
241
+ # value -> (batch_size, num_key_value_heads, query_length, head_dim)
242
+ # ==========================================================================================
243
+
244
+ return query, key, value
245
+
246
+ def _prepare_qkv_for_forward_mha(
247
+ self, hidden_states: torch.Tensor
248
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
249
+ batch_size, query_length = hidden_states.shape[:-1]
250
+
251
+ hidden_states = hidden_states.view(batch_size, query_length, self.num_heads, -1)
252
+ hidden_states = hidden_states.transpose(1, 2)
253
+
254
+ query, key, value = hidden_states.chunk(3, dim=-1)
255
+
256
+ return query, key, value
257
+
258
+ def _prepare_qkv_for_forward_gqa(
259
+ self, hidden_states: torch.Tensor
260
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
261
+ batch_size, query_length = hidden_states.shape[:-1]
262
+
263
+ hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1)
264
+
265
+ query, key, value = hidden_states.split(
266
+ ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1
267
+ )
268
+
269
+ # this needs to be a reshape instead of view sadly
270
+ query = query.reshape(batch_size, query_length, -1, self.head_dim)
271
+
272
+ query = query.transpose(1, 2)
273
+ key = key.transpose(1, 2)
274
+ value = value.transpose(1, 2)
275
+
276
+ return query, key, value
277
+
278
+ def _prepare_qkv_for_forward_mqa(
279
+ self, hidden_states: torch.Tensor
280
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
281
+ batch_size, query_length = hidden_states.shape[:-1]
282
+
283
+ query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1)
284
+
285
+ query = query.view(batch_size, query_length, self.num_heads, -1)
286
+
287
+ query = query.transpose(1, 2)
288
+ key = key.unsqueeze(1)
289
+ value = value.unsqueeze(1)
290
+
291
+ return query, key, value
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ past_key_values: Optional[DynamicCache] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
299
+ ) -> torch.Tensor:
300
+ # ==========================================================================================
301
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
302
+ # ==========================================================================================
303
+
304
+ query, key, value = self._prepare_qkv_for_forward(hidden_states)
305
+
306
+ # ==========================================================================================
307
+ # query -> (batch_size, num_heads, query_length, head_dim)
308
+ # key -> (batch_size, num_key_value_heads, query_length, head_dim)
309
+ # value -> (batch_size, num_key_value_heads, query_length, head_dim)
310
+ # ==========================================================================================
311
+
312
+ if self.position_embedding_type == PositionEmbeddingType.rope:
313
+ query = apply_rotary_pos_emb(query, rope_cos_sin)
314
+ key = apply_rotary_pos_emb(key, rope_cos_sin)
315
+
316
+ if past_key_values is not None:
317
+ key, value = past_key_values.update(key, value, self.layer_idx)
318
+
319
+ # ==========================================================================================
320
+ # query -> (batch_size, num_heads, query_length, head_dim)
321
+ # key -> (batch_size, num_key_value_heads, key_length, head_dim)
322
+ # value -> (batch_size, num_key_value_heads, key_length, head_dim)
323
+ # ==========================================================================================
324
+
325
+ key = key.transpose(-1, -2)
326
+
327
+ dtype = query.dtype
328
+ softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
329
+
330
+ if self.scale_attn_weights:
331
+ if self.attention_multiplier is None:
332
+ scale_factor = 1 / self.head_dim**0.5
333
+ else:
334
+ scale_factor = self.attention_multiplier
335
+ else:
336
+ scale_factor = 1
337
+
338
+ # ==========================================================================================
339
+ # query -> (batch_size, num_heads, query_length, head_dim)
340
+ # key -> (batch_size, num_key_value_heads, head_dim, key_length)
341
+ # value -> (batch_size, num_key_value_heads, key_length, head_dim)
342
+ # ==========================================================================================
343
+
344
+ batch_size = query.shape[0]
345
+ query_length = query.shape[2]
346
+ key_length = key.shape[-1]
347
+
348
+ key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
349
+ value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
350
+
351
+ # Always copies
352
+ query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
353
+ # No copy when layer_past is provided.
354
+ key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
355
+
356
+ # ==========================================================================================
357
+ # query -> (batch_size * num_heads, query_length, head_dim)
358
+ # key -> (batch_size * num_heads, head_dim, key_length)
359
+ # value -> (batch_size, num_heads, key_length, head_dim)
360
+ # ==========================================================================================
361
+
362
+ attn_weights = torch.empty(
363
+ (batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype
364
+ )
365
+
366
+ attn_weights = torch.baddbmm(attn_weights, query, key, beta=0, alpha=scale_factor).view(
367
+ batch_size, self.num_heads, query_length, key_length
368
+ )
369
+
370
+ # ==========================================================================================
371
+ # attn_weights -> (batch_size, num_heads, query_length, key_length)
372
+ # ==========================================================================================
373
+
374
+ attn_weights = attn_weights.to(softmax_dtype)
375
+
376
+ if attention_mask is not None:
377
+ attn_weights = attn_weights + attention_mask
378
+
379
+ attn_weights = F.softmax(attn_weights, dim=-1).to(dtype)
380
+
381
+ attn_weights = self.attn_dropout(attn_weights)
382
+
383
+ # ==========================================================================================
384
+ # value -> (batch_size, num_heads, key_length, head_dim)
385
+ # attn_weights -> (batch_size, num_heads, query_length, key_length)
386
+ # ==========================================================================================
387
+
388
+ attn_output = torch.matmul(attn_weights, value)
389
+
390
+ # ==========================================================================================
391
+ # attn_output -> (batch_size, num_heads, query_length, head_dim)
392
+ # ==========================================================================================
393
+
394
+ attn_output = attn_output.transpose(1, 2)
395
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
396
+
397
+ # ==========================================================================================
398
+ # attn_output -> (batch_size, query_length, num_heads * head_dim)
399
+ # ==========================================================================================
400
+
401
+ attn_output = self.c_proj(attn_output)
402
+ attn_output = self.resid_dropout(attn_output)
403
+
404
+ return attn_output
405
+
406
+
407
+ class GraniteSDPA(GraniteAttention):
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ past_key_values: Optional[DynamicCache] = None,
412
+ attention_mask: Optional[torch.Tensor] = None,
413
+ rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
414
+ ) -> torch.Tensor:
415
+ # ==========================================================================================
416
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
417
+ # ==========================================================================================
418
+
419
+ query, key, value = self._prepare_qkv_for_forward(hidden_states)
420
+
421
+ # ==========================================================================================
422
+ # query -> (batch_size, num_heads, query_length, head_dim)
423
+ # key -> (batch_size, num_key_value_heads, query_length, head_dim)
424
+ # value -> (batch_size, num_key_value_heads, query_length, head_dim)
425
+ # ==========================================================================================
426
+
427
+ if self.position_embedding_type == PositionEmbeddingType.rope:
428
+ query = apply_rotary_pos_emb(query, rope_cos_sin)
429
+ key = apply_rotary_pos_emb(key, rope_cos_sin)
430
+
431
+ if past_key_values is not None:
432
+ key, value = past_key_values.update(key, value, self.layer_idx)
433
+
434
+ # ==========================================================================================
435
+ # query -> (batch_size, num_heads, query_length, head_dim)
436
+ # key -> (batch_size, num_key_value_heads, key_length, head_dim)
437
+ # value -> (batch_size, num_key_value_heads, key_length, head_dim)
438
+ # ==========================================================================================
439
+
440
+ key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
441
+ value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
442
+
443
+ # ==========================================================================================
444
+ # query -> (batch_size, num_heads, query_length, head_dim)
445
+ # key -> (batch_size, num_heads, key_length, head_dim)
446
+ # value -> (batch_size, num_heads, key_length, head_dim)
447
+ # ==========================================================================================
448
+
449
+ attn_output = F.scaled_dot_product_attention(
450
+ query,
451
+ key,
452
+ value,
453
+ attn_mask=attention_mask,
454
+ dropout_p=self.attn_pdrop if self.training else 0,
455
+ is_causal=self.causal if attention_mask is None else False,
456
+ scale=self.attention_multiplier if self.scale_attn_weights else 1,
457
+ )
458
+
459
+ # ==========================================================================================
460
+ # attn_output -> (batch_size, num_heads, query_length, head_dim)
461
+ # ==========================================================================================
462
+
463
+ batch_size = attn_output.shape[0]
464
+ attn_output = attn_output.transpose(1, 2)
465
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
466
+
467
+ # ==========================================================================================
468
+ # attn_output -> (batch_size, query_length, num_heads * head_dim)
469
+ # ==========================================================================================
470
+
471
+ attn_output = self.c_proj(attn_output)
472
+ attn_output = self.resid_dropout(attn_output)
473
+
474
+ return attn_output
475
+
476
+
477
+ class GraniteFlashAttention2(GraniteAttention):
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ past_key_values: Optional[DynamicCache] = None,
482
+ attention_mask: Optional[torch.Tensor] = None,
483
+ rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
484
+ ) -> torch.Tensor:
485
+ # ==========================================================================================
486
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
487
+ # ==========================================================================================
488
+
489
+ query, key, value = self._prepare_qkv_for_forward(hidden_states)
490
+
491
+ # ==========================================================================================
492
+ # query -> (batch_size, num_heads, query_length, head_dim)
493
+ # key -> (batch_size, num_key_value_heads, query_length, head_dim)
494
+ # value -> (batch_size, num_key_value_heads, query_length, head_dim)
495
+ # ==========================================================================================
496
+
497
+ if self.position_embedding_type == PositionEmbeddingType.rope:
498
+ query = apply_rotary_pos_emb(query, rope_cos_sin)
499
+ key = apply_rotary_pos_emb(key, rope_cos_sin)
500
+
501
+ if past_key_values is not None:
502
+ key, value = past_key_values.update(key, value, self.layer_idx)
503
+
504
+ # ==========================================================================================
505
+ # query -> (batch_size, num_heads, query_length, head_dim)
506
+ # key -> (batch_size, num_key_value_heads, key_length, head_dim)
507
+ # value -> (batch_size, num_key_value_heads, key_length, head_dim)
508
+ # ==========================================================================================
509
+
510
+ # TODO avoid this extra transpose
511
+ query = query.transpose(1, 2)
512
+ if self.attention_head_type == AttentionHeadType.mqa:
513
+ key = key.squeeze(1).unsqueeze(2)
514
+ value = value.squeeze(1).unsqueeze(2)
515
+ else:
516
+ key = key.transpose(1, 2)
517
+ value = value.transpose(1, 2)
518
+
519
+ # ==========================================================================================
520
+ # query -> (batch_size, query_length, num_heads, head_dim)
521
+ # key -> (batch_size, key_length, num_heads, head_dim)
522
+ # value -> (batch_size, key_length, num_heads, head_dim)
523
+ # ==========================================================================================
524
+
525
+ batch_size, query_length = query.shape[:2]
526
+ key_length = key.shape[1]
527
+ indices_k, cu_seqlens_k, max_seqlen_k = get_unpad_data(attention_mask)
528
+
529
+ key = IndexFirstAxis.apply(
530
+ key.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
531
+ )
532
+ value = IndexFirstAxis.apply(
533
+ value.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
534
+ )
535
+
536
+ if query_length == key_length:
537
+ query = IndexFirstAxis.apply(
538
+ query.reshape(batch_size * key_length, self.num_heads, self.head_dim), indices_k
539
+ )
540
+ cu_seqlens_q = cu_seqlens_k
541
+ max_seqlen_q = max_seqlen_k
542
+ indices_q = indices_k
543
+ elif query_length == 1:
544
+ max_seqlen_q = 1
545
+ cu_seqlens_q = torch.arange(
546
+ batch_size + 1, dtype=torch.int32, device=query.device
547
+ ) # There is a memcpy here, that is very bad.
548
+ indices_q = cu_seqlens_q[:-1]
549
+ query = query.squeeze(1)
550
+ else:
551
+ # The -q_len: slice assumes left padding.
552
+ attention_mask = attention_mask[:, -query_length:]
553
+ query, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query, attention_mask)
554
+
555
+ # ==========================================================================================
556
+ # query -> (total_q, num_heads, head_dim)
557
+ # key -> (total_q, num_heads, head_dim)
558
+ # value -> (total_q, num_heads, head_dim)
559
+ # ==========================================================================================
560
+
561
+ attn_output = flash_attn_varlen_func(
562
+ query,
563
+ key,
564
+ value,
565
+ cu_seqlens_q=cu_seqlens_q,
566
+ cu_seqlens_k=cu_seqlens_k,
567
+ max_seqlen_q=max_seqlen_q,
568
+ max_seqlen_k=max_seqlen_k,
569
+ dropout_p=self.attn_pdrop if self.training else 0,
570
+ softmax_scale=self.attention_multiplier if self.scale_attn_weights else 1,
571
+ causal=self.causal,
572
+ )
573
+
574
+ # ==========================================================================================
575
+ # attn_output -> (total_q, num_heads, head_dim)
576
+ # ==========================================================================================
577
+
578
+ attn_output = pad_input(attn_output, indices_q, batch_size, query_length)
579
+ attn_output = attn_output.view(batch_size, query_length, -1)
580
+
581
+ # ==========================================================================================
582
+ # attn_output -> (batch_size, query_length, num_heads * head_dim)
583
+ # ==========================================================================================
584
+
585
+ attn_output = self.c_proj(attn_output)
586
+ attn_output = self.resid_dropout(attn_output)
587
+
588
+ return attn_output
589
+
590
+
591
+ _ATTENTION_MODULES = {
592
+ "eager": GraniteAttention,
593
+ "sdpa": GraniteSDPA,
594
+ "flash_attention_2": GraniteFlashAttention2,
595
+ }
596
+
597
+
598
+ def get_attention_module(
599
+ config: GraniteConfig, causal: bool, attention_implementation: str, layer_idx: int
600
+ ) -> GraniteAttention:
601
+ if attention_implementation in _ATTENTION_MODULES:
602
+ return _ATTENTION_MODULES[attention_implementation](config, causal=causal, layer_idx=layer_idx)
603
+ raise ValueError(f"unexpected `attention_implementation` {attention_implementation}")
604
+
605
+
606
+ ##################################################
607
+ # position embeddings
608
+
609
+
610
+ class Alibi(nn.Module):
611
+ def __init__(self, num_heads: int) -> None:
612
+ super().__init__()
613
+ self.num_heads = num_heads
614
+
615
+ self.reset_parameters()
616
+
617
+ def forward(
618
+ self, attention_mask: torch.Tensor, batch_size: int, key_length: int, device: torch.device, dtype: torch.dtype
619
+ ) -> torch.Tensor:
620
+ """
621
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
622
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
623
+ `softmax(l+a) = softmax(l)`. Based on
624
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
625
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
626
+
627
+ Args:
628
+ attention_mask (torch.Tensor): attention_mask tensor of shape (`batch_size`, `key_length`)
629
+ num_heads (int): `num_heads` for the model
630
+ batch_size (int): `batch_size`
631
+ key_length (int): `key_length`
632
+ device (torch.device): device for the tensors
633
+ dtype (torch.dtype): dtype to use for the tensors
634
+
635
+ Returns:
636
+ torch.Tensor: alibi tensor of shape (`batch_size`, `num_heads`, `key_length`)
637
+ """
638
+
639
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
640
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
641
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
642
+ # => the query_length dimension will then be broadcasted correctly
643
+ # This is more or less identical to T5's relative position bias:
644
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
645
+ if attention_mask is None:
646
+ arange_tensor = (
647
+ torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1)
648
+ )
649
+ else:
650
+ arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1)
651
+
652
+ alibi = self.slopes.unsqueeze(1) * arange_tensor
653
+ return alibi.to(dtype)
654
+
655
+ def reset_parameters(self) -> None:
656
+ closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads))
657
+ base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32)
658
+ powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
659
+ slopes = torch.pow(base, powers)
660
+
661
+ if closest_power_of_2 != self.num_heads:
662
+ extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32)
663
+ num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2)
664
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32)
665
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
666
+
667
+ self.register_buffer("slopes", slopes, persistent=False)
668
+
669
+
670
+ class RoPE(nn.Module):
671
+ def __init__(
672
+ self,
673
+ head_dim: int,
674
+ max_position_embeddings: int = 2048,
675
+ base: int = 10000,
676
+ ) -> None:
677
+ super().__init__()
678
+
679
+ self.head_dim = head_dim
680
+ self.max_position_embeddings = max_position_embeddings
681
+ self.base = base
682
+ self.mscale = 1
683
+
684
+ self.reset_parameters()
685
+
686
+ def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
687
+ if seq_len > self.max_seq_len_cached:
688
+ self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype)
689
+
690
+ cos = self.cos_cached[:seq_len].to(dtype)
691
+ sin = self.sin_cached[:seq_len].to(dtype)
692
+
693
+ return cos, sin
694
+
695
+ def reset_parameters(self) -> None:
696
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
697
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
698
+
699
+ # Build here to make `torch.jit.trace` work.
700
+ self._set_cos_sin_cache(
701
+ seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
702
+ )
703
+
704
+ @torch.no_grad()
705
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
706
+ self.max_seq_len_cached = seq_len
707
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
708
+
709
+ freqs = torch.outer(t, self.inv_freq)
710
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
711
+ emb = torch.cat((freqs, freqs), dim=-1)
712
+
713
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
714
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
715
+
716
+
717
+ def apply_rotary_pos_emb(x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
718
+ cos, sin = cos_sin
719
+ x = (x * cos) + (_rotate_half(x) * sin)
720
+ return x
721
+
722
+
723
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
724
+ x1, x2 = torch.chunk(x, 2, dim=-1)
725
+ return torch.cat((-x2, x1), dim=-1)
726
+
727
+
728
+ ##################################################
729
+ # MLP
730
+
731
+
732
+ class GraniteMLP(nn.Module):
733
+ def __init__(self, config: GraniteConfig) -> None:
734
+ super().__init__()
735
+
736
+ hidden_size = config.n_embd
737
+ intermediate_size = config.n_inner
738
+ activation_function = config.activation_function
739
+ add_bias = config.add_bias
740
+ residual_dropout = config.resid_pdrop
741
+
742
+ self.c_fc = nn.Linear(
743
+ hidden_size,
744
+ 2 * intermediate_size if is_glu(activation_function) else intermediate_size,
745
+ bias=add_bias,
746
+ )
747
+ self.act = get_activation_function(activation_function)
748
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=add_bias)
749
+ self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout)
750
+
751
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
752
+ hidden_states = self.c_fc(hidden_states)
753
+ hidden_states = self.act(hidden_states)
754
+ hidden_states = self.c_proj(hidden_states)
755
+ hidden_states = self.dropout(hidden_states)
756
+ return hidden_states
757
+
758
+
759
+ ##################################################
760
+ # transformer layer
761
+
762
+
763
+ class GraniteBlock(nn.Module):
764
+ def __init__(
765
+ self,
766
+ config: GraniteConfig,
767
+ attention_implementation: str,
768
+ layer_idx: Optional[int] = None,
769
+ ) -> None:
770
+ super().__init__()
771
+
772
+ hidden_size = config.hidden_size
773
+ self.inner_dim = config.n_inner
774
+ self.layer_idx = layer_idx
775
+
776
+ self.ln_1 = get_normalization_function(
777
+ config.normalization_function,
778
+ hidden_size,
779
+ eps=config.layer_norm_epsilon,
780
+ )
781
+ self.attn = get_attention_module(config, True, attention_implementation, layer_idx)
782
+ self.ln_2 = get_normalization_function(
783
+ config.normalization_function,
784
+ hidden_size,
785
+ eps=config.layer_norm_epsilon,
786
+ )
787
+ self.mlp = GraniteMLP(config)
788
+
789
+ def forward(
790
+ self,
791
+ hidden_states: torch.Tensor,
792
+ past_key_values: Optional[DynamicCache] = None,
793
+ attention_mask: Optional[torch.Tensor] = None,
794
+ rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
795
+ ) -> torch.Tensor:
796
+ residual = hidden_states
797
+ hidden_states = self.ln_1(hidden_states)
798
+
799
+ attn_output = self.attn(
800
+ hidden_states,
801
+ past_key_values=past_key_values,
802
+ attention_mask=attention_mask,
803
+ rope_cos_sin=rope_cos_sin,
804
+ )
805
+
806
+ # residual connection
807
+ hidden_states = attn_output + residual
808
+
809
+ residual = hidden_states
810
+ hidden_states = self.ln_2(hidden_states)
811
+
812
+ feed_forward_hidden_states = self.mlp(hidden_states)
813
+
814
+ # residual connection
815
+ hidden_states = residual + feed_forward_hidden_states
816
+
817
+ return hidden_states
818
+
819
+
820
+ ##################################################
821
+ # model classes
822
+
823
+
824
+ class GranitePreTrainedModel(PreTrainedModel):
825
+ """
826
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
827
+ models.
828
+ """
829
+
830
+ config_class = GraniteConfig
831
+ base_model_prefix = "transformer"
832
+ causal = True
833
+ _no_split_modules = ["GraniteBlock"]
834
+ _skip_keys_device_placement = "past_key_values"
835
+ _supports_sdpa = True
836
+ _supports_flash_attn_2 = True
837
+
838
+ def __init__(self, config: GraniteConfig, *inputs, **kwargs):
839
+ super().__init__(config, *inputs, **kwargs)
840
+
841
+ self.attention_implementation = self.config._attn_implementation
842
+ self._use_eager_attention = self.attention_implementation == "eager"
843
+ self._use_sdpa = self.attention_implementation == "sdpa"
844
+ self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2"
845
+
846
+ self.initializer_range = config.initializer_range
847
+
848
+ def _init_weights(self, module: nn.Module) -> None:
849
+ if isinstance(module, (nn.LayerNorm, RMSNorm, RoPE)):
850
+ module.reset_parameters()
851
+ elif isinstance(module, nn.Linear):
852
+ nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
853
+ if module.bias is not None:
854
+ module.bias.zero_()
855
+ elif isinstance(module, nn.Embedding):
856
+ nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
857
+ if module.padding_idx is not None:
858
+ module.weight[module.padding_idx].zero_()
859
+
860
+
861
+ class GraniteModel(GranitePreTrainedModel):
862
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
863
+ mask_value = None
864
+
865
+ def __init__(self, config: GraniteConfig, **kwargs) -> None:
866
+ super().__init__(config, **kwargs)
867
+
868
+ self.attention_head_type = AttentionHeadType(config.attention_head_type)
869
+ self.embed_dim = config.hidden_size
870
+ self.num_heads = config.num_attention_heads
871
+ self.num_key_value_heads = config.num_key_value_heads
872
+
873
+ assert (
874
+ self.embed_dim % self.num_heads == 0
875
+ ), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})"
876
+
877
+ self.head_dim = self.embed_dim // self.num_heads
878
+
879
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
880
+
881
+ self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop)
882
+ self.h = nn.ModuleList(
883
+ [GraniteBlock(config, self.attention_implementation, layer_idx=i) for i in range(config.num_hidden_layers)]
884
+ )
885
+ self.ln_f = get_normalization_function(
886
+ config.normalization_function,
887
+ self.embed_dim,
888
+ eps=config.layer_norm_epsilon,
889
+ )
890
+
891
+ self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
892
+
893
+ if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
894
+ self.wpe = nn.Embedding(config.n_positions, self.embed_dim)
895
+ elif self.position_embedding_type == PositionEmbeddingType.alibi:
896
+ assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention"
897
+
898
+ self.alibi = Alibi(self.num_heads)
899
+ elif self.position_embedding_type == PositionEmbeddingType.rope:
900
+ self.rope = RoPE(self.head_dim, max_position_embeddings=config.n_positions, base=config.rope_theta)
901
+ else:
902
+ raise NotImplementedError()
903
+
904
+ # Initialize weights and apply final processing
905
+ self.post_init()
906
+
907
+ def get_input_embeddings(self) -> nn.Embedding:
908
+ return self.wte
909
+
910
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
911
+ self.wte = new_embeddings
912
+
913
+ def forward(
914
+ self,
915
+ input_ids: Optional[torch.Tensor] = None,
916
+ past_key_values: Optional[DynamicCache] = None,
917
+ attention_mask: Optional[torch.Tensor] = None,
918
+ token_type_ids: Optional[torch.Tensor] = None,
919
+ position_ids: Optional[torch.Tensor] = None,
920
+ inputs_embeds: Optional[torch.Tensor] = None,
921
+ use_cache: Optional[bool] = None,
922
+ output_hidden_states: Optional[bool] = None,
923
+ return_dict: Optional[bool] = None,
924
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
925
+ (
926
+ output_hidden_states,
927
+ use_cache,
928
+ return_dict,
929
+ input_shape,
930
+ hidden_states,
931
+ attention_mask,
932
+ position_ids,
933
+ rope_cos_sin,
934
+ past_key_values,
935
+ ) = self._prepare_a_bunch_of_stuff(
936
+ input_ids=input_ids,
937
+ past_key_values=past_key_values,
938
+ attention_mask=attention_mask,
939
+ token_type_ids=token_type_ids,
940
+ position_ids=position_ids,
941
+ inputs_embeds=inputs_embeds,
942
+ use_cache=use_cache,
943
+ output_hidden_states=output_hidden_states,
944
+ return_dict=return_dict,
945
+ )
946
+
947
+ # ==========================================================================================
948
+ # flash:
949
+ # attention_mask -> (batch_size, key_length)
950
+ # else:
951
+ # attention_mask -> (batch_size, 1, query_length, key_length)
952
+ # ==========================================================================================
953
+
954
+ output_shape = input_shape + (hidden_states.size(-1),)
955
+
956
+ past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values
957
+ all_hidden_states = () if output_hidden_states else None
958
+ for block in self.h:
959
+ if output_hidden_states:
960
+ all_hidden_states += (hidden_states,)
961
+
962
+ hidden_states = block(
963
+ hidden_states,
964
+ past_key_values=past_key_values,
965
+ attention_mask=attention_mask,
966
+ rope_cos_sin=rope_cos_sin,
967
+ )
968
+
969
+ hidden_states = self.ln_f(hidden_states)
970
+
971
+ hidden_states = hidden_states.view(output_shape)
972
+ # Add last hidden state
973
+ if output_hidden_states:
974
+ all_hidden_states += (hidden_states,)
975
+
976
+ if not return_dict:
977
+ return tuple(v for v in [hidden_states, past_key_values, all_hidden_states] if v is not None)
978
+
979
+ return BaseModelOutputWithPastAndCrossAttentions(
980
+ last_hidden_state=hidden_states,
981
+ past_key_values=past_key_values,
982
+ hidden_states=all_hidden_states,
983
+ )
984
+
985
+ def _get_position_ids(
986
+ self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device
987
+ ) -> torch.Tensor:
988
+ if attention_mask is not None and len(attention_mask.shape) == 2:
989
+ # create position_ids on the fly for batch generation
990
+ position_ids = attention_mask.long().cumsum(-1) - 1
991
+ position_ids.masked_fill_(attention_mask == 0, 0)
992
+ if past_length > 0:
993
+ position_ids = position_ids[:, past_length:key_length:]
994
+ else:
995
+ position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device)
996
+ position_ids = position_ids.unsqueeze(0).view(-1, query_length)
997
+
998
+ return position_ids
999
+
1000
+ def _get_alibi_bias(
1001
+ self,
1002
+ attention_mask: torch.Tensor,
1003
+ batch_size: int,
1004
+ query_length: int,
1005
+ key_length: int,
1006
+ device: torch.device,
1007
+ dtype: torch.dtype,
1008
+ ) -> torch.Tensor:
1009
+ if self.position_embedding_type != PositionEmbeddingType.alibi:
1010
+ return None
1011
+
1012
+ alibi_bias = self.alibi(attention_mask, batch_size, key_length, device, dtype)
1013
+
1014
+ # ==========================================================================================
1015
+ # alibi_bias -> (batch_size, num_heads, key_length)
1016
+ # ==========================================================================================
1017
+
1018
+ alibi_bias = alibi_bias.unsqueeze(2)
1019
+ if query_length != 1:
1020
+ alibi_bias = alibi_bias.expand(-1, -1, query_length, -1)
1021
+
1022
+ # ==========================================================================================
1023
+ # alibi_bias -> (batch_size, num_heads, query_length, key_length)
1024
+ # ==========================================================================================
1025
+
1026
+ return alibi_bias
1027
+
1028
+ def _get_rope_cos_sin(
1029
+ self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device
1030
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
1031
+ if self.position_embedding_type == PositionEmbeddingType.rope:
1032
+ cos, sin = self.rope(key_length, dtype=dtype, device=device)
1033
+ cos = cos[position_ids].unsqueeze(1)
1034
+ sin = sin[position_ids].unsqueeze(1)
1035
+ return cos, sin
1036
+
1037
+ def _prepare_causal_attention_mask(
1038
+ self, attention_mask: torch.Tensor, batch_size: int, query_length: int, key_length: int, device: torch.device
1039
+ ) -> torch.Tensor:
1040
+ past_length = key_length - query_length
1041
+
1042
+ # ==========================================================================================
1043
+ # attention_mask -> (batch_size, key_length)
1044
+ # ==========================================================================================
1045
+
1046
+ if query_length > 1:
1047
+ # (query_length, key_length)
1048
+ causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device)
1049
+ causal_mask[:, past_length:] = torch.tril(
1050
+ torch.ones(query_length, query_length, dtype=torch.bool, device=device)
1051
+ )
1052
+
1053
+ if past_length > 0:
1054
+ causal_mask[:, :past_length] = True
1055
+
1056
+ # (query_length, key_length) -> (1, query_length, key_length)
1057
+ causal_mask = causal_mask.unsqueeze(0)
1058
+
1059
+ if attention_mask is None:
1060
+ # (1, query_length, key_length) -> (batch_size, query_length, key_length)
1061
+ causal_mask = causal_mask.expand(batch_size, -1, -1)
1062
+ else:
1063
+ # (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length)
1064
+ causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool)
1065
+ else:
1066
+ if attention_mask is None:
1067
+ # (batch_size, query_length, key_length)
1068
+ causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device)
1069
+ else:
1070
+ # (batch_size, query_length, key_length)
1071
+ causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device)
1072
+
1073
+ # ==========================================================================================
1074
+ # attention_mask -> (batch_size, query_length, key_length)
1075
+ # ==========================================================================================
1076
+
1077
+ causal_mask = causal_mask.unsqueeze(1)
1078
+
1079
+ # ==========================================================================================
1080
+ # attention_mask -> (batch_size, 1, query_length, key_length)
1081
+ # ==========================================================================================
1082
+
1083
+ return causal_mask
1084
+
1085
+ def _get_initial_hidden_state(
1086
+ self,
1087
+ input_ids: torch.Tensor,
1088
+ inputs_embeds: torch.Tensor,
1089
+ position_ids: torch.Tensor,
1090
+ token_type_ids: torch.Tensor,
1091
+ ) -> torch.Tensor:
1092
+ if inputs_embeds is None:
1093
+ inputs_embeds = self.wte(input_ids)
1094
+
1095
+ if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
1096
+ inputs_embeds = inputs_embeds + self.wpe(position_ids)
1097
+
1098
+ if token_type_ids is not None:
1099
+ inputs_embeds = inputs_embeds + self.wte(token_type_ids)
1100
+
1101
+ inputs_embeds = self.drop(inputs_embeds)
1102
+
1103
+ return inputs_embeds
1104
+
1105
+ def _prepare_a_bunch_of_stuff(
1106
+ self,
1107
+ input_ids: torch.Tensor = None,
1108
+ past_key_values: DynamicCache = None,
1109
+ attention_mask: torch.Tensor = None,
1110
+ token_type_ids: torch.Tensor = None,
1111
+ position_ids: torch.Tensor = None,
1112
+ inputs_embeds: torch.Tensor = None,
1113
+ use_cache: bool = None,
1114
+ output_hidden_states: bool = None,
1115
+ return_dict: bool = None,
1116
+ ) -> Tuple[
1117
+ bool,
1118
+ bool,
1119
+ bool,
1120
+ torch.Size,
1121
+ torch.Tensor,
1122
+ torch.Tensor,
1123
+ torch.Tensor,
1124
+ Optional[Tuple[torch.Tensor, torch.Tensor]],
1125
+ DynamicCache,
1126
+ ]:
1127
+ output_hidden_states = (
1128
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1129
+ )
1130
+
1131
+ use_cache = self.config.use_cache if use_cache is None else use_cache
1132
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1133
+
1134
+ if input_ids is not None and inputs_embeds is not None:
1135
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1136
+ elif input_ids is not None:
1137
+ input_shape = input_ids.size()
1138
+ elif inputs_embeds is not None:
1139
+ # TODO special handling for padding free transformer needed here if we support inputs_embeds argument
1140
+ input_shape = inputs_embeds.size()[:-1]
1141
+ else:
1142
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1143
+
1144
+ batch_size = input_shape[0]
1145
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1146
+
1147
+ if self.position_embedding_type == PositionEmbeddingType.alibi:
1148
+ if position_ids is not None:
1149
+ warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning)
1150
+
1151
+ if token_type_ids is not None:
1152
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
1153
+
1154
+ # ==========================================================================================
1155
+ # input_ids -> (batch_size, query_length)
1156
+ # attention_mask -> None or (batch_size, key_length)
1157
+ # position_ids -> None or (batch_size, key_length)
1158
+ # ==========================================================================================
1159
+
1160
+ past_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1161
+ query_length = input_shape[-1]
1162
+ key_length = past_length + query_length
1163
+
1164
+ if position_ids is None:
1165
+ position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device)
1166
+
1167
+ # ==========================================================================================
1168
+ # input_ids -> (batch_size, query_length)
1169
+ # attention_mask -> None or (batch_size, key_length)
1170
+ # position_ids -> (batch_size, query_length)
1171
+ # ==========================================================================================
1172
+
1173
+ hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids)
1174
+
1175
+ # ==========================================================================================
1176
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
1177
+ # ==========================================================================================
1178
+
1179
+ alibi_bias = self._get_alibi_bias(
1180
+ attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype
1181
+ )
1182
+
1183
+ # ==========================================================================================
1184
+ # alibi_bias -> (batch_size, num_heads, query_length, key_length)
1185
+ # ==========================================================================================
1186
+
1187
+ rope_cos_sin = self._get_rope_cos_sin(
1188
+ key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device
1189
+ )
1190
+
1191
+ # ==========================================================================================
1192
+ # rope_cos_sin -> 2 * (key_length, head_dim)
1193
+ # ==========================================================================================
1194
+
1195
+ # prepare causal mask only if not using flash attention
1196
+ if self._use_flash_attention_2:
1197
+ if attention_mask is None:
1198
+ attention_mask = torch.ones_like(input_ids)
1199
+ elif self._use_sdpa:
1200
+ # we use the causal/non-causal argument of SDPA for attention in this case
1201
+ if attention_mask is not None:
1202
+ attention_mask = self._prepare_causal_attention_mask(
1203
+ attention_mask, batch_size, query_length, key_length, device
1204
+ )
1205
+
1206
+ attention_mask = torch.where(
1207
+ attention_mask,
1208
+ ~attention_mask if alibi_bias is None else alibi_bias,
1209
+ self._get_mask_value(attention_mask.device, hidden_states.dtype),
1210
+ )
1211
+ else:
1212
+ attention_mask = self._prepare_causal_attention_mask(
1213
+ attention_mask, batch_size, query_length, key_length, device
1214
+ )
1215
+
1216
+ attention_mask = torch.where(
1217
+ attention_mask,
1218
+ ~attention_mask if alibi_bias is None else alibi_bias,
1219
+ self._get_mask_value(attention_mask.device, hidden_states.dtype),
1220
+ )
1221
+
1222
+ return (
1223
+ output_hidden_states,
1224
+ use_cache,
1225
+ return_dict,
1226
+ input_shape,
1227
+ hidden_states,
1228
+ attention_mask,
1229
+ position_ids,
1230
+ rope_cos_sin,
1231
+ past_key_values,
1232
+ )
1233
+
1234
+ def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1235
+ # torch.where expects a tensor. We use a cache to avoid recreating it every time.
1236
+ if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
1237
+ self.mask_value = torch.full([], torch.finfo(torch.float16).min, dtype=dtype, device=device)
1238
+ return self.mask_value
1239
+
1240
+
1241
+ class GraniteForCausalLM(GranitePreTrainedModel):
1242
+ _keys_to_ignore_on_load_missing = ["lm_head.weight"]
1243
+
1244
+ def __init__(self, config: GraniteConfig, **kwargs) -> None:
1245
+ super().__init__(config, **kwargs)
1246
+ self.transformer = GraniteModel(config, **kwargs)
1247
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1248
+
1249
+ # Initialize weights and apply final processing
1250
+ self.post_init()
1251
+
1252
+ def get_input_embeddings(self) -> nn.Embedding:
1253
+ return self.transformer.wte
1254
+
1255
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
1256
+ self.transformer.wte = value
1257
+
1258
+ def get_output_embeddings(self) -> nn.Linear:
1259
+ return self.lm_head
1260
+
1261
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1262
+ self.lm_head = new_embeddings
1263
+
1264
+ # FIXME typing
1265
+ def prepare_inputs_for_generation(
1266
+ self,
1267
+ input_ids: torch.Tensor,
1268
+ past_key_values: Optional[DynamicCache] = None,
1269
+ inputs_embeds: Optional[torch.Tensor] = None,
1270
+ **kwargs,
1271
+ ) -> dict:
1272
+ token_type_ids = kwargs.get("token_type_ids", None)
1273
+ # Omit tokens covered by past_key_values
1274
+ if past_key_values:
1275
+ past_length = past_key_values.get_seq_length()
1276
+
1277
+ # Some generation methods already pass only the last input ID
1278
+ if input_ids.shape[1] > past_length:
1279
+ remove_prefix_length = past_length
1280
+ else:
1281
+ # Default to old behavior: keep only final ID
1282
+ remove_prefix_length = input_ids.shape[1] - 1
1283
+
1284
+ input_ids = input_ids[:, remove_prefix_length:]
1285
+ if token_type_ids is not None:
1286
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1287
+
1288
+ attention_mask = kwargs.get("attention_mask", None)
1289
+ position_ids = kwargs.get("position_ids", None)
1290
+
1291
+ if attention_mask is not None and position_ids is None:
1292
+ # create position_ids on the fly for batch generation
1293
+ position_ids = attention_mask.long().cumsum(-1) - 1
1294
+ position_ids.masked_fill_(attention_mask == 0, 0)
1295
+ if past_key_values:
1296
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1297
+ else:
1298
+ position_ids = None
1299
+
1300
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1301
+ if inputs_embeds is not None and past_key_values is None:
1302
+ model_inputs = {"inputs_embeds": inputs_embeds}
1303
+ else:
1304
+ model_inputs = {"input_ids": input_ids}
1305
+
1306
+ model_inputs.update(
1307
+ {
1308
+ "past_key_values": past_key_values,
1309
+ "use_cache": kwargs.get("use_cache"),
1310
+ "position_ids": position_ids,
1311
+ "attention_mask": attention_mask,
1312
+ "token_type_ids": token_type_ids,
1313
+ }
1314
+ )
1315
+ return model_inputs
1316
+
1317
+ def forward(
1318
+ self,
1319
+ input_ids: Optional[Union[torch.Tensor]] = None,
1320
+ past_key_values: Optional[DynamicCache] = None,
1321
+ attention_mask: Optional[torch.Tensor] = None,
1322
+ token_type_ids: Optional[Union[torch.Tensor]] = None,
1323
+ position_ids: Optional[Union[torch.Tensor]] = None,
1324
+ inputs_embeds: Optional[Union[torch.Tensor]] = None,
1325
+ labels: Optional[Union[torch.Tensor]] = None,
1326
+ use_cache: Optional[bool] = None,
1327
+ output_attentions: Optional[bool] = None,
1328
+ output_hidden_states: Optional[bool] = None,
1329
+ return_dict: Optional[bool] = None,
1330
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1331
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1332
+
1333
+ # ==========================================================================================
1334
+ # input_ids -> (batch_size, query_length)
1335
+ # attention_mask -> None or (batch_size, key_length)
1336
+ # position_ids -> None or (batch_size, key_length)
1337
+ # ==========================================================================================
1338
+
1339
+ transformer_outputs = self.transformer(
1340
+ input_ids,
1341
+ past_key_values=past_key_values,
1342
+ attention_mask=attention_mask,
1343
+ token_type_ids=token_type_ids,
1344
+ position_ids=position_ids,
1345
+ inputs_embeds=inputs_embeds,
1346
+ use_cache=use_cache,
1347
+ output_hidden_states=output_hidden_states,
1348
+ return_dict=return_dict,
1349
+ )
1350
+ hidden_states = transformer_outputs[0]
1351
+
1352
+ lm_logits = self.lm_head(hidden_states)
1353
+
1354
+ loss = None
1355
+ # Shift so that tokens < n predict n
1356
+ if labels is not None:
1357
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1358
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
1359
+
1360
+ # Flatten the tokens
1361
+ loss_fct = nn.CrossEntropyLoss()
1362
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1363
+
1364
+ if not return_dict:
1365
+ output = (lm_logits,) + transformer_outputs[1:]
1366
+ return ((loss,) + output) if loss is not None else output
1367
+
1368
+ return CausalLMOutputWithCrossAttentions(
1369
+ loss=loss,
1370
+ logits=lm_logits,
1371
+ past_key_values=transformer_outputs.past_key_values,
1372
+ hidden_states=transformer_outputs.hidden_states,
1373
+ attentions=transformer_outputs.attentions,
1374
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|endoftext|>",
4
+ "<fim_prefix>",
5
+ "<fim_middle>",
6
+ "<fim_suffix>",
7
+ "<fim_pad>",
8
+ "<filename>",
9
+ "<gh_stars>",
10
+ "<issue_start>",
11
+ "<issue_comment>",
12
+ "<issue_closed>",
13
+ "<jupyter_start>",
14
+ "<jupyter_text>",
15
+ "<jupyter_code>",
16
+ "<jupyter_output>",
17
+ "<empty_output>",
18
+ "<commit_before>",
19
+ "<commit_msg>",
20
+ "<commit_after>",
21
+ "<reponame>"
22
+ ],
23
+ "bos_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "eos_token": {
31
+ "content": "<|endoftext|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "pad_token": {
38
+ "content": "<|endoftext|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<|endoftext|>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<fim_prefix>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<fim_middle>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<fim_suffix>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "<fim_pad>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "5": {
45
+ "content": "<filename>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "6": {
53
+ "content": "<gh_stars>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "7": {
61
+ "content": "<issue_start>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "8": {
69
+ "content": "<issue_comment>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "9": {
77
+ "content": "<issue_closed>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "10": {
85
+ "content": "<jupyter_start>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "11": {
93
+ "content": "<jupyter_text>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "12": {
101
+ "content": "<jupyter_code>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "13": {
109
+ "content": "<jupyter_output>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "14": {
117
+ "content": "<empty_output>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "15": {
125
+ "content": "<commit_before>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "16": {
133
+ "content": "<commit_msg>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "17": {
141
+ "content": "<commit_after>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "18": {
149
+ "content": "<reponame>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ }
156
+ },
157
+ "additional_special_tokens": [
158
+ "<|endoftext|>",
159
+ "<fim_prefix>",
160
+ "<fim_middle>",
161
+ "<fim_suffix>",
162
+ "<fim_pad>",
163
+ "<filename>",
164
+ "<gh_stars>",
165
+ "<issue_start>",
166
+ "<issue_comment>",
167
+ "<issue_closed>",
168
+ "<jupyter_start>",
169
+ "<jupyter_text>",
170
+ "<jupyter_code>",
171
+ "<jupyter_output>",
172
+ "<empty_output>",
173
+ "<commit_before>",
174
+ "<commit_msg>",
175
+ "<commit_after>",
176
+ "<reponame>"
177
+ ],
178
+ "bos_token": "<|endoftext|>",
179
+ "clean_up_tokenization_spaces": true,
180
+ "eos_token": "<|endoftext|>",
181
+ "model_max_length": 9223372036854775807,
182
+ "pad_token": "<|endoftext|>",
183
+ "padding_side": "left",
184
+ "tokenizer_class": "GPT2Tokenizer",
185
+ "unk_token": "<|endoftext|>",
186
+ "vocab_size": 49152
187
+ }