vanilla1116 commited on
Commit
95af11e
1 Parent(s): 561feb6

releasing model anah-20b

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json CHANGED
@@ -2,8 +2,9 @@
2
  "architectures": [
3
  "InternLM2ForCausalLM"
4
  ],
 
5
  "auto_map": {
6
- "AutoConfig": "configuration_internlm.InternLMConfig",
7
  "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM",
8
  "AutoModel": "modeling_internlm2.InternLM2ForCausalLM"
9
  },
@@ -15,20 +16,20 @@
15
  "initializer_range": 0.02,
16
  "intermediate_size": 16384,
17
  "max_position_embeddings": 32768,
18
- "model_type": "internlm",
19
  "num_attention_heads": 48,
20
  "num_hidden_layers": 48,
21
  "num_key_value_heads": 8,
22
  "pad_token_id": 2,
23
  "rms_norm_eps": 1e-05,
24
  "rope_scaling": {
25
- "factor": 1.0,
26
  "type": "dynamic"
27
  },
28
  "rope_theta": 1000000,
29
  "tie_word_embeddings": false,
30
  "torch_dtype": "bfloat16",
31
- "transformers_version": "4.25.1",
32
  "use_cache": true,
33
  "vocab_size": 92544
34
  }
 
2
  "architectures": [
3
  "InternLM2ForCausalLM"
4
  ],
5
+ "attn_implementation": "eager",
6
  "auto_map": {
7
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
8
  "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM",
9
  "AutoModel": "modeling_internlm2.InternLM2ForCausalLM"
10
  },
 
16
  "initializer_range": 0.02,
17
  "intermediate_size": 16384,
18
  "max_position_embeddings": 32768,
19
+ "model_type": "internlm2",
20
  "num_attention_heads": 48,
21
  "num_hidden_layers": 48,
22
  "num_key_value_heads": 8,
23
  "pad_token_id": 2,
24
  "rms_norm_eps": 1e-05,
25
  "rope_scaling": {
26
+ "factor": 2.0,
27
  "type": "dynamic"
28
  },
29
  "rope_theta": 1000000,
30
  "tie_word_embeddings": false,
31
  "torch_dtype": "bfloat16",
32
+ "transformers_version": "4.39.3",
33
  "use_cache": true,
34
  "vocab_size": 92544
35
  }
configuration_internlm.py → configuration_internlm2.py RENAMED
@@ -1,10 +1,7 @@
1
  # coding=utf-8
2
- # Copyright (c) InternLM. All rights reserved.
3
  #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -17,21 +14,22 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
- """ InternLM model configuration"""
21
 
22
  from transformers.configuration_utils import PretrainedConfig
23
  from transformers.utils import logging
24
 
25
  logger = logging.get_logger(__name__)
26
 
27
- INTERNLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
28
 
29
 
30
- class InternLMConfig(PretrainedConfig):
 
31
  r"""
32
- This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate
33
- an InternLM model according to the specified arguments, defining the model architecture. Instantiating a
34
- configuration with the defaults will yield a similar configuration to that of the InternLM-7B.
35
 
36
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
  documentation from [`PretrainedConfig`] for more information.
@@ -39,8 +37,8 @@ class InternLMConfig(PretrainedConfig):
39
 
40
  Args:
41
  vocab_size (`int`, *optional*, defaults to 32000):
42
- Vocabulary size of the InternLM model. Defines the number of different tokens that can be represented by the
43
- `inputs_ids` passed when calling [`InternLMModel`]
44
  hidden_size (`int`, *optional*, defaults to 4096):
45
  Dimension of the hidden representations.
46
  intermediate_size (`int`, *optional*, defaults to 11008):
@@ -73,19 +71,8 @@ class InternLMConfig(PretrainedConfig):
73
  Whether to tie weight embeddings
74
  Example:
75
 
76
- ```python
77
- >>> from transformers import InternLMModel, InternLMConfig
78
-
79
- >>> # Initializing a InternLM internlm-7b style configuration
80
- >>> configuration = InternLMConfig()
81
-
82
- >>> # Initializing a model from the internlm-7b style configuration
83
- >>> model = InternLMModel(configuration)
84
-
85
- >>> # Accessing the model configuration
86
- >>> configuration = model.config
87
- ```"""
88
- model_type = "internlm"
89
  _auto_class = "AutoConfig"
90
 
91
  def __init__( # pylint: disable=W0102
@@ -108,6 +95,7 @@ class InternLMConfig(PretrainedConfig):
108
  bias=True,
109
  rope_theta=10000,
110
  rope_scaling=None,
 
111
  **kwargs,
112
  ):
113
  self.vocab_size = vocab_size
@@ -129,6 +117,10 @@ class InternLMConfig(PretrainedConfig):
129
  self.rope_theta = rope_theta
130
  self.rope_scaling = rope_scaling
131
  self._rope_scaling_validation()
 
 
 
 
