ruixie commited on
Commit
a749dfd
·
1 Parent(s): 75110a3

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -1,36 +1,37 @@
1
  {
2
- "_name_or_path": "/nvme/share/shellm/hf_configs/7b_gq_rope",
3
  "activation_function": "gelu_pytorch_tanh",
4
  "architectures": [
5
- "KCLGPTForCausalLM"
6
  ],
7
  "attention_softmax_in_fp32": true,
8
  "attn_pdrop": 0.1,
9
  "auto_map": {
10
- "AutoConfig": "configuration_kclgpt.KCLGPTConfig",
11
- "AutoModelForCausalLM": "modeling_kclgpt.KCLGPTForCausalLM"
12
  },
 
 
 
13
  "bos_token_id": 70000,
14
- "embd_pdrop": 0.1,
15
  "eos_token_id": 70000,
16
- "group_query_attention": true,
 
17
  "inference_runner": 0,
18
  "initializer_range": 0.02,
19
  "layer_norm_epsilon": 1e-05,
20
  "max_batch_size": null,
21
  "max_sequence_length": null,
22
  "model_type": "kclgpt",
 
23
  "n_embd": 4096,
24
- "n_head": 32,
25
  "n_inner": 16384,
26
- "n_layer": 42,
27
  "n_positions": 8192,
28
- "num_query_groups": 8,
29
  "pad_key_length": true,
30
- "position_embedding_type": "rope",
31
- "pre_allocate_kv_cache": false,
32
  "resid_pdrop": 0.1,
33
  "rope_scaling": null,
 
34
  "scale_attention_softmax_in_fp32": true,
35
  "scale_attn_weights": true,
36
  "summary_activation": null,
@@ -39,8 +40,7 @@
39
  "summary_type": "cls_index",
40
  "summary_use_proj": true,
41
  "torch_dtype": "bfloat16",
42
- "transformers_version": "4.29.2",
43
  "use_cache": true,
44
- "validate_runner_input": true,
45
- "vocab_size": 70144
46
  }
 
1
  {
2
+ "_name_or_path": "WisdomShell/CodeShell",
3
  "activation_function": "gelu_pytorch_tanh",
4
  "architectures": [
5
+ "CodeShellForCausalLM"
6
  ],
7
  "attention_softmax_in_fp32": true,
8
  "attn_pdrop": 0.1,
9
  "auto_map": {
10
+ "AutoConfig": "configuration_codeshell.CodeShellConfig",
11
+ "AutoModelForCausalLM": "modeling_codeshell.CodeShellForCausalLM"
12
  },
13
+ "group_query_attention": true,
14
+ "num_query_groups": 8,
15
+ "position_embedding_type": "rope",
16
  "bos_token_id": 70000,
 
17
  "eos_token_id": 70000,
18
+ "vocab_size": 70144,
19
+ "embd_pdrop": 0.1,
20
  "inference_runner": 0,
21
  "initializer_range": 0.02,
22
  "layer_norm_epsilon": 1e-05,
23
  "max_batch_size": null,
24
  "max_sequence_length": null,
25
  "model_type": "kclgpt",
26
+ "n_layer": 42,
27
  "n_embd": 4096,
 
28
  "n_inner": 16384,
29
+ "n_head": 32,
30
  "n_positions": 8192,
 
31
  "pad_key_length": true,
 
 
32
  "resid_pdrop": 0.1,
33
  "rope_scaling": null,
34
+ "pre_allocate_kv_cache": false,
35
  "scale_attention_softmax_in_fp32": true,
36
  "scale_attn_weights": true,
37
  "summary_activation": null,
 
40
  "summary_type": "cls_index",
41
  "summary_use_proj": true,
42
  "torch_dtype": "bfloat16",
43
+ "transformers_version": "4.31.0",
44
  "use_cache": true,
45
+ "validate_runner_input": true
 
46
  }
