jeffreygo commited on
Commit
83e21b6
1 Parent(s): b197fa6

add source files

Browse files
configuration_cpmbee.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ CpmBee model configuration"""
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ CPMBEE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "openbmb/cpm-bee-10b": "https://huggingface.co/openbmb/cpm-bee-10b/resolve/main/config.json",
27
+ "openbmb/cpm-bee-5b": "https://huggingface.co/openbmb/cpm-bee-5b/resolve/main/config.json",
28
+ "openbmb/cpm-bee-2b": "https://huggingface.co/openbmb/cpm-bee-2b/resolve/main/config.json",
29
+ "openbmb/cpm-bee-1b": "https://huggingface.co/openbmb/cpm-bee-1b/resolve/main/config.json",
30
+ # See all CpmBee models at https://huggingface.co/models?filter=cpmbee
31
+ }
32
+
33
+
34
+ class CpmBeeConfig(PretrainedConfig):
35
+ r"""
36
+ This is the configuration class to store the configuration of a [`CpmBeeModel`]. It is used to instbeeiate an
37
+ CPMBee model according to the specified arguments, defining the model architecture. Instantiating a configuration
38
+ with the defaults will yield a similar configuration to that of the CPMBee
39
+ [openbmb/cpm-bee-10b](https://huggingface.co/openbmb/cpm-bee-10b) architecture.
40
+
41
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
42
+ documentation from [`PretrainedConfig`] for more information.
43
+
44
+ Args:
45
+ vocab_size (`int`, *optional*, defaults to 30720):
46
+ Vocabulary size of the CPMBee model. Defines the number of different tokens that can be represented by the
47
+ `input` passed when calling [`CpmBeeModel`].
48
+ hidden_size (`int`, *optional*, defaults to 4096):
49
+ Dimension of the encoder layers.
50
+ num_attention_heads (`int`, *optional*, defaults to 32):
51
+ Number of attention heads in the Transformer encoder.
52
+ dim_head (`int`, *optional*, defaults to 128):
53
+ Dimension of attention heads for each attention layer in the Transformer encoder.
54
+ dim_ff (`int`, *optional*, defaults to 10240):
55
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
56
+ num_hidden_layers (`int`, *optional*, defaults to 48):
57
+ Number of layers of the Transformer encoder.
58
+ dropout_p (`float`, *optional*, defaults to 0.1):
59
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder.
60
+ position_bias_num_buckets (`int`, *optional*, defaults to 512):
61
+ The number of position_bias buckets.
62
+ position_bias_num_segment_buckets (`int`, *optional*, defaults to 32):
63
+ The number of segment buckets.
64
+ position_bias_max_distance (`int`, *optional*, defaults to 2048):
65
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
66
+ just in case (e.g., 512 or 1024 or 2048).
67
+ eps (`float`, *optional*, defaults to 1e-6):
68
+ The epsilon used by the layer normalization layers.
69
+ init_std (`float`, *optional*, defaults to 1.0):
70
+ Initialize parameters with std = init_std.
71
+ use_cache (`bool`, *optional*, defaults to `True`):
72
+ Whether to use cache.
73
+ distance_scale (`float` or `int`, *optional*, defaults to 16):
74
+ Scale the rotary embedding.
75
+ mask_modules (`list` or `tuple`, *optional*, defaults to None):
76
+ Decides which feedforward block or attention block is pruned.
77
+ half (`bool`, *optional*, defaults to `False`):
78
+ Decides the model parameters are half-precision or not.
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import CpmBeeModel, CpmBeeConfig
84
+
85
+ >>> # Initializing a CPMBee cpm-bee-10b style configuration
86
+ >>> configuration = CpmBeeConfig()
87
+
88
+ >>> # Initializing a model from the cpm-bee-10b style configuration
89
+ >>> model = CpmBeeModel(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+ model_type = "cpmbee"
95
+
96
+ def __init__(
97
+ self,
98
+ vocab_size: int = 30720,
99
+ hidden_size: int = 4096,
100
+ num_attention_heads: int = 64,
101
+ dim_head: int = 64,
102
+ dim_ff: int = 10240,
103
+ num_hidden_layers: int = 32,
104
+ dropout_p: int = 0.0,
105
+ position_bias_num_buckets: int = 256,
106
+ position_bias_num_segment_buckets: int = 32,
107
+ position_bias_max_distance: int = 2048,
108
+ eps: int = 1e-6,
109
+ init_std: float = 1.0,
110
+ use_cache: bool = True,
111
+ distance_scale: Union[int, float] = 16,
112
+ mask_modules: Optional[Union[List, Tuple]] = None,
113
+ half: bool = False,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+ self.position_bias_num_segment_buckets = position_bias_num_segment_buckets
118
+ self.hidden_size = hidden_size
119
+ self.num_attention_heads = num_attention_heads
120
+ self.dim_head = dim_head
121
+ self.dim_ff = dim_ff
122
+ self.num_hidden_layers = num_hidden_layers
123
+ self.position_bias_num_buckets = position_bias_num_buckets
124
+ self.position_bias_max_distance = position_bias_max_distance
125
+ self.dropout_p = dropout_p
126
+ self.eps = eps
127
+ self.use_cache = use_cache
128
+ self.vocab_size = vocab_size
129
+ self.init_std = init_std
130
+ self.distance_scale = distance_scale
131
+ self.half = half
132
+ self.mask_modules = mask_modules
modeling_cpmbee.py ADDED
@@ -0,0 +1,1944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch CpmBee model."""
16
+ import copy
17
+ import math
18
+ from collections import UserDict
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from ...generation.beam_search import BeamHypotheses, BeamSearchScorer
25
+ from ...generation.streamers import BaseStreamer
26
+ from ...generation.utils import (
27
+ GenerationConfig,
28
+ LogitsProcessorList,
29
+ StoppingCriteriaList,
30
+ dist,
31
+ inspect,
32
+ is_deepspeed_zero3_enabled,
33
+ warnings,
34
+ )
35
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
36
+ from ...modeling_utils import PreTrainedModel
37
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
+ from .configuration_cpmbee import CpmBeeConfig
39
+ from .tokenization_cpmbee import CpmBeeTokenizer
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ _CHECKPOINT_FOR_DOC = "openbmb/cpm-bee-10b"
45
+ _CONFIG_FOR_DOC = "CpmBeeConfig"
46
+
47
+ CPMBEE_PRETRAINED_MODEL_ARCHIVE_LIST = [
48
+ "openbmb/cpm-bee-10b",
49
+ "openbmb/cpm-bee-5b",
50
+ "openbmb/cpm-bee-2b",
51
+ "openbmb/cpm-bee-1b",
52
+ # See all CPMBee models at https://huggingface.co/models?filter=cpmbee
53
+ ]
54
+
55
+
56
+ class CpmBeeLinear(nn.Linear):
57
+ def __init__(self, dim_in, dim_out, dtype):
58
+ """
59
+ Construct a linear for CPMBee. It contains a scale operation.
60
+ """
61
+ super().__init__(dim_in, dim_out, bias=False)
62
+ self.dim_in = self.in_features = dim_in
63
+ self.dim_out = self.out_features = dim_out
64
+
65
+ self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
66
+
67
+ def forward(self, x: torch.Tensor):
68
+ """
69
+ Args:
70
+ x (`torch.Tensor` of shape `(batch, seq_len, dim_in)`): The input of linear layer
71
+ Returns:
72
+ `torch.Tensor` of shape `(batch, seq_len, dim_out)`: The output of the linear transform y.
73
+ """
74
+ x = nn.functional.linear(x, self.weight)
75
+ x = x / math.sqrt(self.dim_in)
76
+ return x
77
+
78
+
79
+ class CpmBeeLayerNorm(nn.Module):
80
+ """
81
+ We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details."
82
+ """
83
+
84
+ def __init__(self, config: CpmBeeConfig):
85
+ super().__init__()
86
+
87
+ self.eps = config.eps
88
+ self.dim_norm = config.hidden_size
89
+ self.weight = nn.Parameter(torch.empty(config.hidden_size, dtype=config.torch_dtype))
90
+
91
+ def forward(self, hidden_states: torch.Tensor):
92
+ """
93
+ Args:
94
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
95
+ """
96
+ if hidden_states.size(-1) != self.dim_norm:
97
+ raise AssertionError("hidden_states.size(-1) != self.dim_norm")
98
+ old_dtype = hidden_states.dtype
99
+ variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
100
+ hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
101
+ return hidden_states
102
+
103
+
104
+ class CpmBeeAttention(nn.Module):
105
+ def __init__(self, config: CpmBeeConfig):
106
+ super().__init__()
107
+ self.dim_model = config.hidden_size
108
+ self.num_heads = config.num_attention_heads
109
+ self.dim_head = config.dim_head
110
+
111
+ self.project_q = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
112
+ self.project_k = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
113
+ self.project_v = CpmBeeLinear(self.dim_model, self.num_heads * self.dim_head, dtype=config.torch_dtype)
114
+
115
+ self.attention_out = CpmBeeLinear(self.num_heads * self.dim_head, self.dim_model, dtype=config.torch_dtype)
116
+
117
+ self.softmax = torch.nn.Softmax(dim=-1)
118
+
119
+ if config.dropout_p is not None:
120
+ self.dropout = torch.nn.Dropout(p=config.dropout_p)
121
+ else:
122
+ self.dropout = None
123
+
124
+ def forward(
125
+ self,
126
+ hidden_q: torch.Tensor,
127
+ hidden_kv: torch.Tensor,
128
+ attention_mask: torch.BoolTensor,
129
+ position_bias: torch.Tensor,
130
+ output_attentions: Optional[bool] = False,
131
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
132
+ use_cache: Optional[bool] = None,
133
+ ):
134
+ """
135
+ Args:
136
+ hidden_q (`torch.Tensor`):
137
+ Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
138
+ hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
139
+ Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
140
+ attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
141
+ Avoid invalid areas to participate in the calculation of self-attention.
142
+ position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
143
+ Provide positional information to self-attention block.
144
+ output_attentions (`bool`, *optional*):
145
+ Whether or not to return the attentions tensors of all attention layers.
146
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
147
+ Cached past key and value projection states.
148
+ use_cache (`bool`, *optional*):
149
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
150
+ (see `past_key_values`).
151
+ """
152
+ batch_size = hidden_q.size(0)
153
+ len_q = hidden_q.size(1)
154
+ len_k = hidden_kv.size(1)
155
+
156
+ query = self.project_q(hidden_q)
157
+ key = self.project_k(hidden_kv)
158
+ value = self.project_v(hidden_kv)
159
+
160
+ query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
161
+ key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
162
+ value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
163
+
164
+ if past_key_values is not None:
165
+ key = torch.cat([past_key_values[0], key], dim=-2)
166
+ value = torch.cat([past_key_values[1], value], dim=-2)
167
+ len_k = key.size(-2)
168
+
169
+ # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
170
+ score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
171
+ score = score + position_bias
172
+
173
+ score = torch.masked_fill(
174
+ score,
175
+ attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
176
+ torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
177
+ )
178
+ score = self.softmax(score)
179
+
180
+ score = torch.masked_fill(
181
+ score,
182
+ attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
183
+ torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
184
+ )
185
+ if output_attentions:
186
+ attn_weights = score
187
+ else:
188
+ attn_weights = None
189
+
190
+ if self.dropout is not None:
191
+ score = self.dropout(score)
192
+
193
+ # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
194
+ score = torch.matmul(score, value)
195
+
196
+ score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
197
+ score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
198
+
199
+ score = self.attention_out(score)
200
+
201
+ past_key_values = None
202
+ if use_cache:
203
+ past_key_values = (key, value)
204
+
205
+ return score, attn_weights, past_key_values
206
+
207
+
208
+ class CpmBeeSelfAttentionBlock(nn.Module):
209
+ def __init__(self, config: CpmBeeConfig):
210
+ super().__init__()
211
+ self.layernorm_before_attention = CpmBeeLayerNorm(config)
212
+ self.self_attention = CpmBeeAttention(config)
213
+ if config.dropout_p:
214
+ self.dropout = torch.nn.Dropout(config.dropout_p)
215
+ else:
216
+ self.dropout = None
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.Tensor,
221
+ attention_mask: torch.Tensor,
222
+ position_bias: Optional[torch.Tensor] = None,
223
+ output_attentions: Optional[bool] = False,
224
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
225
+ use_cache: Optional[bool] = None,
226
+ ):
227
+ """
228
+ Args:
229
+ hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
230
+ Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
231
+ attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
232
+ Avoid invalid areas to participate in the calculation of self-attention.
233
+ position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
234
+ Provide positional information to self-attention block.
235
+ output_attentions (`bool`, *optional*):
236
+ Whether or not to return the attentions tensors of all attention layers.
237
+ past_key_values (`Tuple(torch.FloatTensor)`, *optional*):
238
+ Cached past key and value projection states.
239
+ use_cache (`bool`, *optional*):
240
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
241
+ (see `past_key_values`).
242
+ """
243
+ outputs = self.layernorm_before_attention(hidden_states)
244
+ outputs = self.self_attention(
245
+ outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache
246
+ )
247
+
248
+ outputs, attn_weights, current_key_value = outputs
249
+
250
+ if self.dropout is not None:
251
+ outputs = self.dropout(outputs)
252
+ hidden_states = (hidden_states + outputs) / 1.05
253
+
254
+ return hidden_states, attn_weights, current_key_value
255
+
256
+
257
+ class CpmBeeDenseGatedACT(nn.Module):
258
+ def __init__(self, config: CpmBeeConfig):
259
+ super().__init__()
260
+ self.w_0 = CpmBeeLinear(config.hidden_size, config.dim_ff, dtype=config.torch_dtype)
261
+ self.w_1 = CpmBeeLinear(config.hidden_size, config.dim_ff, dtype=config.torch_dtype)
262
+ self.act = torch.nn.GELU()
263
+
264
+ def forward(self, hidden_states: torch.Tensor):
265
+ """Transform an input tensor from one feature space to another via a nonlinear operation
266
+
267
+ Args:
268
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
269
+ """
270
+ gate_score = self.act(self.w_0(hidden_states))
271
+ hidden_states = self.w_1(hidden_states)
272
+
273
+ hidden_states = gate_score * hidden_states
274
+ return hidden_states
275
+
276
+
277
+ class CpmBeeFeedForward(nn.Module):
278
+ def __init__(self, config: CpmBeeConfig):
279
+ super().__init__()
280
+ self.w_in = CpmBeeDenseGatedACT(config)
281
+ if config.dropout_p is not None:
282
+ self.dropout = torch.nn.Dropout(config.dropout_p)
283
+ else:
284
+ self.dropout = None
285
+
286
+ self.w_out = CpmBeeLinear(config.dim_ff, config.hidden_size, dtype=config.torch_dtype)
287
+
288
+ def forward(self, hidden_states: torch.Tensor):
289
+ """
290
+ Args:
291
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
292
+ """
293
+ hidden_states = self.w_in(hidden_states)
294
+
295
+ if self.dropout is not None:
296
+ hidden_states = self.dropout(hidden_states)
297
+
298
+ hidden_states = self.w_out(hidden_states)
299
+
300
+ return hidden_states
301
+
302
+
303
+ class CpmBeeFFNBlock(nn.Module):
304
+ def __init__(self, config: CpmBeeConfig):
305
+ super().__init__()
306
+ self.layernorm_before_ffn = CpmBeeLayerNorm(config)
307
+ self.ffn = CpmBeeFeedForward(config)
308
+ if config.dropout_p:
309
+ self.dropout = torch.nn.Dropout(config.dropout_p)
310
+ else:
311
+ self.dropout = None
312
+
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.Tensor,
316
+ ):
317
+ """
318
+ Args:
319
+ hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
320
+ Hidden states before feed forward layer.
321
+ """
322
+ ln_outputs = self.layernorm_before_ffn(hidden_states)
323
+ outputs = self.ffn(ln_outputs)
324
+ if self.dropout is not None:
325
+ outputs = self.dropout(outputs)
326
+ hidden_states = (hidden_states + outputs) / 1.05
327
+ return hidden_states
328
+
329
+
330
+ class CpmBeeTransformerBlock(nn.Module):
331
+ def __init__(self, config: CpmBeeConfig, mask_att: bool = False, mask_ffn: bool = False):
332
+ super().__init__()
333
+ self.mask_att = mask_att
334
+ self.mask_ffn = mask_ffn
335
+
336
+ if not self.mask_att:
337
+ self.self_att = CpmBeeSelfAttentionBlock(config)
338
+ if not self.mask_ffn:
339
+ self.ffn = CpmBeeFFNBlock(config)
340
+
341
+ def forward(
342
+ self,
343
+ hidden_states: torch.Tensor,
344
+ attention_mask: torch.Tensor,
345
+ position_bias: Optional[torch.Tensor] = None,
346
+ output_attentions: Optional[bool] = False,
347
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
348
+ use_cache: Optional[bool] = None,
349
+ ):
350
+ """
351
+ Args:
352
+ hidden_states (`torch.Tensor`):
353
+ Input to the layer of shape `(batch, seq_len, dim_model)`
354
+ attention_mask (`torch.Tensor`):
355
+ Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
356
+ position_bias (`torch.Tensor`):
357
+ Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
358
+ output_attentions (`bool`, *optional*):
359
+ Whether or not to return the attentions tensors of all attention layers.
360
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
361
+ Cached past key and value projection states
362
+ use_cache (`bool`, *optional*):
363
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
364
+ (see `past_key_values`).
365
+ """
366
+ if not self.mask_att:
367
+ hidden_states = self.self_att(
368
+ hidden_states,
369
+ attention_mask=attention_mask,
370
+ position_bias=position_bias,
371
+ output_attentions=output_attentions,
372
+ past_key_values=past_key_values,
373
+ use_cache=use_cache,
374
+ )
375
+
376
+ hidden_states, attn_weights, current_key_value = hidden_states
377
+ else:
378
+ attn_weights, current_key_value = None, (None, None)
379
+
380
+ if not self.mask_ffn:
381
+ hidden_states = self.ffn(hidden_states)
382
+
383
+ return hidden_states, attn_weights, current_key_value
384
+
385
+
386
+ class CpmBeeEncoder(nn.Module):
387
+ def __init__(self, config: CpmBeeConfig):
388
+ super().__init__()
389
+ self.num_layers = config.num_hidden_layers
390
+ if config.mask_modules is not None:
391
+ assert len(config.mask_modules) == self.num_layers, "The total number of masks should equal to num_layers"
392
+ for mask_module in config.mask_modules:
393
+ assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
394
+ else:
395
+ config.mask_modules = [(False, False)] * self.num_layers
396
+
397
+ self.layers = nn.ModuleList(
398
+ [
399
+ CpmBeeTransformerBlock(
400
+ config, mask_att=config.mask_modules[ith][0], mask_ffn=config.mask_modules[ith][1]
401
+ )
402
+ for ith in range(self.num_layers)
403
+ ]
404
+ )
405
+
406
+ self.output_layernorm = CpmBeeLayerNorm(config)
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ attention_mask: torch.Tensor,
412
+ position_bias: torch.Tensor,
413
+ output_attentions: Optional[bool] = None,
414
+ output_hidden_states: Optional[bool] = None,
415
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
416
+ use_cache: Optional[bool] = None,
417
+ ):
418
+ """
419
+ Args:
420
+ hidden_states (`torch.Tensor`):
421
+ Input to the layer of shape `(batch, seq_len, dim_model)`
422
+ attention_mask (`torch.Tensor`):
423
+ Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
424
+ position_bias (`torch.Tensor`):
425
+ Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
426
+ output_attentions (`bool`, *optional*):
427
+ Whether or not to return the attentions tensors of all attention layers.
428
+ output_hidden_states (`bool`, *optional*):
429
+ Whether or not to return the hidden states of all layers.
430
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
431
+ Cached past key and value projection states
432
+ use_cache (`bool`, *optional*):