132
  super().__init__(
133
  pad_token_id=pad_token_id,
134
  bos_token_id=bos_token_id,
 
1
  # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
  #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
 
 
 
5
  #
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
  # you may not use this file except in compliance with the License.
 
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
  # See the License for the specific language governing permissions and
16
  # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
 
19
  from transformers.configuration_utils import PretrainedConfig
20
  from transformers.utils import logging
21
 
22
  logger = logging.get_logger(__name__)
23
 
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
 
26
 
27
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28
+ class InternLM2Config(PretrainedConfig):
29
  r"""
30
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33
 
34
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
  documentation from [`PretrainedConfig`] for more information.
 
37
 
38
  Args:
39
  vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`InternLM2Model`]
42
  hidden_size (`int`, *optional*, defaults to 4096):
43
  Dimension of the hidden representations.
44
  intermediate_size (`int`, *optional*, defaults to 11008):
 
71
  Whether to tie weight embeddings
72
  Example:
73
 
74
+ """
75
+ model_type = "internlm2"
 
 
 
 
 
 
 
 
 
 
 
76
  _auto_class = "AutoConfig"
77
 
78
  def __init__( # pylint: disable=W0102
 
95
  bias=True,
96
  rope_theta=10000,
97
  rope_scaling=None,
98
+ attn_implementation="eager",
99
  **kwargs,
100
  ):
101
  self.vocab_size = vocab_size
 
117
  self.rope_theta = rope_theta
118
  self.rope_scaling = rope_scaling
119
  self._rope_scaling_validation()
120
+
121
+ self.attn_implementation = attn_implementation
122
+ if self.attn_implementation is None:
123
+ self.attn_implementation = "eager"
124
  super().__init__(
125
  pad_token_id=pad_token_id,
126
  bos_token_id=bos_token_id,
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.39.3"
7
+ }
pytorch_model-00004-of-00004.bin → model-00001-of-00004.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:43f18616e240c0ef9204a7a49c0f0e70efb3b86f636803ec0bf44b53d3a759ed
3
- size 9920388243
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f75e4a4b81917979e5061bf65a33af8be09212206cae30b5c874182462df45d
3
+ size 9895166840
pytorch_model-00001-of-00004.bin → model-00002-of-00004.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:26a91e2e826ef5c62c7db6e91c2caf67edc9f6da165e97f4322452071bae5feb
3
- size 9895185011
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ae8c7e4f65e28ed6ed6930c4b85933fc2e156e2963b6fd974ca549a91fb35e9
3
+ size 9965995968
pytorch_model-00002-of-00004.bin → model-00003-of-00004.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d49702d830821de917cd321f72aae5a224602b630a740634b8dca2f37f5467cb
3
- size 9966016061
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:debacdd08681006ba1b191d27043621f5f5940b417891628e840dc3b1ed53d73
3
+ size 9940805472
pytorch_model-00003-of-00004.bin → model-00004-of-00004.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6afc3a495274a83f3a56a38bdbb256e1a1d769996e2d4e229a2e675d50f44a5f
3
- size 9940825423
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c432da1cd14a158dfba9fe1977e978cd596537199e33896e02bbbafed6b1f23
3
+ size 9920369776
model.safetensors.index.json ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 39722299392
4
+ },
5
+ "weight_map": {
6
+ "model.layers.0.attention.wo.weight": "model-00001-of-00004.safetensors",
7
+ "model.layers.0.attention.wqkv.weight": "model-00001-of-00004.safetensors",
8
+ "model.layers.0.attention_norm.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.ffn_norm.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.1.attention.wo.weight": "model-00001-of-00004.safetensors",
14
+ "model.layers.1.attention.wqkv.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.1.attention_norm.weight": "model-00001-of-00004.safetensors",
16
+ "model.layers.1.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.1.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
18
+ "model.layers.1.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.1.ffn_norm.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.10.attention.wo.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.10.attention.wqkv.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.10.attention_norm.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.10.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.10.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.10.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.10.ffn_norm.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.11.attention.wo.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.11.attention.wqkv.weight": "model-00001-of-00004.safetensors",
29
+ "model.layers.11.attention_norm.weight": "model-00002-of-00004.safetensors",
30
+ "model.layers.11.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
31
+ "model.layers.11.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
32
+ "model.layers.11.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
33
+ "model.layers.11.ffn_norm.weight": "model-00002-of-00004.safetensors",
34
+ "model.layers.12.attention.wo.weight": "model-00002-of-00004.safetensors",
35
+ "model.layers.12.attention.wqkv.weight": "model-00002-of-00004.safetensors",
36
+ "model.layers.12.attention_norm.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.12.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
38
+ "model.layers.12.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.12.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.12.ffn_norm.weight": "model-00002-of-00004.safetensors",
41
+ "model.layers.13.attention.wo.weight": "model-00002-of-00004.safetensors",
42
+ "model.layers.13.attention.wqkv.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.13.attention_norm.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.13.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.13.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.13.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.13.ffn_norm.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.14.attention.wo.weight": "model-00002-of-00004.safetensors",
49
+ "model.layers.14.attention.wqkv.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.14.attention_norm.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.14.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.14.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
53
+ "model.layers.14.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.14.ffn_norm.weight": "model-00002-of-00004.safetensors",
55
+ "model.layers.15.attention.wo.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.15.attention.wqkv.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.15.attention_norm.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.15.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.15.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.15.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.15.ffn_norm.weight": "model-00002-of-00004.safetensors",
62
+ "model.layers.16.attention.wo.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.16.attention.wqkv.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.16.attention_norm.weight": "model-00002-of-00004.safetensors",
65
+ "model.layers.16.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.16.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.16.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.16.ffn_norm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.17.attention.wo.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.17.attention.wqkv.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.17.attention_norm.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.17.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.17.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
74
+ "model.layers.17.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.17.ffn_norm.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.18.attention.wo.weight": "model-00002-of-00004.safetensors",
77
+ "model.layers.18.attention.wqkv.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.18.attention_norm.weight": "model-00002-of-00004.safetensors",
79
+ "model.layers.18.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.18.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.18.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.18.ffn_norm.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.19.attention.wo.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.19.attention.wqkv.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.19.attention_norm.weight": "model-00002-of-00004.safetensors",
86
+ "model.layers.19.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.19.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
88
+ "model.layers.19.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
89
+ "model.layers.19.ffn_norm.weight": "model-00002-of-00004.safetensors",
90
+ "model.layers.2.attention.wo.weight": "model-00001-of-00004.safetensors",
91
+ "model.layers.2.attention.wqkv.weight": "model-00001-of-00004.safetensors",
92
+ "model.layers.2.attention_norm.weight": "model-00001-of-00004.safetensors",
93
+ "model.layers.2.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
94
+ "model.layers.2.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
95
+ "model.layers.2.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
96
+ "model.layers.2.ffn_norm.weight": "model-00001-of-00004.safetensors",
97
+ "model.layers.20.attention.wo.weight": "model-00002-of-00004.safetensors",
98
+ "model.layers.20.attention.wqkv.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.20.attention_norm.weight": "model-00002-of-00004.safetensors",
100
+ "model.layers.20.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
101
+ "model.layers.20.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
102
+ "model.layers.20.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
103
+ "model.layers.20.ffn_norm.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.21.attention.wo.weight": "model-00002-of-00004.safetensors",
105
+ "model.layers.21.attention.wqkv.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.21.attention_norm.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.21.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.21.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.21.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
110
+ "model.layers.21.ffn_norm.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.22.attention.wo.weight": "model-00002-of-00004.safetensors",
112
+ "model.layers.22.attention.wqkv.weight": "model-00002-of-00004.safetensors",
113
+ "model.layers.22.attention_norm.weight": "model-00002-of-00004.safetensors",
114
+ "model.layers.22.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
115
+ "model.layers.22.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
116
+ "model.layers.22.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.22.ffn_norm.weight": "model-00002-of-00004.safetensors",
118
+ "model.layers.23.attention.wo.weight": "model-00002-of-00004.safetensors",
119
+ "model.layers.23.attention.wqkv.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.23.attention_norm.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.23.feed_forward.w1.weight": "model-00002-of-00004.safetensors",
122
+ "model.layers.23.feed_forward.w2.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.23.feed_forward.w3.weight": "model-00002-of-00004.safetensors",
124
+ "model.layers.23.ffn_norm.weight": "model-00002-of-00004.safetensors",
125
+ "model.layers.24.attention.wo.weight": "model-00003-of-00004.safetensors",
126
+ "model.layers.24.attention.wqkv.weight": "model-00003-of-00004.safetensors",
127
+ "model.layers.24.attention_norm.weight": "model-00003-of-00004.safetensors",
128
+ "model.layers.24.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
129
+ "model.layers.24.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
130
+ "model.layers.24.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
131
+ "model.layers.24.ffn_norm.weight": "model-00003-of-00004.safetensors",
132
+ "model.layers.25.attention.wo.weight": "model-00003-of-00004.safetensors",
133
+ "model.layers.25.attention.wqkv.weight": "model-00003-of-00004.safetensors",
134
+ "model.layers.25.attention_norm.weight": "model-00003-of-00004.safetensors",
135
+ "model.layers.25.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
136
+ "model.layers.25.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
137
+ "model.layers.25.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
138
+ "model.layers.25.ffn_norm.weight": "model-00003-of-00004.safetensors",
139
+ "model.layers.26.attention.wo.weight": "model-00003-of-00004.safetensors",
140
+ "model.layers.26.attention.wqkv.weight": "model-00003-of-00004.safetensors",
141
+ "model.layers.26.attention_norm.weight": "model-00003-of-00004.safetensors",
142
+ "model.layers.26.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
143
+ "model.layers.26.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
144
+ "model.layers.26.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
145
+ "model.layers.26.ffn_norm.weight": "model-00003-of-00004.safetensors",
146
+ "model.layers.27.attention.wo.weight": "model-00003-of-00004.safetensors",
147
+ "model.layers.27.attention.wqkv.weight": "model-00003-of-00004.safetensors",
148
+ "model.layers.27.attention_norm.weight": "model-00003-of-00004.safetensors",
149
+ "model.layers.27.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
150
+ "model.layers.27.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
151
+ "model.layers.27.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
152
+ "model.layers.27.ffn_norm.weight": "model-00003-of-00004.safetensors",
153
+ "model.layers.28.attention.wo.weight": "model-00003-of-00004.safetensors",
154
+ "model.layers.28.attention.wqkv.weight": "model-00003-of-00004.safetensors",
155
+ "model.layers.28.attention_norm.weight": "model-00003-of-00004.safetensors",
156
+ "model.layers.28.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
157
+ "model.layers.28.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
158
+ "model.layers.28.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
159
+ "model.layers.28.ffn_norm.weight": "model-00003-of-00004.safetensors",
160
+ "model.layers.29.attention.wo.weight": "model-00003-of-00004.safetensors",
161
+ "model.layers.29.attention.wqkv.weight": "model-00003-of-00004.safetensors",
162
+ "model.layers.29.attention_norm.weight": "model-00003-of-00004.safetensors",
163
+ "model.layers.29.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
164
+ "model.layers.29.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
165
+ "model.layers.29.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
166
+ "model.layers.29.ffn_norm.weight": "model-00003-of-00004.safetensors",
167
+ "model.layers.3.attention.wo.weight": "model-00001-of-00004.safetensors",
168
+ "model.layers.3.attention.wqkv.weight": "model-00001-of-00004.safetensors",
169
+ "model.layers.3.attention_norm.weight": "model-00001-of-00004.safetensors",
170
+ "model.layers.3.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
171
+ "model.layers.3.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
172
+ "model.layers.3.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
173
+ "model.layers.3.ffn_norm.weight": "model-00001-of-00004.safetensors",
174
+ "model.layers.30.attention.wo.weight": "model-00003-of-00004.safetensors",
175
+ "model.layers.30.attention.wqkv.weight": "model-00003-of-00004.safetensors",
176
+ "model.layers.30.attention_norm.weight": "model-00003-of-00004.safetensors",
177
+ "model.layers.30.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
178
+ "model.layers.30.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
179
+ "model.layers.30.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
180
+ "model.layers.30.ffn_norm.weight": "model-00003-of-00004.safetensors",
181
+ "model.layers.31.attention.wo.weight": "model-00003-of-00004.safetensors",
182
+ "model.layers.31.attention.wqkv.weight": "model-00003-of-00004.safetensors",
183
+ "model.layers.31.attention_norm.weight": "model-00003-of-00004.safetensors",
184
+ "model.layers.31.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
185
+ "model.layers.31.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
186
+ "model.layers.31.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
187
+ "model.layers.31.ffn_norm.weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.32.attention.wo.weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.32.attention.wqkv.weight": "model-00003-of-00004.safetensors",
190
+ "model.layers.32.attention_norm.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.32.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.32.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
193
+ "model.layers.32.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
194
+ "model.layers.32.ffn_norm.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.33.attention.wo.weight": "model-00003-of-00004.safetensors",
196
+ "model.layers.33.attention.wqkv.weight": "model-00003-of-00004.safetensors",
197
+ "model.layers.33.attention_norm.weight": "model-00003-of-00004.safetensors",
198
+ "model.layers.33.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
199
+ "model.layers.33.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
200
+ "model.layers.33.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.33.ffn_norm.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.34.attention.wo.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.34.attention.wqkv.weight": "model-00003-of-00004.safetensors",
204
+ "model.layers.34.attention_norm.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.34.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
206
+ "model.layers.34.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.34.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.34.ffn_norm.weight": "model-00003-of-00004.safetensors",
209
+ "model.layers.35.attention.wo.weight": "model-00003-of-00004.safetensors",
210
+ "model.layers.35.attention.wqkv.weight": "model-00003-of-00004.safetensors",
211
+ "model.layers.35.attention_norm.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.35.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
213
+ "model.layers.35.feed_forward.w2.weight": "model-00003-of-00004.safetensors",
214
+ "model.layers.35.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
215
+ "model.layers.35.ffn_norm.weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.36.attention.wo.weight": "model-00003-of-00004.safetensors",
217
+ "model.layers.36.attention.wqkv.weight": "model-00003-of-00004.safetensors",
218
+ "model.layers.36.attention_norm.weight": "model-00004-of-00004.safetensors",
219
+ "model.layers.36.feed_forward.w1.weight": "model-00003-of-00004.safetensors",
220
+ "model.layers.36.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
221
+ "model.layers.36.feed_forward.w3.weight": "model-00003-of-00004.safetensors",
222
+ "model.layers.36.ffn_norm.weight": "model-00004-of-00004.safetensors",
223
+ "model.layers.37.attention.wo.weight": "model-00004-of-00004.safetensors",
224
+ "model.layers.37.attention.wqkv.weight": "model-00004-of-00004.safetensors",
225
+ "model.layers.37.attention_norm.weight": "model-00004-of-00004.safetensors",
226
+ "model.layers.37.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
227
+ "model.layers.37.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
228
+ "model.layers.37.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
229
+ "model.layers.37.ffn_norm.weight": "model-00004-of-00004.safetensors",
230
+ "model.layers.38.attention.wo.weight": "model-00004-of-00004.safetensors",
231
+ "model.layers.38.attention.wqkv.weight": "model-00004-of-00004.safetensors",
232
+ "model.layers.38.attention_norm.weight": "model-00004-of-00004.safetensors",
233
+ "model.layers.38.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
234
+ "model.layers.38.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
235
+ "model.layers.38.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
236
+ "model.layers.38.ffn_norm.weight": "model-00004-of-00004.safetensors",
237
+ "model.layers.39.attention.wo.weight": "model-00004-of-00004.safetensors",
238
+ "model.layers.39.attention.wqkv.weight": "model-00004-of-00004.safetensors",
239
+ "model.layers.39.attention_norm.weight": "model-00004-of-00004.safetensors",
240
+ "model.layers.39.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
241
+ "model.layers.39.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
242
+ "model.layers.39.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
243
+ "model.layers.39.ffn_norm.weight": "model-00004-of-00004.safetensors",
244
+ "model.layers.4.attention.wo.weight": "model-00001-of-00004.safetensors",
245
+ "model.layers.4.attention.wqkv.weight": "model-00001-of-00004.safetensors",
246
+ "model.layers.4.attention_norm.weight": "model-00001-of-00004.safetensors",
247
+ "model.layers.4.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
248
+ "model.layers.4.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
249
+ "model.layers.4.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
250
+ "model.layers.4.ffn_norm.weight": "model-00001-of-00004.safetensors",
251
+ "model.layers.40.attention.wo.weight": "model-00004-of-00004.safetensors",
252
+ "model.layers.40.attention.wqkv.weight": "model-00004-of-00004.safetensors",
253
+ "model.layers.40.attention_norm.weight": "model-00004-of-00004.safetensors",
254
+ "model.layers.40.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
255
+ "model.layers.40.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
256
+ "model.layers.40.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
257
+ "model.layers.40.ffn_norm.weight": "model-00004-of-00004.safetensors",
258
+ "model.layers.41.attention.wo.weight": "model-00004-of-00004.safetensors",
259
+ "model.layers.41.attention.wqkv.weight": "model-00004-of-00004.safetensors",
260
+ "model.layers.41.attention_norm.weight": "model-00004-of-00004.safetensors",
261
+ "model.layers.41.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
262
+ "model.layers.41.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
263
+ "model.layers.41.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
264
+ "model.layers.41.ffn_norm.weight": "model-00004-of-00004.safetensors",
265
+ "model.layers.42.attention.wo.weight": "model-00004-of-00004.safetensors",
266
+ "model.layers.42.attention.wqkv.weight": "model-00004-of-00004.safetensors",
267
+ "model.layers.42.attention_norm.weight": "model-00004-of-00004.safetensors",
268
+ "model.layers.42.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
269
+ "model.layers.42.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
270
+ "model.layers.42.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
271
+ "model.layers.42.ffn_norm.weight": "model-00004-of-00004.safetensors",
272
+ "model.layers.43.attention.wo.weight": "model-00004-of-00004.safetensors",
273
+ "model.layers.43.attention.wqkv.weight": "model-00004-of-00004.safetensors",
274
+ "model.layers.43.attention_norm.weight": "model-00004-of-00004.safetensors",
275
+ "model.layers.43.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
276
+ "model.layers.43.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
277
+ "model.layers.43.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
278
+ "model.layers.43.ffn_norm.weight": "model-00004-of-00004.safetensors",
279
+ "model.layers.44.attention.wo.weight": "model-00004-of-00004.safetensors",
280
+ "model.layers.44.attention.wqkv.weight": "model-00004-of-00004.safetensors",
281
+ "model.layers.44.attention_norm.weight": "model-00004-of-00004.safetensors",
282
+ "model.layers.44.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
283
+ "model.layers.44.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
284
+ "model.layers.44.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
285
+ "model.layers.44.ffn_norm.weight": "model-00004-of-00004.safetensors",
286
+ "model.layers.45.attention.wo.weight": "model-00004-of-00004.safetensors",
287
+ "model.layers.45.attention.wqkv.weight": "model-00004-of-00004.safetensors",
288
+ "model.layers.45.attention_norm.weight": "model-00004-of-00004.safetensors",
289
+ "model.layers.45.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
290
+ "model.layers.45.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
291
+ "model.layers.45.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
292
+ "model.layers.45.ffn_norm.weight": "model-00004-of-00004.safetensors",
293
+ "model.layers.46.attention.wo.weight": "model-00004-of-00004.safetensors",
294
+ "model.layers.46.attention.wqkv.weight": "model-00004-of-00004.safetensors",
295
+ "model.layers.46.attention_norm.weight": "model-00004-of-00004.safetensors",
296
+ "model.layers.46.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
297
+ "model.layers.46.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
298
+ "model.layers.46.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
299
+ "model.layers.46.ffn_norm.weight": "model-00004-of-00004.safetensors",
300
+ "model.layers.47.attention.wo.weight": "model-00004-of-00004.safetensors",
301
+ "model.layers.47.attention.wqkv.weight": "model-00004-of-00004.safetensors",
302
+ "model.layers.47.attention_norm.weight": "model-00004-of-00004.safetensors",
303
+ "model.layers.47.feed_forward.w1.weight": "model-00004-of-00004.safetensors",
304
+ "model.layers.47.feed_forward.w2.weight": "model-00004-of-00004.safetensors",
305
+ "model.layers.47.feed_forward.w3.weight": "model-00004-of-00004.safetensors",
306
+ "model.layers.47.ffn_norm.weight": "model-00004-of-00004.safetensors",
307
+ "model.layers.5.attention.wo.weight": "model-00001-of-00004.safetensors",
308
+ "model.layers.5.attention.wqkv.weight": "model-00001-of-00004.safetensors",
309
+ "model.layers.5.attention_norm.weight": "model-00001-of-00004.safetensors",
310
+ "model.layers.5.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
311
+ "model.layers.5.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
312
+ "model.layers.5.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
313
+ "model.layers.5.ffn_norm.weight": "model-00001-of-00004.safetensors",
314
+ "model.layers.6.attention.wo.weight": "model-00001-of-00004.safetensors",
315
+ "model.layers.6.attention.wqkv.weight": "model-00001-of-00004.safetensors",
316
+ "model.layers.6.attention_norm.weight": "model-00001-of-00004.safetensors",
317
+ "model.layers.6.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
318
+ "model.layers.6.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
319
+ "model.layers.6.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
320
+ "model.layers.6.ffn_norm.weight": "model-00001-of-00004.safetensors",
321
+ "model.layers.7.attention.wo.weight": "model-00001-of-00004.safetensors",
322
+ "model.layers.7.attention.wqkv.weight": "model-00001-of-00004.safetensors",
323
+ "model.layers.7.attention_norm.weight": "model-00001-of-00004.safetensors",
324
+ "model.layers.7.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
325
+ "model.layers.7.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
326
+ "model.layers.7.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
327
+ "model.layers.7.ffn_norm.weight": "model-00001-of-00004.safetensors",
328
+ "model.layers.8.attention.wo.weight": "model-00001-of-00004.safetensors",
329
+ "model.layers.8.attention.wqkv.weight": "model-00001-of-00004.safetensors",
330
+ "model.layers.8.attention_norm.weight": "model-00001-of-00004.safetensors",
331
+ "model.layers.8.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
332
+ "model.layers.8.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
333
+ "model.layers.8.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
334
+ "model.layers.8.ffn_norm.weight": "model-00001-of-00004.safetensors",
335
+ "model.layers.9.attention.wo.weight": "model-00001-of-00004.safetensors",
336
+ "model.layers.9.attention.wqkv.weight": "model-00001-of-00004.safetensors",
337
+ "model.layers.9.attention_norm.weight": "model-00001-of-00004.safetensors",
338
+ "model.layers.9.feed_forward.w1.weight": "model-00001-of-00004.safetensors",
339
+ "model.layers.9.feed_forward.w2.weight": "model-00001-of-00004.safetensors",
340
+ "model.layers.9.feed_forward.w3.weight": "model-00001-of-00004.safetensors",
341
+ "model.layers.9.ffn_norm.weight": "model-00001-of-00004.safetensors",
342
+ "model.norm.weight": "model-00004-of-00004.safetensors",
343
+ "model.tok_embeddings.weight": "model-00001-of-00004.safetensors",
344
+ "output.weight": "model-00004-of-00004.safetensors"
345
+ }
346
+ }
modeling_internlm2.py CHANGED
@@ -1,10 +1,6 @@
1
- # coding=utf-8
2
- # # Copyright (c) InternLM. All rights reserved.
3
  #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -25,6 +21,7 @@ import warnings