configuration_codeshell.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This code is based on Bigcode's GPTBigCode configuration. It has been modified from
17
+ # its original forms to accommodate minor architectural differences compared to
18
+ # GPTBigCode Configuration that trained the model.
19
+
20
+ # coding=utf-8
21
+ # Copyright 2023 The BigCode team and HuggingFace Inc. team.
22
+ #
23
+ # Licensed under the Apache License, Version 2.0 (the "License");
24
+ # you may not use this file except in compliance with the License.
25
+ # You may obtain a copy of the License at
26
+ #
27
+ # http://www.apache.org/licenses/LICENSE-2.0
28
+ #
29
+ # Unless required by applicable law or agreed to in writing, software
30
+ # distributed under the License is distributed on an "AS IS" BASIS,
31
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32
+ # See the License for the specific language governing permissions and
33
+ # limitations under the License.
34
+ """ CodeShell configuration"""
35
+
36
+ from transformers.configuration_utils import PretrainedConfig
37
+ from transformers.utils import logging
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ class CodeShellConfig(PretrainedConfig):
44
+ """
45
+ This is the configuration class to store the configuration of a [`CodeShellModel`]. It is used to instantiate a
46
+ CodeShell model according to the specified arguments, defining the model architecture.
47
+
48
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
49
+ documentation from [`PretrainedConfig`] for more information.
50
+
51
+ Args:
52
+ vocab_size (`int`, *optional*, defaults to 50257):
53
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
54
+ `inputs_ids` passed when calling [`CodeShellModel`].
55
+ n_positions (`int`, *optional*, defaults to 1024):
56
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
57
+ just in case (e.g., 512 or 1024 or 2048).
58
+ n_embd (`int`, *optional*, defaults to 768):
59
+ Dimensionality of the embeddings and hidden states.
60
+ n_layer (`int`, *optional*, defaults to 12):
61
+ Number of hidden layers in the Transformer encoder.
62
+ n_head (`int`, *optional*, defaults to 12):
63
+ Number of attention heads for each attention layer in the Transformer encoder.
64
+ n_inner (`int`, *optional*, defaults to None):
65
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
66
+ activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`):
67
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new",
68
+ "gelu_pytorch_tanh"]`.
69
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
70
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
71
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
72
+ The dropout ratio for the embeddings.
73
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
74
+ The dropout ratio for the attention.
75
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
76
+ The epsilon to use in the layer normalization layers.
77
+ initializer_range (`float`, *optional*, defaults to 0.02):
78
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
79
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
80
+ Scale attention weights by dividing by sqrt(hidden_size)..
81
+ use_cache (`bool`, *optional*, defaults to `True`):
82
+ Whether or not the model should return the last key/values attentions (not used by all models).
83
+ attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
84
+ Whether to call the fused softmax in float32.
85
+ scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
86
+ Whether to scale the attention softmax in float32.
87
+ attention_type (`bool`, *optional*, defaults to `True`):
88
+ Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`).
89
+ """
90
+
91
+ model_type = "codeshell"
92
+ keys_to_ignore_at_inference = ["past_key_values"]
93
+ attribute_map = {
94
+ "hidden_size": "n_embd",
95
+ "max_position_embeddings": "n_positions",
96
+ "num_attention_heads": "n_head",
97
+ "num_hidden_layers": "n_layer",
98
+ }
99
+
100
+ def __init__(
101
+ self,
102
+ vocab_size=50257,
103
+ n_positions=1024,
104
+ n_embd=768,
105
+ n_layer=12,
106
+ n_head=12,
107
+ n_inner=None,
108
+ activation_function="gelu_pytorch_tanh",
109
+ resid_pdrop=0.1,
110
+ embd_pdrop=0.1,
111
+ attn_pdrop=0.1,
112
+ layer_norm_epsilon=1e-5,
113
+ initializer_range=0.02,
114
+ scale_attn_weights=True,
115
+ use_cache=True,
116
+ bos_token_id=50256,
117
+ eos_token_id=50256,
118
+ attention_softmax_in_fp32=True,
119
+ scale_attention_softmax_in_fp32=True,
120
+ group_query_attention=True,
121
+ num_query_groups=1,
122
+ position_embedding_type="learned_absolute",
123
+ rope_scaling=None,
124
+ **kwargs,
125
+ ):
126
+ self.vocab_size = vocab_size
127
+ self.n_positions = n_positions
128
+ self.n_embd = n_embd
129
+ self.n_layer = n_layer
130
+ self.n_head = n_head
131
+ self.n_inner = n_inner
132
+ self.activation_function = activation_function
133
+ self.resid_pdrop = resid_pdrop
134
+ self.embd_pdrop = embd_pdrop
135
+ self.attn_pdrop = attn_pdrop
136
+ self.layer_norm_epsilon = layer_norm_epsilon
137
+ self.initializer_range = initializer_range
138
+ self.scale_attn_weights = scale_attn_weights
139
+ self.use_cache = use_cache
140
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
141
+ self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
142
+ self.group_query_attention = group_query_attention
143
+ self.num_query_groups = num_query_groups
144
+ self.position_embedding_type = position_embedding_type
145
+ self.rope_scaling = rope_scaling
146
+ assert self.position_embedding_type in [
147
+ "learned_absolute", "rope"
148
+ ], "position_embedding_type must be one of ['learned_absolute', 'rope']"
149
+
150
+ self.bos_token_id = bos_token_id
151
+ self.eos_token_id = eos_token_id
152
+
153
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
generation_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "_from_model_config": true,
3
- "bos_token_id": 70000,
4
- "eos_token_id": 70000,
5
- "transformers_version": "4.29.2"
6
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "transformers_version": "4.31.0"
6
  }
