add source files
Browse files- configuration_cpmbee.py +132 -0
- modeling_cpmbee.py +1944 -0
- test_modeling_cpmbee.py +183 -0
- test_tokenization_cpmbee.py +187 -0
- tokenization_cpmbee.py +868 -0
configuration_cpmbee.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" CpmBee model configuration"""
|
16 |
+
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
from ...configuration_utils import PretrainedConfig
|
20 |
+
from ...utils import logging
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
CPMBEE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
26 |
+
"openbmb/cpm-bee-10b": "https://huggingface.co/openbmb/cpm-bee-10b/resolve/main/config.json",
|
27 |
+
"openbmb/cpm-bee-5b": "https://huggingface.co/openbmb/cpm-bee-5b/resolve/main/config.json",
|
28 |
+
"openbmb/cpm-bee-2b": "https://huggingface.co/openbmb/cpm-bee-2b/resolve/main/config.json",
|
29 |
+
"openbmb/cpm-bee-1b": "https://huggingface.co/openbmb/cpm-bee-1b/resolve/main/config.json",
|
30 |
+
# See all CpmBee models at https://huggingface.co/models?filter=cpmbee
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
class CpmBeeConfig(PretrainedConfig):
|
35 |
+
r"""
|
36 |
+
This is the configuration class to store the configuration of a [`CpmBeeModel`]. It is used to instbeeiate an
|
37 |
+
CPMBee model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
38 |
+
with the defaults will yield a similar configuration to that of the CPMBee
|
39 |
+
[openbmb/cpm-bee-10b](https://huggingface.co/openbmb/cpm-bee-10b) architecture.
|
40 |
+
|
41 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
42 |
+
documentation from [`PretrainedConfig`] for more information.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
vocab_size (`int`, *optional*, defaults to 30720):
|
46 |
+
Vocabulary size of the CPMBee model. Defines the number of different tokens that can be represented by the
|
47 |
+
`input` passed when calling [`CpmBeeModel`].
|
48 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
49 |
+
Dimension of the encoder layers.
|
50 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
51 |
+
Number of attention heads in the Transformer encoder.
|
52 |
+
dim_head (`int`, *optional*, defaults to 128):
|
53 |
+
Dimension of attention heads for each attention layer in the Transformer encoder.
|
54 |
+
dim_ff (`int`, *optional*, defaults to 10240):
|
55 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
56 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
57 |
+
Number of layers of the Transformer encoder.
|
58 |
+
dropout_p (`float`, *optional*, defaults to 0.1):
|
59 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder.
|
60 |
+
position_bias_num_buckets (`int`, *optional*, defaults to 512):
|
61 |
+
The number of position_bias buckets.
|
62 |
+
position_bias_num_segment_buckets (`int`, *optional*, defaults to 32):
|
63 |
+
The number of segment buckets.
|
64 |
+
position_bias_max_distance (`int`, *optional*, defaults to 2048):
|
65 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
66 |
+
just in case (e.g., 512 or 1024 or 2048).
|
67 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
68 |
+
The epsilon used by the layer normalization layers.
|
69 |
+
init_std (`float`, *optional*, defaults to 1.0):
|
70 |
+
Initialize parameters with std = init_std.
|
71 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
72 |
+
Whether to use cache.
|
73 |
+
distance_scale (`float` or `int`, *optional*, defaults to 16):
|
74 |
+
Scale the rotary embedding.
|
75 |
+
mask_modules (`list` or `tuple`, *optional*, defaults to None):
|
76 |
+
Decides which feedforward block or attention block is pruned.
|
77 |
+
half (`bool`, *optional*, defaults to `False`):
|
78 |
+
Decides the model parameters are half-precision or not.
|
79 |
+
|
80 |
+
Example:
|
81 |
+
|
82 |
+
```python
|
83 |
+
>>> from transformers import CpmBeeModel, CpmBeeConfig
|
84 |
+
|
85 |
+
>>> # Initializing a CPMBee cpm-bee-10b style configuration
|
86 |
+
>>> configuration = CpmBeeConfig()
|
87 |
+
|
88 |
+
>>> # Initializing a model from the cpm-bee-10b style configuration
|
89 |
+
>>> model = CpmBeeModel(configuration)
|
90 |
+
|
91 |
+
>>> # Accessing the model configuration
|
92 |
+
>>> configuration = model.config
|
93 |
+
```"""
|
94 |
+
model_type = "cpmbee"
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
vocab_size: int = 30720,
|
99 |
+
hidden_size: int = 4096,
|
100 |
+
num_attention_heads: int = 64,
|
101 |
+
dim_head: int = 64,
|
102 |
+
dim_ff: int = 10240,
|
103 |
+
num_hidden_layers: int = 32,
|
104 |
+
dropout_p: int = 0.0,
|
105 |
+
position_bias_num_buckets: int = 256,
|
106 |
+
position_bias_num_segment_buckets: int = 32,
|
107 |
+
position_bias_max_distance: int = 2048,
|
108 |
+
eps: int = 1e-6,
|
109 |
+
init_std: float = 1.0,
|
110 |
+
use_cache: bool = True,
|
111 |
+
distance_scale: Union[int, float] = 16,
|
112 |
+
mask_modules: Optional[Union[List, Tuple]] = None,
|
113 |
+
half: bool = False,
|
114 |
+
**kwargs,
|
115 |
+
):
|
116 |
+
super().__init__(**kwargs)
|
117 |
+
self.position_bias_num_segment_buckets = position_bias_num_segment_buckets
|
118 |
+
self.hidden_size = hidden_size
|
119 |
+
self.num_attention_heads = num_attention_heads
|
120 |
+
self.dim_head = dim_head
|
121 |
+
self.dim_ff = dim_ff
|
122 |
+
self.num_hidden_layers = num_hidden_layers
|
123 |
+
self.position_bias_num_buckets = position_bias_num_buckets
|
124 |
+
self.position_bias_max_distance = position_bias_max_distance
|
125 |
+
self.dropout_p = dropout_p
|
126 |
+
self.eps = eps
|
127 |
+
self.use_cache = use_cache
|
128 |
+
self.vocab_size = vocab_size
|
129 |
+
self.init_std = init_std
|
130 |
+
self.distance_scale = distance_scale
|
131 |
+
self.half = half
|
132 |
+
self.mask_modules = mask_modules
|
modeling_cpmbee.py
ADDED
@@ -0,0 +1,1944 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The OpenBMB Team The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch CpmBee model."""
|
16 |
+
import copy
|
17 |
+
import math
|
18 |
+
from collections import UserDict
|
19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
|
24 |
+
from ...generation.beam_search import BeamHypotheses, BeamSearchScorer
|
25 |
+
from ...generation.streamers import BaseStreamer
|
26 |
+
from ...generation.utils import (
|
27 |
+
GenerationConfig,
|
28 |
+
LogitsProcessorList,
|
29 |
+
StoppingCriteriaList,
|
30 |
+
dist,
|
31 |
+
inspect,
|
32 |
+
is_deepspeed_zero3_enabled,
|
33 |
+
warnings,
|
34 |
+
)
|
35 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
36 |
+
from ...modeling_utils import PreTrainedModel
|
37 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
38 |
+
from .configuration_cpmbee import CpmBeeConfig
|
39 |
+
from .tokenization_cpmbee import CpmBeeTokenizer
|
40 |
+
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__)
|
43 |
+
|
44 |
+
_CHECKPOINT_FOR_DOC = "openbmb/cpm-bee-10b"
|
45 |
+
_CONFIG_FOR_DOC = "CpmBeeConfig"
|
46 |
+
|
47 |
+
CPMBEE_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
48 |
+
"openbmb/cpm-bee-10b",
|
49 |
+
"openbmb/cpm-bee-5b",
|
50 |
+
"openbmb/cpm-bee-2b",
|
51 |
+
"openbmb/cpm-bee-1b",
|
52 |
+
# See all CPMBee models at https://huggingface.co/models?filter=cpmbee
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
class CpmBeeLinear(nn.Linear):
|
57 |
+
def __init__(self, dim_in, dim_out, dtype):
|
58 |
+
"""
|
59 |
+
Construct a linear for CPMBee. It contains a scale operation.
|
60 |
+
"""
|
61 |
+
super().__init__(dim_in, dim_out, bias=False)
|
62 |
+
self.dim_in = self.in_features = dim_in
|
63 |
+
self.dim_out = self.out_features = dim_out
|
64 |
+
|
65 |
+
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
|
66 |
+
|
67 |
+
def forward(self, x: torch.Tensor):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
x (`torch.Tensor` of shape `(batch, seq_len, dim_in)`): The input of linear layer
|
71 |
+
Returns:
|
72 |
+
`torch.Tensor` of shape `(batch, seq_len, dim_out)`: The output of the linear transform y.
|
73 |
+
"""
|
74 |
+
x = nn.functional.linear(x, self.weight)
|
75 |
+
x = x / math.sqrt(self.dim_in)
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
class CpmBeeLayerNorm(nn.Module):
|
80 |
+
"""
|
81 |
+
We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details."
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, config: CpmBeeConfig):
|
85 |
+
super().__init__()
|
86 |
+
|
87 |
+
self.eps = config.eps
|
88 |
+
self.dim_norm = config.hidden_size
|
89 |
+
self.weight = nn.Parameter(torch.empty(config.hidden_size, dtype=config.torch_dtype))
|
90 |
+
|
91 |
+
def forward(self, hidden_states: torch.Tensor):
|
92 |
+
"""
|
93 |
+
Args:
|
94 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
95 |
+
"""
|
96 |
+
if hidden_states.size(-1) != self.dim_norm:
|
97 |
+
raise AssertionError("hidden_states.size(-1) != self.dim_norm")
|
98 |
+
old_dtype = hidden_states.dtype
|
99 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
100 |
+
hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
|
101 |
+
return hidden_states
|
102 |
+
|
103 |
+
|
104 |
+
class CpmBeeAttention(nn.Module):
|
105 |
+
def __init__(self, config: CpmBeeConfig):
|
106 |
+
super().__init__()
|
107 |
+
self.dim_model = config.hidden_size
|
108 |
+
self.num_heads = config.num_attention_heads
|
109 |
+
self.dim_head = config.dim_head
|
110 |
+
|
111 |
+
self.project_q = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
|
112 |
+
self.project_k = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
|
113 |
+
self.project_v = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
|
114 |
+
|
115 |
+
self.attention_out = CpmBeeLinear(self.num_heads * self.dim_head, self.dim_model, dtype=config.torch_dtype)
|
116 |
+
|
117 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
118 |
+
|
119 |
+
if config.dropout_p is not None:
|
120 |
+
self.dropout = torch.nn.Dropout(p=config.dropout_p)
|
121 |
+
else:
|
122 |
+
self.dropout = None
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
hidden_q: torch.Tensor,
|
127 |
+
hidden_kv: torch.Tensor,
|
128 |
+
attention_mask: torch.BoolTensor,
|
129 |
+
position_bias: torch.Tensor,
|
130 |
+
output_attentions: Optional[bool] = False,
|
131 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
132 |
+
use_cache: Optional[bool] = None,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Args:
|
136 |
+
hidden_q (`torch.Tensor`):
|
137 |
+
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
138 |
+
hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
|
139 |
+
Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
|
140 |
+
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
141 |
+
Avoid invalid areas to participate in the calculation of self-attention.
|
142 |
+
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
143 |
+
Provide positional information to self-attention block.
|
144 |
+
output_attentions (`bool`, *optional*):
|
145 |
+
Whether or not to return the attentions tensors of all attention layers.
|
146 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
|
147 |
+
Cached past key and value projection states.
|
148 |
+
use_cache (`bool`, *optional*):
|
149 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
150 |
+
(see `past_key_values`).
|
151 |
+
"""
|
152 |
+
batch_size = hidden_q.size(0)
|
153 |
+
len_q = hidden_q.size(1)
|
154 |
+
len_k = hidden_kv.size(1)
|
155 |
+
|
156 |
+
query = self.project_q(hidden_q)
|
157 |
+
key = self.project_k(hidden_kv)
|
158 |
+
value = self.project_v(hidden_kv)
|
159 |
+
|
160 |
+
query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
161 |
+
key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
162 |
+
value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
163 |
+
|
164 |
+
if past_key_values is not None:
|
165 |
+
key = torch.cat([past_key_values[0], key], dim=-2)
|
166 |
+
value = torch.cat([past_key_values[1], value], dim=-2)
|
167 |
+
len_k = key.size(-2)
|
168 |
+
|
169 |
+
# (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
|
170 |
+
score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
|
171 |
+
score = score + position_bias
|
172 |
+
|
173 |
+
score = torch.masked_fill(
|
174 |
+
score,
|
175 |
+
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
|
176 |
+
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
177 |
+
)
|
178 |
+
score = self.softmax(score)
|
179 |
+
|
180 |
+
score = torch.masked_fill(
|
181 |
+
score,
|
182 |
+
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
|
183 |
+
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
184 |
+
)
|
185 |
+
if output_attentions:
|
186 |
+
attn_weights = score
|
187 |
+
else:
|
188 |
+
attn_weights = None
|
189 |
+
|
190 |
+
if self.dropout is not None:
|
191 |
+
score = self.dropout(score)
|
192 |
+
|
193 |
+
# (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
|
194 |
+
score = torch.matmul(score, value)
|
195 |
+
|
196 |
+
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
|
197 |
+
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
|
198 |
+
|
199 |
+
score = self.attention_out(score)
|
200 |
+
|
201 |
+
past_key_values = None
|
202 |
+
if use_cache:
|
203 |
+
past_key_values = (key, value)
|
204 |
+
|
205 |
+
return score, attn_weights, past_key_values
|
206 |
+
|
207 |
+
|
208 |
+
class CpmBeeSelfAttentionBlock(nn.Module):
|
209 |
+
def __init__(self, config: CpmBeeConfig):
|
210 |
+
super().__init__()
|
211 |
+
self.layernorm_before_attention = CpmBeeLayerNorm(config)
|
212 |
+
self.self_attention = CpmBeeAttention(config)
|
213 |
+
if config.dropout_p:
|
214 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
215 |
+
else:
|
216 |
+
self.dropout = None
|
217 |
+
|
218 |
+
def forward(
|
219 |
+
self,
|
220 |
+
hidden_states: torch.Tensor,
|
221 |
+
attention_mask: torch.Tensor,
|
222 |
+
position_bias: Optional[torch.Tensor] = None,
|
223 |
+
output_attentions: Optional[bool] = False,
|
224 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
225 |
+
use_cache: Optional[bool] = None,
|
226 |
+
):
|
227 |
+
"""
|
228 |
+
Args:
|
229 |
+
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
|
230 |
+
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
231 |
+
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
232 |
+
Avoid invalid areas to participate in the calculation of self-attention.
|
233 |
+
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
234 |
+
Provide positional information to self-attention block.
|
235 |
+
output_attentions (`bool`, *optional*):
|
236 |
+
Whether or not to return the attentions tensors of all attention layers.
|
237 |
+
past_key_values (`Tuple(torch.FloatTensor)`, *optional*):
|
238 |
+
Cached past key and value projection states.
|
239 |
+
use_cache (`bool`, *optional*):
|
240 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
241 |
+
(see `past_key_values`).
|
242 |
+
"""
|
243 |
+
outputs = self.layernorm_before_attention(hidden_states)
|
244 |
+
outputs = self.self_attention(
|
245 |
+
outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache
|
246 |
+
)
|
247 |
+
|
248 |
+
outputs, attn_weights, current_key_value = outputs
|
249 |
+
|
250 |
+
if self.dropout is not None:
|
251 |
+
outputs = self.dropout(outputs)
|
252 |
+
hidden_states = (hidden_states + outputs) / 1.05
|
253 |
+
|
254 |
+
return hidden_states, attn_weights, current_key_value
|
255 |
+
|
256 |
+
|
257 |
+
class CpmBeeDenseGatedACT(nn.Module):
|
258 |
+
def __init__(self, config: CpmBeeConfig):
|
259 |
+
super().__init__()
|
260 |
+
self.w_0 = CpmBeeLinear(config.hidden_size, config.dim_ff, dtype=config.torch_dtype)
|
261 |
+
self.w_1 = CpmBeeLinear(config.hidden_size, config.dim_ff, dtype=config.torch_dtype)
|
262 |
+
self.act = torch.nn.GELU()
|
263 |
+
|
264 |
+
def forward(self, hidden_states: torch.Tensor):
|
265 |
+
"""Transform an input tensor from one feature space to another via a nonlinear operation
|
266 |
+
|
267 |
+
Args:
|
268 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
269 |
+
"""
|
270 |
+
gate_score = self.act(self.w_0(hidden_states))
|
271 |
+
hidden_states = self.w_1(hidden_states)
|
272 |
+
|
273 |
+
hidden_states = gate_score * hidden_states
|
274 |
+
return hidden_states
|
275 |
+
|
276 |
+
|
277 |
+
class CpmBeeFeedForward(nn.Module):
|
278 |
+
def __init__(self, config: CpmBeeConfig):
|
279 |
+
super().__init__()
|
280 |
+
self.w_in = CpmBeeDenseGatedACT(config)
|
281 |
+
if config.dropout_p is not None:
|
282 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
283 |
+
else:
|
284 |
+
self.dropout = None
|
285 |
+
|
286 |
+
self.w_out = CpmBeeLinear(config.dim_ff, config.hidden_size, dtype=config.torch_dtype)
|
287 |
+
|
288 |
+
def forward(self, hidden_states: torch.Tensor):
|
289 |
+
"""
|
290 |
+
Args:
|
291 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
292 |
+
"""
|
293 |
+
hidden_states = self.w_in(hidden_states)
|
294 |
+
|
295 |
+
if self.dropout is not None:
|
296 |
+
hidden_states = self.dropout(hidden_states)
|
297 |
+
|
298 |
+
hidden_states = self.w_out(hidden_states)
|
299 |
+
|
300 |
+
return hidden_states
|
301 |
+
|
302 |
+
|
303 |
+
class CpmBeeFFNBlock(nn.Module):
|
304 |
+
def __init__(self, config: CpmBeeConfig):
|
305 |
+
super().__init__()
|
306 |
+
self.layernorm_before_ffn = CpmBeeLayerNorm(config)
|
307 |
+
self.ffn = CpmBeeFeedForward(config)
|
308 |
+
if config.dropout_p:
|
309 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
310 |
+
else:
|
311 |
+
self.dropout = None
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self,
|
315 |
+
hidden_states: torch.Tensor,
|
316 |
+
):
|
317 |
+
"""
|
318 |
+
Args:
|
319 |
+
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
|
320 |
+
Hidden states before feed forward layer.
|
321 |
+
"""
|
322 |
+
ln_outputs = self.layernorm_before_ffn(hidden_states)
|
323 |
+
outputs = self.ffn(ln_outputs)
|
324 |
+
if self.dropout is not None:
|
325 |
+
outputs = self.dropout(outputs)
|
326 |
+
hidden_states = (hidden_states + outputs) / 1.05
|
327 |
+
return hidden_states
|
328 |
+
|
329 |
+
|
330 |
+
class CpmBeeTransformerBlock(nn.Module):
|
331 |
+
def __init__(self, config: CpmBeeConfig, mask_att: bool = False, mask_ffn: bool = False):
|
332 |
+
super().__init__()
|
333 |
+
self.mask_att = mask_att
|
334 |
+
self.mask_ffn = mask_ffn
|
335 |
+
|
336 |
+
if not self.mask_att:
|
337 |
+
self.self_att = CpmBeeSelfAttentionBlock(config)
|
338 |
+
if not self.mask_ffn:
|
339 |
+
self.ffn = CpmBeeFFNBlock(config)
|
340 |
+
|
341 |
+
def forward(
|
342 |
+
self,
|
343 |
+
hidden_states: torch.Tensor,
|
344 |
+
attention_mask: torch.Tensor,
|
345 |
+
position_bias: Optional[torch.Tensor] = None,
|
346 |
+
output_attentions: Optional[bool] = False,
|
347 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
348 |
+
use_cache: Optional[bool] = None,
|
349 |
+
):
|
350 |
+
"""
|
351 |
+
Args:
|
352 |
+
hidden_states (`torch.Tensor`):
|
353 |
+
Input to the layer of shape `(batch, seq_len, dim_model)`
|
354 |
+
attention_mask (`torch.Tensor`):
|
355 |
+
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
|
356 |
+
position_bias (`torch.Tensor`):
|
357 |
+
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
|
358 |
+
output_attentions (`bool`, *optional*):
|
359 |
+
Whether or not to return the attentions tensors of all attention layers.
|
360 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
|
361 |
+
Cached past key and value projection states
|
362 |
+
use_cache (`bool`, *optional*):
|
363 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
364 |
+
(see `past_key_values`).
|
365 |
+
"""
|
366 |
+
if not self.mask_att:
|
367 |
+
hidden_states = self.self_att(
|
368 |
+
hidden_states,
|
369 |
+
attention_mask=attention_mask,
|
370 |
+
position_bias=position_bias,
|
371 |
+
output_attentions=output_attentions,
|
372 |
+
past_key_values=past_key_values,
|
373 |
+
use_cache=use_cache,
|
374 |
+
)
|
375 |
+
|
376 |
+
hidden_states, attn_weights, current_key_value = hidden_states
|
377 |
+
else:
|
378 |
+
attn_weights, current_key_value = None, (None, None)
|
379 |
+
|
380 |
+
if not self.mask_ffn:
|
381 |
+
hidden_states = self.ffn(hidden_states)
|
382 |
+
|
383 |
+
return hidden_states, attn_weights, current_key_value
|
384 |
+
|
385 |
+
|
386 |
+
class CpmBeeEncoder(nn.Module):
|
387 |
+
def __init__(self, config: CpmBeeConfig):
|
388 |
+
super().__init__()
|
389 |
+
self.num_layers = config.num_hidden_layers
|
390 |
+
if config.mask_modules is not None:
|
391 |
+
assert len(config.mask_modules) == self.num_layers, "The total number of masks should equal to num_layers"
|
392 |
+
for mask_module in config.mask_modules:
|
393 |
+
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
|
394 |
+
else:
|
395 |
+
config.mask_modules = [(False, False)] * self.num_layers
|
396 |
+
|
397 |
+
self.layers = nn.ModuleList(
|
398 |
+
[
|
399 |
+
CpmBeeTransformerBlock(
|
400 |
+
config, mask_att=config.mask_modules[ith][0], mask_ffn=config.mask_modules[ith][1]
|
401 |
+
)
|
402 |
+
for ith in range(self.num_layers)
|
403 |
+
]
|
404 |
+
)
|
405 |
+
|
406 |
+
self.output_layernorm = CpmBeeLayerNorm(config)
|
407 |
+
|
408 |
+
def forward(
|
409 |
+
self,
|
410 |
+
hidden_states: torch.Tensor,
|
411 |
+
attention_mask: torch.Tensor,
|
412 |
+
position_bias: torch.Tensor,
|
413 |
+
output_attentions: Optional[bool] = None,
|
414 |
+
output_hidden_states: Optional[bool] = None,
|
415 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
416 |
+
use_cache: Optional[bool] = None,
|
417 |
+
):
|
418 |
+
"""
|
419 |
+
Args:
|
420 |
+
hidden_states (`torch.Tensor`):
|
421 |
+
Input to the layer of shape `(batch, seq_len, dim_model)`
|
422 |
+
attention_mask (`torch.Tensor`):
|
423 |
+
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
|
424 |
+
position_bias (`torch.Tensor`):
|
425 |
+
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
|
426 |
+
output_attentions (`bool`, *optional*):
|
427 |
+
Whether or not to return the attentions tensors of all attention layers.
|
428 |
+
output_hidden_states (`bool`, *optional*):
|
429 |
+
Whether or not to return the hidden states of all layers.
|
430 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
|
431 |
+
Cached past key and value projection states
|
432 |
+
use_cache (`bool`, *optional*):
|