25
  from typing import List, Optional, Tuple, Union
26
 
27
  import torch
 
28
  import torch.utils.checkpoint
29
  from einops import rearrange
30
  from torch import nn
@@ -48,12 +45,37 @@ try:
48
  except: # noqa # pylint: disable=bare-except
49
  BaseStreamer = None
50
 
51
- from .configuration_internlm import InternLMConfig as InternLM2Config
52
 
53
  logger = logging.get_logger(__name__)
54
 
55
  _CONFIG_FOR_DOC = "InternLM2Config"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
59
  def _make_causal_mask(
@@ -88,6 +110,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
88
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
89
 
90
 
 
91
  class InternLM2RMSNorm(nn.Module):
92
  def __init__(self, hidden_size, eps=1e-6):
93
  """
@@ -105,6 +128,7 @@ class InternLM2RMSNorm(nn.Module):
105
  return self.weight * hidden_states.to(input_dtype)
106
 
107
 
 
108
  class InternLM2RotaryEmbedding(nn.Module):
109
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
110
  super().__init__()
@@ -133,7 +157,7 @@ class InternLM2RotaryEmbedding(nn.Module):
133
  def forward(self, x, seq_len=None):
134
  # x: [bs, num_attention_heads, seq_len, head_size]
135
  if seq_len > self.max_seq_len_cached:
136
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
137
 
138
  return (
139
  self.cos_cached[:seq_len].to(dtype=x.dtype),
@@ -141,6 +165,7 @@ class InternLM2RotaryEmbedding(nn.Module):
141
  )
142
 
143
 
 
144
  class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
145
  """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
146
 
@@ -160,6 +185,7 @@ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
160
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
161
 
162
 
 
163
  class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
164
  """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
165
  Credits to the Reddit users /u/bloc97 and /u/emozilla.
@@ -188,6 +214,7 @@ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
188
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
189
 
190
 
 
191
  def rotate_half(x):
192
  """Rotates half the hidden dims of the input."""
193
  x1 = x[..., : x.shape[-1] // 2]
@@ -195,22 +222,13 @@ def rotate_half(x):
195
  return torch.cat((-x2, x1), dim=-1)
196
 
197
 
198
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
199
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
200
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
201
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
202
- cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
203
- sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
204
- if q.size(2) == 1:
205
- q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
206
- else:
207
- q_embed = (q * cos) + (rotate_half(q) * sin)
208
-
209
- if k.size(2) == 1:
210
- k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
211
- else:
212
- k_embed = (k * cos) + (rotate_half(k) * sin)
213
-
214
  return q_embed, k_embed
215
 
216
 
@@ -231,6 +249,7 @@ class InternLM2MLP(nn.Module):
231
  return down_proj
232
 
233
 
 
234
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
235
  """
236
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -243,6 +262,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
243
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
244
 
245
 
 
246
  class InternLM2Attention(nn.Module):
247
  """Multi-headed attention from 'Attention Is All You Need' paper"""
248
 
@@ -287,10 +307,17 @@ class InternLM2Attention(nn.Module):
287
  self.head_dim,
288
  max_position_embeddings=self.max_position_embeddings,
289
  base=self.config.rope_theta,
290
- scaling_factor=scaling_factor
 
 
 
 
 
 
 
291
  )