modeling_codeshell.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This code is based on Bigcode's GPTBigCode model. It has been modified from
17
+ # its original forms to accommodate minor architectural differences compared to
18
+ # GPTBigCode model that trained the model.
19
+
20
+ # Copyright 2023 The Bigcode team and HuggingFace Inc. team.
21
+ # Licensed under the Apache License, Version 2.0 (the "License");
22
+ # you may not use this file except in compliance with the License.
23
+ # You may obtain a copy of the License at
24
+ #
25
+ # http://www.apache.org/licenses/LICENSE-2.0
26
+ #
27
+ # Unless required by applicable law or agreed to in writing, software
28
+ # distributed under the License is distributed on an "AS IS" BASIS,
29
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30
+ # See the License for the specific language governing permissions and
31
+ # limitations under the License.
32
+
33
+ """PyTorch CodeShellGPT model."""
34
+ import math
35
+ from typing import List, Optional, Tuple, Union
36
+
37
+ import torch
38
+ import torch.utils.checkpoint
39
+ from torch import nn
40
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
41
+
42
+ from transformers.activations import ACT2FN
43
+ from transformers.modeling_outputs import (
44
+ BaseModelOutputWithPastAndCrossAttentions,
45
+ CausalLMOutputWithCrossAttentions,
46
+ )
47
+ from transformers.modeling_utils import PreTrainedModel
48
+ from transformers.utils import (
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ logging,
52
+ )
53
+ from .configuration_codeshell import CodeShellConfig
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ # Fused kernels
59
+ # Use separate functions for each case because conditionals prevent kernel fusion.
60
+ # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
61
+ # Is it doable without writing 32 functions?
62
+ @torch.jit.script
63
+ def upcast_masked_softmax(
64
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
65
+ ):
66
+ input_dtype = x.dtype
67
+ x = x.to(softmax_dtype) * scale
68
+ x = torch.where(mask, x, mask_value)
69
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
70
+ return x
71
+
72
+
73
+ @torch.jit.script
74
+ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
75
+ input_dtype = x.dtype
76
+ x = x.to(softmax_dtype) * scale
77
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
78
+ return x
79
+
80
+
81
+ @torch.jit.script
82
+ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
83
+ x = torch.where(mask, x, mask_value)
84
+ x = torch.nn.functional.softmax(x, dim=-1)
85
+ return x
86
+
87
+
88
+ class LlamaRotaryEmbedding(torch.nn.Module):
89
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
90
+ super().__init__()
91
+
92
+ self.dim = dim
93
+ self.max_position_embeddings = max_position_embeddings
94
+ self.base = base
95
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
96
+ self.register_buffer("inv_freq", inv_freq)
97
+
98
+ # Build here to make `torch.jit.trace` work.
99
+ self._set_cos_sin_cache(
100
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
101
+ )
102
+
103
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
104
+ self.max_seq_len_cached = seq_len
105
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
106
+
107
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
108
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
109
+ emb = torch.cat((freqs, freqs), dim=-1)
110
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
111
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
112
+
113
+ def forward(self, x, seq_len=None):
114
+ # x: [bs, num_attention_heads, seq_len, head_size]
115
+ if seq_len > self.max_seq_len_cached:
116
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
117
+
118
+ return (
119
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
120
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
121
+ )
122
+
123
+
124
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
125
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
126
+
127
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
128
+ self.scaling_factor = scaling_factor
129
+ super().__init__(dim, max_position_embeddings, base, device)
130
+
131
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
132
+ self.max_seq_len_cached = seq_len
133
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
134
+ t = t / self.scaling_factor
135
+
136
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
137
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
138
+ emb = torch.cat((freqs, freqs), dim=-1)
139
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
140
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
141
+
142
+
143
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
144
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
145
+
146
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
147
+ self.scaling_factor = scaling_factor
148
+ super().__init__(dim, max_position_embeddings, base, device)
149
+
150
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
151
+ self.max_seq_len_cached = seq_len
152
+
153
+ if seq_len > self.max_position_embeddings:
154
+ base = self.base * (
155
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
156
+ ) ** (self.dim / (self.dim - 2))
157
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
158
+ self.register_buffer("inv_freq", inv_freq)
159
+
160
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
161
+
162
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
163
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
164
+ emb = torch.cat((freqs, freqs), dim=-1)
165
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
166
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
167
+
168
+
169
+ def rotate_half(x):
170
+ """Rotates half the hidden dims of the input."""
171
+ x1 = x[..., : x.shape[-1] // 2]
172
+ x2 = x[..., x.shape[-1] // 2 :]
173
+ return torch.cat((-x2, x1), dim=-1)
174
+
175
+
176
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
177
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
178
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
179
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
180
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
181
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
182
+ q_embed = (q * cos) + (rotate_half(q) * sin)
183
+ k_embed = (k * cos) + (rotate_half(k) * sin)
184
+ return q_embed, k_embed
185
+
186
+
187
+ class CodeShellAttention(nn.Module):
188
+ def __init__(self, config, layer_idx=None):
189
+ super().__init__()
190
+ self.mask_value = None
191
+
192
+ self.position_embedding_type = config.position_embedding_type
193
+ self.rope_scaling = config.rope_scaling
194
+ self.max_position_embeddings = config.max_position_embeddings
195
+
196
+ self.group_query_attention = config.group_query_attention
197
+ self.num_query_groups = config.num_query_groups
198
+
199
+ self.embed_dim = config.hidden_size
200
+ self.num_heads = config.num_attention_heads
201
+ self.head_dim = self.embed_dim // self.num_heads
202
+ self.kv_heads = config.num_query_groups if self.group_query_attention else self.num_heads
203
+ self.kv_dim = self.kv_heads * self.head_dim
204
+ self.split_size = self.embed_dim
205
+ if self.head_dim * self.num_heads != self.embed_dim:
206
+ raise ValueError(
207
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
208
+ f" {self.num_heads})."
209
+ )
210
+
211
+ self.scale_attn_weights = config.scale_attn_weights
212
+
213
+ self.layer_idx = layer_idx
214
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
215
+ self.scale_attention_softmax_in_fp32 = (
216
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
217
+ )
218
+
219
+ self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
220
+
221
+ self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
222
+
223
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
224
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
225
+
226
+ if self.position_embedding_type == "rope":
227
+ self._init_rope()
228
+
229
+ def _init_rope(self):
230
+ if self.rope_scaling is None:
231
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
232
+ else:
233
+ scaling_type = self.rope_scaling["type"]
234
+ scaling_factor = self.rope_scaling["factor"]
235
+ if scaling_type == "linear":
236
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
237
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
238
+ )
239
+ elif scaling_type == "dynamic":
240
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
241
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
242
+ )
243
+ else:
244
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
245
+
246
+
247
+ def _get_mask_value(self, device, dtype):
248
+ # torch.where expects a tensor. We use a cache to avoid recreating it every time.
249
+ if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
250
+ self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
251
+ return self.mask_value
252
+
253
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
254
+ dtype = query.dtype
255
+ softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
256
+ upcast = dtype != softmax_dtype
257
+
258
+ unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
259
+ scale_factor = unscale**-1
260
+ if self.scale_attn_weights:
261
+ scale_factor /= self.head_dim**0.5
262
+
263
+ # [b, np, sq, sk]
264
+ output_size = (query.size(1),
265
+ query.size(2),
266
+ query.size(0),
267
+ key.size(0))
268
+ attn_view = (output_size[0]*output_size[1], output_size[2], output_size[3])
269
+
270
+ # [sq, b, np, hn] -> [sq, b * np, hn]
271
+ query = query.reshape(output_size[2],
272
+ output_size[0] * output_size[1], -1)
273
+ # [sk, b, np, hn] -> [sk, b * np, hn]
274
+ key = key.reshape(output_size[3],
275
+ output_size[0] * output_size[1], -1)
276
+ attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
277
+ if query.device.type == "cpu":
278
+ # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
279
+ # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
280
+ # but the fix has not been released as of pytorch version 2.0.0.
281
+ attn_weights = torch.zeros_like(attn_weights)
282
+ beta = 1
283
+ else:
284
+ beta = 0
285
+
286
+ attn_weights = torch.baddbmm(attn_weights,
287
+ query.transpose(0, 1),
288
+ key.transpose(0, 1).transpose(1, 2),
289
+ beta=beta, alpha=scale_factor).reshape(output_size)
290
+
291
+ if upcast:
292
+ # Use a fused kernel to prevent a large overhead from casting and scaling.
293
+ # Sub-optimal when the key length is not a multiple of 8.
294
+ if attention_mask is None:
295
+ attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
296
+ else:
297
+ mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
298
+ attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
299
+ else:
300
+ if attention_mask is not None:
301
+ mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
302
+
303
+ # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
304
+ attn_weights = torch.where(attention_mask, attn_weights, mask_value)
305
+
306
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
307
+
308
+ attn_weights = self.attn_dropout(attn_weights)
309
+
310
+ attn_weights = attn_weights.reshape(attn_view)
311
+
312
+ # value_layer -> context layer.
313
+ # [sk, b, np, hn] --> [b, np, sq, hn]
314
+
315
+ # context layer shape: [b, np, sq, hn]
316
+ output_size = (value.size(1),
317
+ value.size(2),
318
+ query.size(0),
319
+ value.size(3))
320
+
321
+ # change view [sk, b * np, hn]
322
+ value = value.reshape(value.size(0),
323
+ output_size[0] * output_size[1], -1)
324
+ attn_output = torch.bmm(attn_weights, value.transpose(0, 1))
325
+
326
+ # change view [b, np, sq, hn]
327
+ attn_output = attn_output.reshape(*output_size)
328
+ # [b, np, sq, hn] --> [sq, b, np, hn]
329
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous()
330
+
331
+ # [sq, b, np, hn] --> [sq, b, hp]
332
+ attn_output = attn_output.reshape(attn_output.size(0), attn_output.size(1), -1)
333
+
334
+ return attn_output, attn_weights
335
+
336
+ def forward(
337
+ self,
338
+ hidden_states: torch.Tensor,
339
+ layer_past: Optional[torch.Tensor] = None,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ position_ids: Optional[torch.LongTensor] = None,
342
+ head_mask: Optional[torch.Tensor] = None,
343
+ encoder_hidden_states: Optional[torch.Tensor] = None,
344
+ encoder_attention_mask: Optional[torch.Tensor] = None,
345
+ use_cache: Optional[bool] = False,
346
+ output_attentions: Optional[bool] = False,
347
+ ) -> Union[
348
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
349
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
350
+ ]:
351
+ if self.group_query_attention:
352
+ query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
353
+ else:
354
+ # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
355
+ # i.e., the memory layout is not the same as GPT2.
356
+ # This makes the concatenation with past_key_value more efficient.
357
+ query, key_value = (
358
+ self.c_attn(hidden_states)
359
+ .reshape(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
360
+ .transpose(1, 2)
361
+ .split((self.head_dim, 2 * self.head_dim), dim=3)
362
+ )
363
+
364
+ query = query.reshape(query.size(0), query.size(1), -1, self.head_dim)
365
+
366
+ key, value = key_value.split((self.head_dim*self.num_query_groups, self.head_dim*self.num_query_groups), dim=-1)
367
+ # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
368
+ key = key.reshape(key.size(0), key.size(1), -1, self.head_dim)
369
+ value = value.reshape(value.size(0), value.size(1), -1, self.head_dim)
370
+
371
+ key = key.repeat_interleave(
372
+ self.num_heads // self.num_query_groups,
373
+ dim = 2
374
+ )
375
+ value = value.repeat_interleave(
376
+ self.num_heads // self.num_query_groups,
377
+ dim = 2
378
+ )
379
+
380
+ if self.position_embedding_type == "rope":
381
+ kv_seq_len = key.shape[-3]
382
+ if layer_past is not None:
383
+ kv_seq_len += layer_past[0].shape[-3]
384
+
385
+ cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
386
+ query = query.transpose(1, 2).contiguous()
387
+ key = key.transpose(1, 2).contiguous()
388
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
389
+ query = query.transpose(1, 2).contiguous()
390
+ key = key.transpose(1, 2).contiguous()
391
+
392
+ if layer_past is not None:
393
+ key = torch.cat((layer_past[0], key), dim=-3)
394
+ value = torch.cat((layer_past[1], value), dim=-3)
395
+ present = (key, value) if use_cache else None
396
+
397
+ attn_output, attn_weights = self._attn(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1), attention_mask, head_mask)
398
+
399
+ attn_output = attn_output.transpose(0, 1).reshape(hidden_states.shape)
400
+ attn_output = self.c_proj(attn_output)
401
+ attn_output = self.resid_dropout(attn_output)
402
+
403
+ outputs = (attn_output, present)
404
+ if output_attentions:
405
+ if self.group_query_attention:
406
+ # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
407
+ attn_weights = attn_weights.transpose(1, 2)
408
+ outputs += (attn_weights,)
409
+
410
+ return outputs # a, present, (attentions)
411
+
412
+
413
+ class CodeShellMLP(nn.Module):
414
+ def __init__(self, intermediate_size, config):
415
+ super().__init__()
416
+ embed_dim = config.hidden_size
417
+ self.c_fc = nn.Linear(embed_dim, intermediate_size)
418
+ self.c_proj = nn.Linear(intermediate_size, embed_dim)
419
+ self.act = ACT2FN[config.activation_function]
420
+ self.dropout = nn.Dropout(config.resid_pdrop)
421
+
422
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
423
+ def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
424
+ hidden_states = self.c_fc(hidden_states)
425
+ hidden_states = self.act(hidden_states)
426
+ hidden_states = self.c_proj(hidden_states)
427
+ hidden_states = self.dropout(hidden_states)
428
+ return hidden_states
429
+
430
+
431
+ class CodeShellBlock(nn.Module):
432
+ def __init__(self, config, layer_idx=None):
433
+ super().__init__()
434
+ hidden_size = config.hidden_size
435
+ self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
436
+
437
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
438
+ self.attn = CodeShellAttention(config, layer_idx=layer_idx)
439
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
440
+
441
+ self.mlp = CodeShellMLP(self.inner_dim, config)
442
+
443
+ def forward(
444
+ self,
445
+ hidden_states: Optional[Tuple[torch.Tensor]],
446
+ layer_past: Optional[torch.Tensor] = None,
447
+ attention_mask: Optional[torch.Tensor] = None,
448
+ position_ids: Optional[torch.LongTensor] = None,
449
+ head_mask: Optional[torch.Tensor] = None,
450
+ encoder_hidden_states: Optional[torch.Tensor] = None,
451
+ encoder_attention_mask: Optional[torch.Tensor] = None,
452
+ use_cache: Optional[bool] = False,
453
+ output_attentions: Optional[bool] = False,
454
+ ) -> Union[
455
+ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
456
+ ]:
457
+ residual = hidden_states
458
+ hidden_states = self.ln_1(hidden_states)
459
+ attn_outputs = self.attn(
460
+ hidden_states,
461
+ layer_past=layer_past,
462
+ attention_mask=attention_mask,
463
+ position_ids=position_ids,
464
+ head_mask=head_mask,
465
+ use_cache=use_cache,
466
+ output_attentions=output_attentions,
467
+ )
468
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
469
+
470
+ outputs = attn_outputs[1:]
471
+ # residual connection
472
+ hidden_states = attn_output + residual
473
+
474
+ residual = hidden_states
475
+ hidden_states = self.ln_2(hidden_states)
476
+ feed_forward_hidden_states = self.mlp(hidden_states)
477
+ # residual connection
478
+ hidden_states = residual + feed_forward_hidden_states
479
+
480
+ if use_cache:
481
+ outputs = (hidden_states,) + outputs
482
+ else:
483
+ outputs = (hidden_states,) + outputs[1:]
484
+
485
+ return outputs # hidden_states, present, (attentions, cross_attentions)
486
+
487
+
488
+ class CodeShellPreTrainedModel(PreTrainedModel):
489
+ """
490
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
491
+ models.
492
+ """
493
+
494
+ config_class = CodeShellConfig
495
+ base_model_prefix = "transformer"
496
+ supports_gradient_checkpointing = True
497
+ _no_split_modules = ["CodeShellBlock"]
498
+ _skip_keys_device_placement = "past_key_values"
499
+
500
+ def __init__(self, *inputs, **kwargs):
501
+ super().__init__(*inputs, **kwargs)
502
+
503
+ def _init_weights(self, module):
504
+ """Initialize the weights."""
505
+ if isinstance(module, (CodeShellMLP, CodeShellAttention)):
506
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
507
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
508
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
509
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
510
+ #
511
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
512
+ module.c_proj.weight.data.normal_(
513
+ mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
514
+ )
515
+ module.c_proj._is_hf_initialized = True
516
+ elif isinstance(module, nn.Linear):
517
+ # Slightly different from the TF version which uses truncated_normal for initialization
518
+ # cf https://github.com/pytorch/pytorch/pull/5617
519
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
520
+ if module.bias is not None:
521
+ module.bias.data.zero_()
522
+ elif isinstance(module, nn.Embedding):
523
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
524
+ if module.padding_idx is not None:
525
+ module.weight.data[module.padding_idx].zero_()
526
+ elif isinstance(module, nn.LayerNorm):
527
+ module.bias.data.zero_()
528
+ module.weight.data.fill_(1.0)
529
+
530
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->CodeShell
531
+ def _set_gradient_checkpointing(self, module, value=False):
532
+ if isinstance(module, CodeShellModel):
533
+ module.gradient_checkpointing = value
534
+
535
+
536
+ GPT_BIGCODE_START_DOCSTRING = r"""
537
+
538
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
539
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
540
+ etc.)
541
+
542
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
543
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
544
+ and behavior.
545
+
546
+ Parameters:
547
+ config ([`CodeShellConfig`]): Model configuration class with all the parameters of the model.
548
+ Initializing with a config file does not load the weights associated with the model, only the
549
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
550
+ """
551
+
552
+ GPT_BIGCODE_INPUTS_DOCSTRING = r"""
553
+ Args:
554
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
555
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
556
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
557
+ sequence tokens in the vocabulary.
558
+
559
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
560
+ `input_ids`.
561
+
562
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
563
+ [`PreTrainedTokenizer.__call__`] for details.
564
+
565
+ [What are input IDs?](../glossary#input-ids)
566
+ past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`):
567
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
568
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
569
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
570
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
571
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
572
+
573
+ - 1 for tokens that are **not masked**,
574
+ - 0 for tokens that are **masked**.
575
+
576
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
577
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
578
+ `len(past_key_values) + len(input_ids)`
579
+
580
+ [What are attention masks?](../glossary#attention-mask)
581
+ token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
582
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
583
+ 1]`:
584
+
585
+ - 0 corresponds to a *sentence A* token,
586
+ - 1 corresponds to a *sentence B* token.
587
+
588
+ [What are token type IDs?](../glossary#token-type-ids)
589
+ position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
590
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
591
+ config.max_position_embeddings - 1]`.
592
+
593
+ [What are position IDs?](../glossary#position-ids)
594
+ head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
595
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
596
+
597
+ - 1 indicates the head is **not masked**,
598
+ - 0 indicates the head is **masked**.
599
+
600
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
601
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
602
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
603
+ model's internal embedding lookup matrix.
604
+
605
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
606
+ `past_key_values`).
607
+ use_cache (`bool`, *optional*):
608
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
609
+ `past_key_values`).
610
+ output_attentions (`bool`, *optional*):
611
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
612
+ tensors for more detail.
613
+ output_hidden_states (`bool`, *optional*):
614
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
615
+ more detail.
616
+ return_dict (`bool`, *optional*):
617
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
618
+ """
619
+
620
+
621
+ @add_start_docstrings(
622
+ "The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.",
623
+ GPT_BIGCODE_START_DOCSTRING,
624
+ )
625
+ class CodeShellModel(CodeShellPreTrainedModel):
626
+ def __init__(self, config):
627
+ super().__init__(config)
628
+ self.group_query_attention = config.group_query_attention
629
+ self.num_query_groups = config.num_query_groups
630
+ self.position_embedding_type = config.position_embedding_type
631
+ self.embed_dim = config.hidden_size
632
+
633
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
634
+ if self.position_embedding_type == "learned_absolute":
635
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
636
+ else:
637
+ pass
638
+
639
+ self.drop = nn.Dropout(config.embd_pdrop)
640
+ self.h = nn.ModuleList([CodeShellBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
641
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
642
+
643
+ max_positions = config.max_position_embeddings
644
+ self.register_buffer(
645
+ "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
646
+ )
647
+
648
+ self.gradient_checkpointing = False
649
+
650
+ # Initialize weights and apply final processing
651
+ self.post_init()
652
+
653
+ def get_input_embeddings(self):
654
+ return self.wte
655
+
656
+ def set_input_embeddings(self, new_embeddings):
657
+ self.wte = new_embeddings
658
+
659
+ @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
660
+ def forward(
661
+ self,
662
+ input_ids: Optional[torch.Tensor] = None,
663
+ past_key_values: Optional[List[torch.Tensor]] = None,
664
+ attention_mask: Optional[torch.Tensor] = None,
665
+ token_type_ids: Optional[torch.Tensor] = None,
666
+ position_ids: Optional[torch.Tensor] = None,
667
+ head_mask: Optional[torch.Tensor] = None,
668
+ inputs_embeds: Optional[torch.Tensor] = None,
669
+ encoder_hidden_states: Optional[torch.Tensor] = None,
670
+ encoder_attention_mask: Optional[torch.Tensor] = None,
671
+ use_cache: Optional[bool] = None,
672
+ output_attentions: Optional[bool] = None,
673
+ output_hidden_states: Optional[bool] = None,
674
+ return_dict: Optional[bool] = None,
675
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
676
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
677
+ output_hidden_states = (
678
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
679
+ )
680
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
681
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
682
+
683
+ if input_ids is not None and inputs_embeds is not None:
684
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
685
+ elif input_ids is not None:
686
+ input_shape = input_ids.size()
687
+ input_ids = input_ids.reshape(-1, input_shape[-1])
688
+ batch_size = input_ids.shape[0]
689
+ elif inputs_embeds is not None:
690
+ input_shape = inputs_embeds.size()[:-1]
691
+ batch_size = inputs_embeds.shape[0]
692
+ else:
693
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
694
+
695
+ if batch_size <= 0:
696
+ raise ValueError("batch_size has to be defined and > 0")
697
+
698
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
699
+
700
+ if token_type_ids is not None:
701
+ token_type_ids = token_type_ids.reshape(-1, input_shape[-1])
702
+ if position_ids is not None:
703
+ position_ids = position_ids.reshape(-1, input_shape[-1])
704
+
705
+ if past_key_values is None:
706
+ past_length = 0
707
+ past_key_values = tuple([None] * len(self.h))
708
+ else:
709
+ past_length = past_key_values[0][0].size(-3)
710
+
711
+ if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
712
+ # create position_ids on the fly for batch generation
713
+ position_ids = attention_mask.long().cumsum(-1) - 1
714
+ position_ids.masked_fill_(attention_mask == 0, 1)
715
+ if past_length > 0:
716
+ position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
717
+ elif position_ids is None:
718
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
719
+ position_ids = position_ids.unsqueeze(0).reshape(-1, input_shape[-1])
720
+
721
+ # Self-attention mask.
722
+ query_length = input_shape[-1]
723
+ key_length = past_length + query_length
724
+ self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
725
+
726
+ if attention_mask is not None:
727
+ self_attention_mask = self_attention_mask * attention_mask.reshape(batch_size, 1, -1).to(
728
+ dtype=torch.bool, device=self_attention_mask.device
729
+ )
730
+
731
+ # MQA models: (batch_size, query_length, n_heads, key_length)
732
+ # MHA models: (batch_size, n_heads, query_length, key_length)
733
+ attention_mask = self_attention_mask.unsqueeze(1)
734
+
735
+ encoder_attention_mask = None
736
+
737
+ # Prepare head mask if needed
738
+ # 1.0 in head_mask indicate we keep the head
739
+ # attention_probs has shape bsz x n_heads x N x N
740
+ # head_mask has shape n_layer x batch x n_heads x N x N
741
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
742
+
743
+ if inputs_embeds is None:
744
+ inputs_embeds = self.wte(input_ids)
745
+
746
+ hidden_states = inputs_embeds
747
+ if self.position_embedding_type == "learned_absolute":
748
+ position_embeds = self.wpe(position_ids)
749
+ hidden_states = hidden_states + position_embeds
750
+
751
+ if token_type_ids is not None:
752
+ token_type_embeds = self.wte(token_type_ids)
753
+ hidden_states = hidden_states + token_type_embeds
754
+
755
+ hidden_states = self.drop(hidden_states)
756
+
757
+ output_shape = input_shape + (hidden_states.size(-1),)
758
+
759
+ presents = [] if use_cache else None
760
+ all_self_attentions = () if output_attentions else None
761
+ all_hidden_states = () if output_hidden_states else None
762
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
763
+ if output_hidden_states:
764
+ all_hidden_states = all_hidden_states + (hidden_states,)
765
+
766
+ if self.gradient_checkpointing and self.training:
767
+
768
+ def create_custom_forward(module):
769
+ def custom_forward(*inputs):
770
+ # None for past_key_value
771
+ return module(*inputs, use_cache, output_attentions)
772
+
773
+ return custom_forward
774
+
775
+ outputs = torch.utils.checkpoint.checkpoint(
776
+ create_custom_forward(block),
777
+ hidden_states,
778
+ None,
779
+ attention_mask,
780
+ position_ids,
781
+ head_mask[i],
782
+ encoder_hidden_states,
783
+ encoder_attention_mask,
784
+ )
785
+ else:
786
+ outputs = block(
787
+ hidden_states,
788
+ layer_past=layer_past,
789
+ attention_mask=attention_mask,
790
+ position_ids=position_ids,
791
+ head_mask=head_mask[i],
792
+ encoder_hidden_states=encoder_hidden_states,
793
+ encoder_attention_mask=encoder_attention_mask,
794
+ use_cache=use_cache,
795
+ output_attentions=output_attentions,
796
+ )
797
+
798
+ hidden_states = outputs[0]
799
+ if use_cache:
800
+ presents.append(outputs[1])
801
+
802
+ if output_attentions:
803
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
804
+
805
+ hidden_states = self.ln_f(hidden_states)
806
+ hidden_states = hidden_states.reshape(output_shape)
807
+ # Add last hidden state
808
+ if output_hidden_states:
809
+ all_hidden_states = all_hidden_states + (hidden_states,)
810
+
811
+
812
+ if not return_dict:
813
+ return tuple(
814
+ v
815
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
816
+ if v is not None
817
+ )
818
+
819
+ return BaseModelOutputWithPastAndCrossAttentions(
820
+ last_hidden_state=hidden_states,
821
+ past_key_values=presents,
822
+ hidden_states=all_hidden_states,
823
+ attentions=all_self_attentions,
824
+ )
825
+
826
+
827
+ @add_start_docstrings(
828
+ """
829
+ The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
830
+ embeddings).
831
+ """,
832
+ GPT_BIGCODE_START_DOCSTRING,
833
+ )
834
+ class CodeShellForCausalLM(CodeShellPreTrainedModel):
835
+ _tied_weights_keys = ["lm_head.weight"]
836
+
837
+ def __init__(self, config):
838
+ super().__init__(config)
839
+ self.transformer = CodeShellModel(config)
840
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
841
+
842
+ # Initialize weights and apply final processing
843
+ self.post_init()
844
+
845
+ def quantize(self, bits: int):
846
+ try:
847
+ import bitsandbytes
848
+ from .quantizer import quantize_online
849
+ except ImportError:
850
+ raise ImportError(f"Needs bitsandbytes to run quantize.")
851
+ return quantize_online(self, bits)
852
+
853
+ def get_output_embeddings(self):
854
+ return self.lm_head
855
+
856
+ def set_output_embeddings(self, new_embeddings):
857
+ self.lm_head = new_embeddings
858
+
859
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
860
+ token_type_ids = kwargs.get("token_type_ids", None)
861
+ # only last token for inputs_ids if past is defined in kwargs
862
+ if past_key_values:
863
+ input_ids = input_ids[:, -1].unsqueeze(-1)
864
+ if token_type_ids is not None:
865
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
866
+
867
+ attention_mask = kwargs.get("attention_mask", None)
868
+ position_ids = kwargs.get("position_ids", None)
869
+
870
+ if attention_mask is not None and position_ids is None:
871
+ # create position_ids on the fly for batch generation
872
+ position_ids = attention_mask.long().cumsum(-1) - 1
873
+ position_ids.masked_fill_(attention_mask == 0, 1)
874
+ if past_key_values:
875
+ position_ids = position_ids[:, -1].unsqueeze(-1)
876
+ else:
877
+ position_ids = None
878
+
879
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
880
+ if inputs_embeds is not None and past_key_values is None:
881
+ model_inputs = {"inputs_embeds": inputs_embeds}
882
+ else:
883
+ model_inputs = {"input_ids": input_ids}
884
+
885
+ model_inputs.update(
886
+ {
887
+ "past_key_values": past_key_values,
888
+ "use_cache": kwargs.get("use_cache"),
889
+ "position_ids": position_ids,
890
+ "attention_mask": attention_mask,
891
+ "token_type_ids": token_type_ids,
892
+ }
893
+ )
894
+ return model_inputs
895
+
896
+ @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
897
+ def forward(
898
+ self,
899
+ input_ids: Optional[torch.Tensor] = None,
900
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
901
+ attention_mask: Optional[torch.Tensor] = None,
902
+ token_type_ids: Optional[torch.Tensor] = None,
903
+ position_ids: Optional[torch.Tensor] = None,
904
+ head_mask: Optional[torch.Tensor] = None,
905
+ inputs_embeds: Optional[torch.Tensor] = None,
906
+ encoder_hidden_states: Optional[torch.Tensor] = None,
907
+ encoder_attention_mask: Optional[torch.Tensor] = None,
908
+ labels: Optional[torch.Tensor] = None,
909
+ use_cache: Optional[bool] = None,
910
+ output_attentions: Optional[bool] = None,
911
+ output_hidden_states: Optional[bool] = None,
912
+ return_dict: Optional[bool] = None,
913
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
914
+ r"""
915
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
916
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
917
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
918
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
919
+ """
920
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
921
+
922
+ transformer_outputs = self.transformer(
923
+ input_ids,
924
+ past_key_values=past_key_values,
925
+ attention_mask=attention_mask,
926
+ token_type_ids=token_type_ids,
927
+ position_ids=position_ids,
928
+ head_mask=head_mask,
929
+ inputs_embeds=inputs_embeds,
930
+ encoder_hidden_states=encoder_hidden_states,
931
+ encoder_attention_mask=encoder_attention_mask,
932
+ use_cache=use_cache,
933
+ output_attentions=output_attentions,
934
+ output_hidden_states=output_hidden_states,
935
+ return_dict=return_dict,
936
+ )
937
+ hidden_states = transformer_outputs[0]
938
+ lm_logits = self.lm_head(hidden_states)
939
+ loss = None
940
+ if labels is not None:
941
+ # Shift so that tokens < n predict n
942
+ shift_logits = lm_logits[..., :-1, :].contiguous()
943
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
944
+ # Flatten the tokens
945
+ loss_fct = CrossEntropyLoss()
946
+ loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
947
+
948
+ if not return_dict:
949
+ output = (lm_logits,) + transformer_outputs[1:]
950
+ return ((loss,) + output) if loss is not None else output
951
+
952
+ return CausalLMOutputWithCrossAttentions(
953
+ loss=loss,
954
+ logits=lm_logits,
955
+ past_key_values=transformer_outputs.past_key_values,
956
+ hidden_states=transformer_outputs.hidden_states,
957
+ attentions=transformer_outputs.attentions,
958
+ )
959
+
960
+ @staticmethod
961
+ def _reorder_cache(past_key_values, beam_idx):
962
+ reordered_past = ()
963
+ for layer_past in past_key_values:
964
+ reordered_past += (
965
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
966
+ )
967
+ return reordered_past
quantizer.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import bitsandbytes as bnb
3
+ from bitsandbytes.nn.modules import Params4bit, Int8Params
4
+ except ImportError:
5
+ pass
6
+ import torch
7
+
8
+ def Params4bitCuda(self, device):
9
+ self.data = self.data.cuda(device)
10
+ self.quant_state[0] = self.quant_state[0].cuda(device)
11
+ self.quant_state[4][0] = self.quant_state[4][0].cuda(device)
12
+ self.quant_state[4][1][0] = self.quant_state[4][1][0].cuda(device)
13
+ self.quant_state[4][1][1] = self.quant_state[4][1][1].cuda(device)
14
+
15
+ self.quant_state[6] = self.quant_state[6].cuda(device)
16
+ return self
17
+
18
+ class Linear4bitOnline(torch.nn.Module):
19
+ def __init__(self, weight, bias, quant_type):
20
+ super().__init__()
21
+ self.weight = Params4bit(
22
+ weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
23
+ )
24
+ self.compute_dtype = None
25
+ #self.weight.cuda(weight.device)
26
+ self.bias = bias
27
+
28
+ def forward(self, x: torch.Tensor):
29
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
30
+ if self.bias is not None and self.bias.dtype != x.dtype:
31
+ self.bias.data = self.bias.data.to(x.dtype)
32
+
33
+ if getattr(self.weight, "quant_state", None) is None:
34
+ print(
35
+ "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
36
+ )
37
+ inp_dtype = x.dtype
38
+ if self.compute_dtype is not None:
39
+ x = x.to(self.compute_dtype)
40
+
41
+ bias = None if self.bias is None else self.bias.to(self.compute_dtype)
42
+ out = bnb.matmul_4bit(
43
+ x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
44
+ )
45
+
46
+ out = out.to(inp_dtype)
47
+
48
+ return out
49
+
50
+ class Linear8bitLtOnline(torch.nn.Module):
51
+ def __init__(
52
+ self,
53
+ weight,
54
+ bias,
55
+ has_fp16_weights=True,
56
+ memory_efficient_backward=False,
57
+ threshold=0.0,
58
+ index=None,
59
+ ):
60
+ super().__init__()
61
+ assert (
62
+ not memory_efficient_backward
63
+ ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
64
+ self.state = bnb.MatmulLtState()
65
+ self.index = index
66
+
67
+ # Necessary for stacked layers
68
+ self.state.threshold = threshold
69
+ self.state.has_fp16_weights = has_fp16_weights
70
+ self.state.memory_efficient_backward = memory_efficient_backward
71
+ if threshold > 0.0 and not has_fp16_weights:
72
+ self.state.use_pool = True
73
+
74
+ self.weight = Int8Params(
75
+ weight.data,
76
+ has_fp16_weights=has_fp16_weights,
77
+ requires_grad=has_fp16_weights,
78
+ )
79
+ self.bias = bias
80
+
81
+ def init_8bit_state(self):
82
+ self.state.CB = self.weight.CB
83
+ self.state.SCB = self.weight.SCB
84
+ self.weight.CB = None
85
+ self.weight.SCB = None
86
+
87
+ def forward(self, x: torch.Tensor):
88
+ self.state.is_training = self.training
89
+ if self.weight.CB is not None:
90
+ self.init_8bit_state()
91
+
92
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
93
+ if self.bias is not None and self.bias.dtype != x.dtype:
94
+ self.bias.data = self.bias.data.to(x.dtype)
95
+
96
+ out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
97
+
98
+ if not self.state.has_fp16_weights:
99
+ if self.state.CB is not None and self.state.CxB is not None:
100
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
101
+ # we no longer need the row-major weight
102
+ del self.state.CB
103
+ self.weight.data = self.state.CxB
104
+ return out
105
+
106
+ def quantize_online(model, bits: int):
107
+ def quant(weight, bias=None):
108
+ if bits == 8:
109
+ linear = Linear8bitLtOnline(
110
+ weight,
111
+ bias,
112
+ has_fp16_weights=False,
113
+ threshold=6.0,
114
+ )
115
+ if bias is not None:
116
+ linear.bias = torch.nn.Parameter(bias)
117
+ elif bits == 4:
118
+ linear = Linear4bitOnline(
119
+ weight,
120
+ bias,
121
+ quant_type="nf4", #fp4/nf4
122
+ )
123
+ else:
124
+ raise ValueError("quantize only support 4/8 bit")
125
+ return linear
126
+
127
+ def auto_quant(layer):
128
+ if hasattr(layer,"bias"):
129
+ linear = quant(layer.weight,bias=layer.bias)
130
+ else:
131
+ linear = quant(layer.weight)
132
+ return linear
133
+
134
+ for i,layer in enumerate(model.transformer.h):
135
+ layer.mlp.c_fc = auto_quant(layer.mlp.c_fc)
136
+ layer.mlp.c_proj = auto_quant(layer.mlp.c_proj)
137
+
138
+ layer.attn.c_attn=auto_quant(layer.attn.c_attn)
139
+ layer.attn.c_proj=auto_quant(layer.attn.c_proj)
140
+
141
+ return model