292
  else:
293
- raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
294
  return self.rotary_emb
295
 
296
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -384,6 +411,7 @@ class InternLM2Attention(nn.Module):
384
  return attn_output, attn_weights, past_key_value
385
 
386
 
 
387
  class InternLM2FlashAttention2(InternLM2Attention):
388
  """
389
  InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
@@ -420,9 +448,8 @@ class InternLM2FlashAttention2(InternLM2Attention):
420
  qkv_states = rearrange(
421
  qkv_states,
422
  "b q (h gs d) -> b q h gs d",
423
- gs=self.num_heads + 2 * self.num_key_value_heads,
424
  d=self.head_dim,
425
- q=q_len,
426
  )
427
 
428
  query_states = qkv_states[..., : self.num_key_value_groups, :]
@@ -430,6 +457,10 @@ class InternLM2FlashAttention2(InternLM2Attention):
430
  key_states = qkv_states[..., -2, :]
431
  value_states = qkv_states[..., -1, :]
432
 
 
 
 
 
433
  kv_seq_len = key_states.shape[-2]
434
  if past_key_value is not None:
435
  kv_seq_len += past_key_value[0].shape[-2]
@@ -449,36 +480,9 @@ class InternLM2FlashAttention2(InternLM2Attention):
449
  key_states = key_states.transpose(1, 2)
450
  value_states = value_states.transpose(1, 2)
451
 
452
- dropout_rate = 0.0 if not self.training else self.attention_dropout
453
-
454
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
455
- # therefore the input hidden states gets silently casted in float32. Hence, we need
456
- # cast them back in the correct dtype just to be sure everything works as expected.
457
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
458
- # in fp32. (InternLM2RMSNorm handles it correctly)
459
-
460
- input_dtype = query_states.dtype
461
- if input_dtype == torch.float32:
462
- # Handle the case where the model is quantized
463
- if hasattr(self.config, "_pre_quantization_dtype"):
464
- target_dtype = self.config._pre_quantization_dtype
465
- else:
466
- target_dtype = self.q_proj.weight.dtype
467
-
468
- logger.warning_once(
469
- f"The input hidden states seems to be silently casted in float32, this might be related to"
470
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back "
471
- f"the input in {target_dtype}."
472
- )
473
-
474
- query_states = query_states.to(target_dtype)
475
- key_states = key_states.to(target_dtype)
476
- value_states = value_states.to(target_dtype)
477
-
478
  attn_output = self._flash_attention_forward(
479
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
480
  )
481
-
482
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
483
  attn_output = self.wo(attn_output)
484
 
@@ -487,16 +491,112 @@ class InternLM2FlashAttention2(InternLM2Attention):
487
 
488
  return attn_output, attn_weights, past_key_value
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  class InternLM2DecoderLayer(nn.Module):
492
  def __init__(self, config: InternLM2Config):
493
  super().__init__()
494
  self.hidden_size = config.hidden_size
495
- self.attention = (
496
- InternLM2Attention(config=config)
497
- if not getattr(config, "_flash_attn_2_enabled", False)
498
- else InternLM2FlashAttention2(config=config)
499
- )
500
  self.feed_forward = InternLM2MLP(config)
501
  self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
502
  self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -581,6 +681,7 @@ InternLM2_START_DOCSTRING = r"""
581
  """
582
 
583
 
 
584
  @add_start_docstrings(
585
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
586
  InternLM2_START_DOCSTRING,
@@ -591,7 +692,6 @@ class InternLM2PreTrainedModel(PreTrainedModel):
591
  supports_gradient_checkpointing = True
592
  _no_split_modules = ["InternLM2DecoderLayer"]
593
  _skip_keys_device_placement = "past_key_values"
594
- _supports_flash_attn_2 = True
595
 
596
  def _init_weights(self, module):
597
  std = self.config.initializer_range
@@ -670,6 +770,7 @@ InternLM2_INPUTS_DOCSTRING = r"""
670
  """
671
 
672
 
 
673
  @add_start_docstrings(
674
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
675
  InternLM2_START_DOCSTRING,
@@ -688,8 +789,10 @@ class InternLM2Model(InternLM2PreTrainedModel):
688
  super().__init__(config)
689
  self.padding_idx = config.pad_token_id
690
  self.vocab_size = config.vocab_size
 
691
 
692
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
693
  self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
694
  self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
695
 
@@ -703,7 +806,6 @@ class InternLM2Model(InternLM2PreTrainedModel):
703
  def set_input_embeddings(self, value):
704
  self.tok_embeddings = value
705
 
706
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
707
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
708
  # create causal mask
709
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -748,6 +850,9 @@ class InternLM2Model(InternLM2PreTrainedModel):
748
 
749
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
750
 
 
 
 
751
  # retrieve input_ids and inputs_embeds
752
  if input_ids is not None and inputs_embeds is not None:
753
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -773,14 +878,18 @@ class InternLM2Model(InternLM2PreTrainedModel):
773
 
774
  if inputs_embeds is None:
775
  inputs_embeds = self.tok_embeddings(input_ids)
776
- # embed positions
777
- if attention_mask is None:
778
- attention_mask = torch.ones(
779
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 
 
 
 
 
 
 
780
  )
781
- attention_mask = self._prepare_decoder_attention_mask(
782
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
783
- )
784
 
785
  # embed positions
786
  hidden_states = inputs_embeds
@@ -854,6 +963,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
854
  )
855
 
856
 
 
857
  class InternLM2ForCausalLM(InternLM2PreTrainedModel):
858
  _auto_class = "AutoModelForCausalLM"
859
 
@@ -1023,12 +1133,16 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1023
  )
1024
  return reordered_past
1025
 
1026
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
1027
- prompt = ""
 
 
 
 
 
1028
  for record in history:
1029
- prompt += f"""[UNUSED_TOKEN_146]user\n{record[0]}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n{record[1]}[UNUSED_TOKEN_145]\n"""
1030
- prompt += f"""[UNUSED_TOKEN_146]user\n{query}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"""
1031
- print(prompt)
1032
  return tokenizer([prompt], return_tensors="pt")
1033
 
1034
  @torch.no_grad()
@@ -1042,10 +1156,15 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1042
  do_sample: bool = True,
1043
  temperature: float = 0.8,
1044
  top_p: float = 0.8,
 
 
 
1045
  **kwargs,
1046
  ):
1047
- inputs = self.build_inputs(tokenizer, query, history)
1048
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
 
 
1049
  outputs = self.generate(
1050
  **inputs,
1051
  streamer=streamer,
@@ -1053,11 +1172,12 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1053
  do_sample=do_sample,
1054
  temperature=temperature,
1055
  top_p=top_p,
 
1056
  **kwargs,
1057
  )
1058
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1059
  response = tokenizer.decode(outputs, skip_special_tokens=True)
1060
- response = response.split("[UNUSED_TOKEN_145]")[0]
1061
  history = history + [(query, response)]
1062
  return response, history
1063
 
@@ -1095,6 +1215,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1095
  self.query = query
1096
  self.history = history
1097
  self.response = ""
 
1098
  self.received_inputs = False
1099
  self.queue.put((self.response, history + [(self.query, self.response)]))
1100
 
@@ -1109,11 +1230,15 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1109
  self.received_inputs = True
1110
  return
1111
 
1112
- token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
1113
- if token.strip() != "[UNUSED_TOKEN_145]":
 
1114
  self.response = self.response + token
1115
  history = self.history + [(self.query, self.response)]
1116
  self.queue.put((self.response, history))
 
 
 
1117
 
1118
  def end(self):
1119
  self.queue.put(None)
@@ -1143,6 +1268,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1143
  return consumer()
1144
 
1145
 
 
1146
  @add_start_docstrings(
1147
  """
1148
  The InternLM2 Model transformer with a sequence classification head on top (linear layer).
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
 
2
  #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
 
 
 
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
21
  from typing import List, Optional, Tuple, Union
22
 
23
  import torch
24
+ import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from einops import rearrange
27
  from torch import nn
 
45
  except: # noqa # pylint: disable=bare-except
46
  BaseStreamer = None
47
 
48
+ from .configuration_internlm2 import InternLM2Config
49
 
50
  logger = logging.get_logger(__name__)
51
 
52
  _CONFIG_FOR_DOC = "InternLM2Config"
53
 
54
+ flash_attn_func, flash_attn_varlen_func = None, None
55
+ pad_input, index_first_axis, unpad_input = None, None, None
56
+ def _import_flash_attn():
57
+ global flash_attn_func, flash_attn_varlen_func
58
+ global pad_input, index_first_axis, unpad_input
59
+ try:
60
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
61
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
62
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
63
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
64
+ except ImportError:
65
+ raise ImportError("flash_attn is not installed.")
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
 
80
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
81
  def _make_causal_mask(
 
110
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
111
 
112
 
113
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
114
  class InternLM2RMSNorm(nn.Module):
115
  def __init__(self, hidden_size, eps=1e-6):
116
  """
 
128
  return self.weight * hidden_states.to(input_dtype)
129
 
130
 
131
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
132
  class InternLM2RotaryEmbedding(nn.Module):
133
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
134
  super().__init__()
 
157
  def forward(self, x, seq_len=None):
158
  # x: [bs, num_attention_heads, seq_len, head_size]
159
  if seq_len > self.max_seq_len_cached:
160
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
161
 
162
  return (
163
  self.cos_cached[:seq_len].to(dtype=x.dtype),
 
165
  )
166
 
167
 
168
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
169
  class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
170
  """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
171
 
 
185
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
186
 
187
 
188
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
189
  class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
190
  """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
191
  Credits to the Reddit users /u/bloc97 and /u/emozilla.
 
214
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
215
 
216
 
217
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
218
  def rotate_half(x):
219
  """Rotates half the hidden dims of the input."""
220
  x1 = x[..., : x.shape[-1] // 2]
 
222
  return torch.cat((-x2, x1), dim=-1)
223
 
224
 
225
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
226
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
227
+ """Applies Rotary Position Embedding to the query and key tensors."""
228
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
229
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
230
+ q_embed = (q * cos) + (rotate_half(q) * sin)
231
+ k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
232
  return q_embed, k_embed
233
 
234
 
 
249
  return down_proj
250
 
251
 
252
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
253
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
254
  """
255
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
262
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
263
 
264
 
265
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
266
  class InternLM2Attention(nn.Module):
267
  """Multi-headed attention from 'Attention Is All You Need' paper"""
268
 
 
307
  self.head_dim,
308
  max_position_embeddings=self.max_position_embeddings,
309
  base=self.config.rope_theta,
310
+ scaling_factor=scaling_factor,
311
+ )
312
+ elif scaling_type == "linear":
313
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
314
+ self.head_dim,
315
+ max_position_embeddings=self.max_position_embeddings,
316
+ base=self.config.rope_theta,
317
+ scaling_factor=scaling_factor,
318
  )
319
  else:
320
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
321
  return self.rotary_emb
322
 
323
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
411
  return attn_output, attn_weights, past_key_value
412
 
413
 
414
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
415
  class InternLM2FlashAttention2(InternLM2Attention):
416
  """
417
  InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
 
448
  qkv_states = rearrange(
449
  qkv_states,
450
  "b q (h gs d) -> b q h gs d",
451
+ gs=2 + self.num_key_value_groups,
452
  d=self.head_dim,
 
453
  )
454
 
455
  query_states = qkv_states[..., : self.num_key_value_groups, :]
 
457
  key_states = qkv_states[..., -2, :]
458
  value_states = qkv_states[..., -1, :]
459
 
460
+ query_states = query_states.transpose(1, 2)
461
+ key_states = key_states.transpose(1, 2)
462
+ value_states = value_states.transpose(1, 2)
463
+
464
  kv_seq_len = key_states.shape[-2]
465
  if past_key_value is not None:
466
  kv_seq_len += past_key_value[0].shape[-2]
 
480
  key_states = key_states.transpose(1, 2)
481
  value_states = value_states.transpose(1, 2)
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  attn_output = self._flash_attention_forward(
484
+ query_states, key_states, value_states, attention_mask, q_len
485
  )
 
486
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
487
  attn_output = self.wo(attn_output)
488
 
 
491
 
492
  return attn_output, attn_weights, past_key_value
493
 
494
+ def _flash_attention_forward(
495
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
496
+ ):
497
+ """
498
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
499
+ first unpad the input, then computes the attention scores and pad the final attention scores.
500
+
501
+ Args:
502
+ query_states (`torch.Tensor`):
503
+ Input query states to be passed to Flash Attention API
504
+ key_states (`torch.Tensor`):
505
+ Input key states to be passed to Flash Attention API
506
+ value_states (`torch.Tensor`):
507
+ Input value states to be passed to Flash Attention API
508
+ attention_mask (`torch.Tensor`):
509
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
510
+ position of padding tokens and 1 for the position of non-padding tokens.
511
+ dropout (`int`, *optional*):
512
+ Attention dropout
513
+ softmax_scale (`float`, *optional*):
514
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
515
+ """
516
+ # Contains at least one padding token in the sequence
517
+ causal = self.is_causal and query_length != 1
518
+ if attention_mask is not None:
519
+ batch_size = query_states.shape[0]
520
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
521
+ query_states, key_states, value_states, attention_mask, query_length
522
+ )
523
+
524
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
525
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
526
+
527
+ attn_output_unpad = flash_attn_varlen_func(
528
+ query_states,
529
+ key_states,
530
+ value_states,
531
+ cu_seqlens_q=cu_seqlens_q,
532
+ cu_seqlens_k=cu_seqlens_k,
533
+ max_seqlen_q=max_seqlen_in_batch_q,
534
+ max_seqlen_k=max_seqlen_in_batch_k,
535
+ dropout_p=dropout,
536
+ softmax_scale=softmax_scale,
537
+ causal=causal,
538
+ )
539
 
540
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
541
+ else:
542
+ attn_output = flash_attn_func(
543
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
544
+ )
545
+
546
+ return attn_output
547
+
548
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
549
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
550
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
551
+
552
+ key_layer = index_first_axis(
553
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
554
+ )
555
+ value_layer = index_first_axis(
556
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
557
+ )
558
+
559
+ if query_length == kv_seq_len:
560
+ query_layer = index_first_axis(
561
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
562
+ )
563
+ cu_seqlens_q = cu_seqlens_k
564
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
565
+ indices_q = indices_k
566
+ elif query_length == 1:
567
+ max_seqlen_in_batch_q = 1
568
+ cu_seqlens_q = torch.arange(
569
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
570
+ ) # There is a memcpy here, that is very bad.
571
+ indices_q = cu_seqlens_q[:-1]
572
+ query_layer = query_layer.squeeze(1)
573
+ else:
574
+ # The -q_len: slice assumes left padding.
575
+ attention_mask = attention_mask[:, -query_length:]
576
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
577
+
578
+ return (
579
+ query_layer,
580
+ key_layer,
581
+ value_layer,
582
+ indices_q.to(torch.int64),
583
+ (cu_seqlens_q, cu_seqlens_k),
584
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
585
+ )
586
+
587
+ INTERNLM2_ATTENTION_CLASSES = {
588
+ "eager": InternLM2Attention,
589
+ "flash_attention_2": InternLM2FlashAttention2,
590
+ }
591
+
592
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
593
  class InternLM2DecoderLayer(nn.Module):
594
  def __init__(self, config: InternLM2Config):
595
  super().__init__()
596
  self.hidden_size = config.hidden_size
597
+
598
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
599
+
 
 
600
  self.feed_forward = InternLM2MLP(config)
601
  self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
602
  self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
681
  """
682
 
683
 
684
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
685
  @add_start_docstrings(
686
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
687
  InternLM2_START_DOCSTRING,
 
692
  supports_gradient_checkpointing = True
693
  _no_split_modules = ["InternLM2DecoderLayer"]
694
  _skip_keys_device_placement = "past_key_values"
 
695
 
696
  def _init_weights(self, module):
697
  std = self.config.initializer_range
 
770
  """
771
 
772
 
773
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
774
  @add_start_docstrings(
775
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
776
  InternLM2_START_DOCSTRING,
 
789
  super().__init__(config)
790
  self.padding_idx = config.pad_token_id
791
  self.vocab_size = config.vocab_size
792
+ self.config = config
793
 
794
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
795
+
796
  self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
797
  self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
798
 
 
806
  def set_input_embeddings(self, value):
807
  self.tok_embeddings = value
808
 
 
809
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
810
  # create causal mask
811
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 
850
 
851
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
852
 
853
+ if self.config.attn_implementation == "flash_attention_2":
854
+ _import_flash_attn()
855
+
856
  # retrieve input_ids and inputs_embeds
857
  if input_ids is not None and inputs_embeds is not None:
858
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
878
 
879
  if inputs_embeds is None:
880
  inputs_embeds = self.tok_embeddings(input_ids)
881
+
882
+ if self.config.attn_implementation == "flash_attention_2":
883
+ # 2d mask is passed through the layers
884
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
885
+ else:
886
+ if attention_mask is None:
887
+ attention_mask = torch.ones(
888
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
889
+ )
890
+ attention_mask = self._prepare_decoder_attention_mask(
891
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
892
  )
 
 
 
893
 
894
  # embed positions
895
  hidden_states = inputs_embeds
 
963
  )
964
 
965
 
966
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
967
  class InternLM2ForCausalLM(InternLM2PreTrainedModel):
968
  _auto_class = "AutoModelForCausalLM"
969
 
 
1133
  )
1134
  return reordered_past
1135
 
1136
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
1137
+ if tokenizer.add_bos_token:
1138
+ prompt = ""
1139
+ else:
1140
+ prompt = tokenizer.bos_token
1141
+ if meta_instruction:
1142
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1143
  for record in history:
1144
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1145
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
 
1146
  return tokenizer([prompt], return_tensors="pt")
1147
 
1148
  @torch.no_grad()
 
1156
  do_sample: bool = True,
1157
  temperature: float = 0.8,
1158
  top_p: float = 0.8,
1159
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1160
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1161
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1162
  **kwargs,
1163
  ):
1164
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1165
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1166
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1167
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1168
  outputs = self.generate(
1169
  **inputs,
1170
  streamer=streamer,
 
1172
  do_sample=do_sample,
1173
  temperature=temperature,
1174
  top_p=top_p,
1175
+ eos_token_id=eos_token_id,
1176
  **kwargs,
1177
  )
1178
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1179
  response = tokenizer.decode(outputs, skip_special_tokens=True)
1180
+ response = response.split("<|im_end|>")[0]
1181
  history = history + [(query, response)]
1182
  return response, history
1183
 
 
1215
  self.query = query
1216
  self.history = history
1217
  self.response = ""
1218
+ self.cache = []
1219
  self.received_inputs = False
1220
  self.queue.put((self.response, history + [(self.query, self.response)]))
1221
 
 
1230
  self.received_inputs = True
1231
  return
1232
 
1233
+ self.cache.extend(value.tolist())
1234
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1235
+ if token.strip() != "<|im_end|>":
1236
  self.response = self.response + token
1237
  history = self.history + [(self.query, self.response)]
1238
  self.queue.put((self.response, history))
1239
+ self.cache = []
1240
+ else:
1241
+ self.end()
1242
 
1243
  def end(self):
1244
  self.queue.put(None)
 
1268
  return consumer()
1269
 
1270
 
1271
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1272
  @add_start_docstrings(
1273
  """
1274
  The InternLM2 Model transformer with a sequence classification head on top (linear layer).
pytorch_model.bin.index.json DELETED
@@ -1,346 +0,0 @@
1
- {
2
- "metadata": {
3
- "total_size": 39722299392
4
- },
5
- "weight_map": {
6
- "model.layers.0.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
7
- "model.layers.0.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
8
- "model.layers.0.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
9
- "model.layers.0.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
10
- "model.layers.0.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
11
- "model.layers.0.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
12
- "model.layers.0.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
13
- "model.layers.1.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
14
- "model.layers.1.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
15
- "model.layers.1.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
16
- "model.layers.1.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
17
- "model.layers.1.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
18
- "model.layers.1.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
19
- "model.layers.1.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
20
- "model.layers.10.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
21
- "model.layers.10.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
22
- "model.layers.10.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
23
- "model.layers.10.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
24
- "model.layers.10.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
25
- "model.layers.10.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
26
- "model.layers.10.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
27
- "model.layers.11.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
28
- "model.layers.11.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
29
- "model.layers.11.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
30
- "model.layers.11.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
31
- "model.layers.11.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
32
- "model.layers.11.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
33
- "model.layers.11.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
34
- "model.layers.12.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
35
- "model.layers.12.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
36
- "model.layers.12.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
37
- "model.layers.12.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
38
- "model.layers.12.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
39
- "model.layers.12.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
40
- "model.layers.12.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
41
- "model.layers.13.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
42
- "model.layers.13.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
43
- "model.layers.13.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
44
- "model.layers.13.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
45
- "model.layers.13.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
46
- "model.layers.13.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
47
- "model.layers.13.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
48
- "model.layers.14.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
49
- "model.layers.14.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
50
- "model.layers.14.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
51
- "model.layers.14.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
52
- "model.layers.14.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
53
- "model.layers.14.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
54
- "model.layers.14.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
55
- "model.layers.15.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
56
- "model.layers.15.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
57
- "model.layers.15.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
58
- "model.layers.15.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
59
- "model.layers.15.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
60
- "model.layers.15.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
61
- "model.layers.15.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
62
- "model.layers.16.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
63
- "model.layers.16.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
64
- "model.layers.16.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
65
- "model.layers.16.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
66
- "model.layers.16.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
67
- "model.layers.16.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
68
- "model.layers.16.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
69
- "model.layers.17.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
70
- "model.layers.17.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
71
- "model.layers.17.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
72
- "model.layers.17.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
73
- "model.layers.17.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
74
- "model.layers.17.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
75
- "model.layers.17.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
76
- "model.layers.18.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
77
- "model.layers.18.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
78
- "model.layers.18.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
79
- "model.layers.18.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
80
- "model.layers.18.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
81
- "model.layers.18.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
82
- "model.layers.18.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
83
- "model.layers.19.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
84
- "model.layers.19.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
85
- "model.layers.19.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
86
- "model.layers.19.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
87
- "model.layers.19.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
88
- "model.layers.19.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
89
- "model.layers.19.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
90
- "model.layers.2.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
91
- "model.layers.2.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
92
- "model.layers.2.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
93
- "model.layers.2.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
94
- "model.layers.2.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
95
- "model.layers.2.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
96
- "model.layers.2.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
97
- "model.layers.20.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
98
- "model.layers.20.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
99
- "model.layers.20.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
100
- "model.layers.20.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
101
- "model.layers.20.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
102
- "model.layers.20.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
103
- "model.layers.20.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
104
- "model.layers.21.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
105
- "model.layers.21.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
106
- "model.layers.21.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
107
- "model.layers.21.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
108
- "model.layers.21.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
109
- "model.layers.21.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
110
- "model.layers.21.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
111
- "model.layers.22.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
112
- "model.layers.22.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
113
- "model.layers.22.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
114
- "model.layers.22.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
115
- "model.layers.22.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
116
- "model.layers.22.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
117
- "model.layers.22.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
118
- "model.layers.23.attention.wo.weight": "pytorch_model-00002-of-00004.bin",
119
- "model.layers.23.attention.wqkv.weight": "pytorch_model-00002-of-00004.bin",
120
- "model.layers.23.attention_norm.weight": "pytorch_model-00002-of-00004.bin",
121
- "model.layers.23.feed_forward.w1.weight": "pytorch_model-00002-of-00004.bin",
122
- "model.layers.23.feed_forward.w2.weight": "pytorch_model-00002-of-00004.bin",
123
- "model.layers.23.feed_forward.w3.weight": "pytorch_model-00002-of-00004.bin",
124
- "model.layers.23.ffn_norm.weight": "pytorch_model-00002-of-00004.bin",
125
- "model.layers.24.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
126
- "model.layers.24.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
127
- "model.layers.24.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
128
- "model.layers.24.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
129
- "model.layers.24.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
130
- "model.layers.24.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
131
- "model.layers.24.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
132
- "model.layers.25.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
133
- "model.layers.25.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
134
- "model.layers.25.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
135
- "model.layers.25.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
136
- "model.layers.25.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
137
- "model.layers.25.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
138
- "model.layers.25.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
139
- "model.layers.26.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
140
- "model.layers.26.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
141
- "model.layers.26.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
142
- "model.layers.26.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
143
- "model.layers.26.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
144
- "model.layers.26.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
145
- "model.layers.26.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
146
- "model.layers.27.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
147
- "model.layers.27.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
148
- "model.layers.27.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
149
- "model.layers.27.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
150
- "model.layers.27.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
151
- "model.layers.27.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
152
- "model.layers.27.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
153
- "model.layers.28.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
154
- "model.layers.28.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
155
- "model.layers.28.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
156
- "model.layers.28.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
157
- "model.layers.28.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
158
- "model.layers.28.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
159
- "model.layers.28.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
160
- "model.layers.29.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
161
- "model.layers.29.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
162
- "model.layers.29.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
163
- "model.layers.29.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
164
- "model.layers.29.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
165
- "model.layers.29.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
166
- "model.layers.29.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
167
- "model.layers.3.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
168
- "model.layers.3.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
169
- "model.layers.3.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
170
- "model.layers.3.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
171
- "model.layers.3.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
172
- "model.layers.3.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
173
- "model.layers.3.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
174
- "model.layers.30.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
175
- "model.layers.30.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
176
- "model.layers.30.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
177
- "model.layers.30.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
178
- "model.layers.30.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
179
- "model.layers.30.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
180
- "model.layers.30.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
181
- "model.layers.31.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
182
- "model.layers.31.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
183
- "model.layers.31.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
184
- "model.layers.31.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
185
- "model.layers.31.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
186
- "model.layers.31.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
187
- "model.layers.31.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
188
- "model.layers.32.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
189
- "model.layers.32.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
190
- "model.layers.32.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
191
- "model.layers.32.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
192
- "model.layers.32.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
193
- "model.layers.32.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
194
- "model.layers.32.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
195
- "model.layers.33.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
196
- "model.layers.33.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
197
- "model.layers.33.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
198
- "model.layers.33.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
199
- "model.layers.33.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
200
- "model.layers.33.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
201
- "model.layers.33.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
202
- "model.layers.34.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
203
- "model.layers.34.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
204
- "model.layers.34.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
205
- "model.layers.34.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
206
- "model.layers.34.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
207
- "model.layers.34.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
208
- "model.layers.34.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
209
- "model.layers.35.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
210
- "model.layers.35.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
211
- "model.layers.35.attention_norm.weight": "pytorch_model-00003-of-00004.bin",
212
- "model.layers.35.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
213
- "model.layers.35.feed_forward.w2.weight": "pytorch_model-00003-of-00004.bin",
214
- "model.layers.35.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
215
- "model.layers.35.ffn_norm.weight": "pytorch_model-00003-of-00004.bin",
216
- "model.layers.36.attention.wo.weight": "pytorch_model-00003-of-00004.bin",
217
- "model.layers.36.attention.wqkv.weight": "pytorch_model-00003-of-00004.bin",
218
- "model.layers.36.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
219
- "model.layers.36.feed_forward.w1.weight": "pytorch_model-00003-of-00004.bin",
220
- "model.layers.36.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
221
- "model.layers.36.feed_forward.w3.weight": "pytorch_model-00003-of-00004.bin",
222
- "model.layers.36.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
223
- "model.layers.37.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
224
- "model.layers.37.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
225
- "model.layers.37.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
226
- "model.layers.37.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
227
- "model.layers.37.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
228
- "model.layers.37.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
229
- "model.layers.37.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
230
- "model.layers.38.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
231
- "model.layers.38.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
232
- "model.layers.38.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
233
- "model.layers.38.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
234
- "model.layers.38.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
235
- "model.layers.38.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
236
- "model.layers.38.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
237
- "model.layers.39.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
238
- "model.layers.39.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
239
- "model.layers.39.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
240
- "model.layers.39.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
241
- "model.layers.39.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
242
- "model.layers.39.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
243
- "model.layers.39.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
244
- "model.layers.4.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
245
- "model.layers.4.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
246
- "model.layers.4.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
247
- "model.layers.4.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
248
- "model.layers.4.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
249
- "model.layers.4.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
250
- "model.layers.4.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
251
- "model.layers.40.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
252
- "model.layers.40.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
253
- "model.layers.40.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
254
- "model.layers.40.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
255
- "model.layers.40.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
256
- "model.layers.40.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
257
- "model.layers.40.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
258
- "model.layers.41.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
259
- "model.layers.41.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
260
- "model.layers.41.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
261
- "model.layers.41.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
262
- "model.layers.41.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
263
- "model.layers.41.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
264
- "model.layers.41.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
265
- "model.layers.42.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
266
- "model.layers.42.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
267
- "model.layers.42.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
268
- "model.layers.42.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
269
- "model.layers.42.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
270
- "model.layers.42.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
271
- "model.layers.42.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
272
- "model.layers.43.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
273
- "model.layers.43.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
274
- "model.layers.43.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
275
- "model.layers.43.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
276
- "model.layers.43.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
277
- "model.layers.43.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
278
- "model.layers.43.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
279
- "model.layers.44.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
280
- "model.layers.44.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
281
- "model.layers.44.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
282
- "model.layers.44.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
283
- "model.layers.44.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
284
- "model.layers.44.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
285
- "model.layers.44.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
286
- "model.layers.45.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
287
- "model.layers.45.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
288
- "model.layers.45.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
289
- "model.layers.45.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
290
- "model.layers.45.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
291
- "model.layers.45.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
292
- "model.layers.45.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
293
- "model.layers.46.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
294
- "model.layers.46.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
295
- "model.layers.46.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
296
- "model.layers.46.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
297
- "model.layers.46.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
298
- "model.layers.46.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
299
- "model.layers.46.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
300
- "model.layers.47.attention.wo.weight": "pytorch_model-00004-of-00004.bin",
301
- "model.layers.47.attention.wqkv.weight": "pytorch_model-00004-of-00004.bin",
302
- "model.layers.47.attention_norm.weight": "pytorch_model-00004-of-00004.bin",
303
- "model.layers.47.feed_forward.w1.weight": "pytorch_model-00004-of-00004.bin",
304
- "model.layers.47.feed_forward.w2.weight": "pytorch_model-00004-of-00004.bin",
305
- "model.layers.47.feed_forward.w3.weight": "pytorch_model-00004-of-00004.bin",
306
- "model.layers.47.ffn_norm.weight": "pytorch_model-00004-of-00004.bin",
307
- "model.layers.5.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
308
- "model.layers.5.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
309
- "model.layers.5.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
310
- "model.layers.5.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
311
- "model.layers.5.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
312
- "model.layers.5.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
313
- "model.layers.5.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
314
- "model.layers.6.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
315
- "model.layers.6.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
316
- "model.layers.6.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
317
- "model.layers.6.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
318
- "model.layers.6.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
319
- "model.layers.6.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
320
- "model.layers.6.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
321
- "model.layers.7.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
322
- "model.layers.7.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
323
- "model.layers.7.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
324
- "model.layers.7.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
325
- "model.layers.7.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
326
- "model.layers.7.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
327
- "model.layers.7.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
328
- "model.layers.8.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
329
- "model.layers.8.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
330
- "model.layers.8.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
331
- "model.layers.8.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
332
- "model.layers.8.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
333
- "model.layers.8.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
334
- "model.layers.8.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
335
- "model.layers.9.attention.wo.weight": "pytorch_model-00001-of-00004.bin",
336
- "model.layers.9.attention.wqkv.weight": "pytorch_model-00001-of-00004.bin",
337
- "model.layers.9.attention_norm.weight": "pytorch_model-00001-of-00004.bin",
338
- "model.layers.9.feed_forward.w1.weight": "pytorch_model-00001-of-00004.bin",
339
- "model.layers.9.feed_forward.w2.weight": "pytorch_model-00001-of-00004.bin",
340
- "model.layers.9.feed_forward.w3.weight": "pytorch_model-00001-of-00004.bin",
341
- "model.layers.9.ffn_norm.weight": "pytorch_model-00001-of-00004.bin",
342
- "model.norm.weight": "pytorch_model-00004-of-00004.bin",
343
- "model.tok_embeddings.weight": "pytorch_model-00001-of-00004.bin",
344
- "output.weight": "pytorch_model-00004-of-00004.bin"
345
- }
346
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
special_tokens_map.json CHANGED
@@ -1,6 +1,38 @@
1
  {
2
- "bos_token": "<s>",
3
- "eos_token": "</s>",
4
- "pad_token": "</s>",
5
- "unk_token": "<unk>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  }
 
1
  {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|action_start|>",
6
+ "<|action_end|>",
7
+ "<|interpreter|>",
8
+ "<|plugin|>"
9
+ ],
10
+ "bos_token": {
11
+ "content": "<s>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ "eos_token": {
18
+ "content": "</s>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "</s>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "unk_token": {
32
+ "content": "<unk>",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ }
38
  }
tokenization_internlm.py → tokenization_internlm2.py RENAMED
@@ -1,10 +1,7 @@
1
  # coding=utf-8
2
- # Copyright (c) InternLM. All rights reserved.
3
  #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -18,7 +15,7 @@
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
 
21
- """Tokenization classes for IntermLM."""
22
  import os
23
  from shutil import copyfile
24
  from typing import Any, Dict, List, Optional, Tuple
@@ -34,9 +31,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
34
  PRETRAINED_VOCAB_FILES_MAP = {}
35
 
36
 
37
- class InternLMTokenizer(PreTrainedTokenizer):
 
38
  """
39
- Construct a InternLM tokenizer. Based on byte-level Byte-Pair-Encoding.
40
 
41
  Args:
42
  vocab_file (`str`):
@@ -79,8 +77,6 @@ class InternLMTokenizer(PreTrainedTokenizer):
79
  **kwargs,
80
  )
81
 
82
- """ Initialization"""
83
-
84
  @property
85
  def no_prefix_space_tokens(self):
86
  if self._no_prefix_space_tokens is None:
 
1
  # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
  #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
 
 
 
5
  #
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
  # you may not use this file except in compliance with the License.
 
15
  # See the License for the specific language governing permissions and
16
  # limitations under the License.
17
 
18
+ """Tokenization classes for InternLM."""
19
  import os
20
  from shutil import copyfile
21
  from typing import Any, Dict, List, Optional, Tuple
 
31
  PRETRAINED_VOCAB_FILES_MAP = {}
32
 
33
 
34
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
  """
37
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
38
 
39
  Args:
40
  vocab_file (`str`):
 
77
  **kwargs,
78
  )
79
 
 
 
80
  @property
81
  def no_prefix_space_tokens(self):
82
  if self._no_prefix_space_tokens is None:
tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization Fast class for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, Optional, Tuple
22
+
23
+ from tokenizers import processors, decoders, Tokenizer, normalizers
24
+ from tokenizers.models import BPE
25
+
26
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
27
+ from transformers.utils import logging
28
+
29
+ from transformers.convert_slow_tokenizer import (
30
+ SLOW_TO_FAST_CONVERTERS,
31
+ SpmConverter,
32
+ SentencePieceExtractor,
33
+ )
34
+
35
+ from .tokenization_internlm2 import InternLM2Tokenizer
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
40
+
41
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
42
+ class InternLM2Converter(SpmConverter):
43
+ handle_byte_fallback = True
44
+
45
+ def vocab(self, proto):
46
+ vocab = [
47
+ ("<unk>", 0.0),
48
+ ("<s>", 0.0),
49
+ ("</s>", 0.0),
50
+ ]
51
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
52
+ return vocab
53
+
54
+ def unk_id(self, proto):
55
+ unk_id = 0
56
+ return unk_id
57
+
58
+ def decoder(self, replacement, add_prefix_space):
59
+ decoders_sequence = [
60
+ decoders.Replace("▁", " "),
61
+ decoders.ByteFallback(),
62
+ decoders.Fuse(),
63
+ ]
64
+ if self.proto.normalizer_spec.add_dummy_prefix:
65
+ decoders_sequence.append(decoders.Strip(content=" ", left=1))
66
+ return decoders.Sequence(decoders_sequence)
67
+
68
+ def tokenizer(self, proto):
69
+ model_type = proto.trainer_spec.model_type
70
+ vocab_scores = self.vocab(proto)
71
+ # special tokens
72
+ added_tokens = self.original_tokenizer.added_tokens_decoder
73
+ for i in range(len(vocab_scores)):
74
+ piece, score = vocab_scores[i]
75
+ if i in added_tokens:
76
+ vocab_scores[i] = (added_tokens[i].content, score)
77
+ if model_type == 1:
78
+ raise RuntimeError("InternLM2 is supposed to be a BPE model!")
79
+
80
+ elif model_type == 2:
81
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
82
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
83
+ tokenizer = Tokenizer(
84
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
85
+ )
86
+ tokenizer.add_special_tokens(
87
+ [ added_token for index, added_token in added_tokens.items()]
88
+ )
89
+ else:
90
+ raise Exception(
91
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
92
+ )
93
+
94
+ return tokenizer
95
+
96
+ def normalizer(self, proto):
97
+ normalizers_list = []
98
+ if proto.normalizer_spec.add_dummy_prefix:
99
+ normalizers_list.append(normalizers.Prepend(prepend="▁"))
100
+ normalizers_list.append(normalizers.Replace(pattern=" ", content="▁"))
101
+ return normalizers.Sequence(normalizers_list)
102
+
103
+ def pre_tokenizer(self, replacement, add_prefix_space):
104
+ return None
105
+
106
+ SLOW_TO_FAST_CONVERTERS["InternLM2Tokenizer"] = InternLM2Converter
107
+
108
+
109
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
110
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
111
+ vocab_files_names = VOCAB_FILES_NAMES
112
+ slow_tokenizer_class = InternLM2Tokenizer
113
+ padding_side = "left"
114
+ model_input_names = ["input_ids", "attention_mask"]
115
+ _auto_class = "AutoTokenizer"
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_file,
120
+ unk_token="<unk>",
121
+ bos_token="<s>",
122
+ eos_token="</s>",
123
+ pad_token="</s>",
124
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
125
+ add_bos_token=True,
126
+ add_eos_token=False,
127
+ decode_with_prefix_space=False,
128
+ clean_up_tokenization_spaces=False,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(
132
+ vocab_file=vocab_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ sp_model_kwargs=sp_model_kwargs,
138
+ add_bos_token=add_bos_token,
139
+ add_eos_token=add_eos_token,
140
+ decode_with_prefix_space=decode_with_prefix_space,
141
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
142
+ **kwargs,
143
+ )
144
+ self._add_bos_token = add_bos_token
145
+ self._add_eos_token = add_eos_token
146
+ self.update_post_processor()
147
+ self.vocab_file = vocab_file
148
+
149
+ @property
150
+ def can_save_slow_tokenizer(self) -> bool:
151
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
152
+
153
+ def update_post_processor(self):
154
+ """
155
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
156
+ """
157
+ bos = self.bos_token
158
+ bos_token_id = self.bos_token_id
159
+ if bos is None and self.add_bos_token:
160
+ raise ValueError("add_bos_token = True but bos_token = None")
161
+
162
+ eos = self.eos_token
163
+ eos_token_id = self.eos_token_id
164
+ if eos is None and self.add_eos_token:
165
+ raise ValueError("add_eos_token = True but eos_token = None")
166
+
167
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
168
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
169
+
170
+ special_tokens = []
171
+ if self.add_bos_token:
172
+ special_tokens.append((bos, bos_token_id))
173
+ if self.add_eos_token:
174
+ special_tokens.append((eos, eos_token_id))
175
+ self._tokenizer.post_processor = processors.TemplateProcessing(
176
+ single=single, pair=pair, special_tokens=special_tokens
177
+ )
178
+
179
+ @property
180
+ def add_eos_token(self):
181
+ return self._add_eos_token
182
+
183
+ @property
184
+ def add_bos_token(self):
185
+ return self._add_bos_token
186
+
187
+ @add_eos_token.setter
188
+ def add_eos_token(self, value):
189
+ self._add_eos_token = value
190
+ self.update_post_processor()
191
+
192
+ @add_bos_token.setter
193
+ def add_bos_token(self, value):
194
+ self._add_bos_token = value
195
+ self.update_post_processor()
196
+
197
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
198
+ if not self.can_save_slow_tokenizer:
199
+ raise ValueError(
200
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
201
+ "tokenizer."
202
+ )
203
+
204
+ if not os.path.isdir(save_directory):
205
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
206
+ return
207
+ out_vocab_file = os.path.join(
208
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
209
+ )
210
+
211
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
212
+ copyfile(self.vocab_file, out_vocab_file)
213
+
214
+ return (out_vocab_file,)
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c1f649fe0ac36a053b6beca3ac1c0a170ec1f2b4d99acda1e4ffc78715a7bf1
3
+ size 5753097
tokenizer_config.json CHANGED
@@ -1,15 +1,102 @@
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "auto_map": {
3
  "AutoTokenizer": [
4
- "tokenization_internlm.InternLMTokenizer",
5
- null
6
  ]
7
  },
8
  "bos_token": "<s>",
 
9
  "clean_up_tokenization_spaces": false,
 
10
  "eos_token": "</s>",
11
  "model_max_length": 1000000000000000019884624838656,
12
  "pad_token": "</s>",
13
- "tokenizer_class": "InternLMTokenizer",
 
14
  "unk_token": "<unk>"
15
  }
 
1
  {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "92538": {
30
+ "content": "<|plugin|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "92539": {
38
+ "content": "<|interpreter|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "92540": {
46
+ "content": "<|action_end|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "92541": {
54
+ "content": "<|action_start|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "92542": {
62
+ "content": "<|im_end|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "92543": {
70
+ "content": "<|im_start|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ }
77
+ },
78
+ "additional_special_tokens": [
79
+ "<|im_start|>",
80
+ "<|im_end|>",
81
+ "<|action_start|>",
82
+ "<|action_end|>",
83
+ "<|interpreter|>",
84
+ "<|plugin|>"
85
+ ],
86
  "auto_map": {
87
  "AutoTokenizer": [
88
+ "tokenization_internlm2.InternLM2Tokenizer",
89
+ "tokenization_internlm2_fast.InternLM2TokenizerFast"
90
  ]
91
  },
92
  "bos_token": "<s>",
93
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
94
  "clean_up_tokenization_spaces": false,
95
+ "decode_with_prefix_space": false,
96
  "eos_token": "</s>",
97
  "model_max_length": 1000000000000000019884624838656,
98
  "pad_token": "</s>",
99
+ "sp_model_kwargs": null,
100
+ "tokenizer_class": "InternLM2Tokenizer",
101
  "unk_token": "<unk>"
102
  }