crumb commited on
Commit
ef49281
·
1 Parent(s): 82f052d

Upload 2 files

Browse files
configuration_switchllama.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ LLaMA model configuration"""
2
+
3
+ # from ...configuration_utils import PretrainedConfig
4
+ # from ...utils import logging
5
+ from transformers.models.llama.configuration_llama import *
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
11
+
12
+
13
+ class SwitchLlamaConfig(PretrainedConfig):
14
+ r"""
15
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
16
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
17
+ defaults will yield a similar configuration to that of the LLaMA-7B.
18
+
19
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
20
+ documentation from [`PretrainedConfig`] for more information.
21
+
22
+
23
+ Args:
24
+ vocab_size (`int`, *optional*, defaults to 32000):
25
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
26
+ `inputs_ids` passed when calling [`LlamaModel`]
27
+ hidden_size (`int`, *optional*, defaults to 4096):
28
+ Dimension of the hidden representations.
29
+ intermediate_size (`int`, *optional*, defaults to 11008):
30
+ Dimension of the MLP representations.
31
+ num_hidden_layers (`int`, *optional*, defaults to 32):
32
+ Number of hidden layers in the Transformer encoder.
33
+ num_attention_heads (`int`, *optional*, defaults to 32):
34
+ Number of attention heads for each attention layer in the Transformer encoder.
35
+ num_key_value_heads (`int`, *optional*):
36
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
37
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
38
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
39
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
40
+ by meanpooling all the original heads within that group. For more details checkout [this
41
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
42
+ `num_attention_heads`.
43
+ pretraining_tp (`int`, *optional*, defaults to `1`):
44
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
45
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
46
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
47
+ issue](https://github.com/pytorch/pytorch/issues/76232).
48
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49
+ The non-linear activation function (function or string) in the decoder.
50
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
51
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
52
+ Llama 2 up to 4096, CodeLlama up to 16384.
53
+ initializer_range (`float`, *optional*, defaults to 0.02):
54
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
55
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
56
+ The epsilon used by the rms normalization layers.
57
+ use_cache (`bool`, *optional*, defaults to `True`):
58
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
59
+ relevant if `config.is_decoder=True`.
60
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
61
+ Whether to tie weight embeddings
62
+ rope_theta (`float`, *optional*, defaults to 10000.0):
63
+ The base period of the RoPE embeddings.
64
+ rope_scaling (`Dict`, *optional*):
65
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
66
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
67
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
68
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
69
+ these scaling strategies behave:
70
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
71
+ experimental feature, subject to breaking API changes in future versions.
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import LlamaModel, LlamaConfig
77
+
78
+ >>> # Initializing a LLaMA llama-7b style configuration
79
+ >>> configuration = LlamaConfig()
80
+
81
+ >>> # Initializing a model from the llama-7b style configuration
82
+ >>> model = LlamaModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+ model_type = "switchllama"
88
+ keys_to_ignore_at_inference = ["past_key_values"]
89
+
90
+ def __init__(
91
+ self,
92
+ vocab_size=32000,
93
+ hidden_size=4096,
94
+ intermediate_size=11008,
95
+ num_hidden_layers=32,
96
+ num_attention_heads=32,
97
+ num_key_value_heads=None,
98
+ hidden_act="silu",
99
+ max_position_embeddings=2048,
100
+ initializer_range=0.02,
101
+ rms_norm_eps=1e-6,
102
+ use_cache=True,
103
+ pad_token_id=None,
104
+ bos_token_id=1,
105
+ eos_token_id=2,
106
+ pretraining_tp=1,
107
+ tie_word_embeddings=False,
108
+ rope_theta=10000.0,
109
+ rope_scaling=None,
110
+ # extra!
111
+ expert_capacity=64,
112
+ router_bias=False,
113
+ router_jitter_noise=0.01,
114
+ router_ignore_padding_tokens=False,
115
+ num_experts=8,
116
+ dropout_rate=0.01,
117
+ router_aux_loss_coef=0.001,
118
+ router_z_loss_coef=0.001,
119
+ **kwargs,
120
+ ):
121
+ self.router_aux_loss_coef=router_aux_loss_coef
122
+ self.router_z_loss_coef=router_z_loss_coef
123
+ self.dropout_rate = dropout_rate
124
+ self.num_experts = num_experts
125
+ self.router_ignore_padding_tokens = router_ignore_padding_tokens
126
+ self.router_jitter_noise = router_jitter_noise
127
+ self.router_bias = router_bias
128
+ self.expert_capacity = expert_capacity
129
+ self.vocab_size = vocab_size
130
+ self.max_position_embeddings = max_position_embeddings
131
+ self.hidden_size = hidden_size
132
+ self.intermediate_size = intermediate_size
133
+ self.num_hidden_layers = num_hidden_layers
134
+ self.num_attention_heads = num_attention_heads
135
+
136
+ # for backward compatibility
137
+ if num_key_value_heads is None:
138
+ num_key_value_heads = num_attention_heads
139
+
140
+ self.num_key_value_heads = num_key_value_heads
141
+ self.hidden_act = hidden_act
142
+ self.initializer_range = initializer_range
143
+ self.rms_norm_eps = rms_norm_eps
144
+ self.pretraining_tp = pretraining_tp
145
+ self.use_cache = use_cache
146
+ self.rope_theta = rope_theta
147
+ self.rope_scaling = rope_scaling
148
+ self._rope_scaling_validation()
149
+
150
+ super().__init__(
151
+ pad_token_id=pad_token_id,
152
+ bos_token_id=bos_token_id,
153
+ eos_token_id=eos_token_id,
154
+ tie_word_embeddings=tie_word_embeddings,
155
+ **kwargs,
156
+ )
157
+
158
+ def _rope_scaling_validation(self):
159
+ """
160
+ Validate the `rope_scaling` configuration.
161
+ """
162
+ if self.rope_scaling is None:
163
+ return
164
+
165
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
166
+ raise ValueError(
167
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
168
+ f"got {self.rope_scaling}"
169
+ )
170
+ rope_scaling_type = self.rope_scaling.get("type", None)
171
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
172
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
173
+ raise ValueError(
174
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
175
+ )
176
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
177
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
modeling_switchllama.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright idek
2
+
3
+ from transformers.models.llama.modeling_llama import *
4
+ from torch import nn
5
+ import torch
6
+ from configuration_switchllama import SwitchLlamaConfig
7
+
8
+
9
+ def router_z_loss_func(router_logits: torch.Tensor) -> float:
10
+ r"""
11
+ Compute the router z-loss implemented in PyTorch.
12
+
13
+ The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
14
+ It encourages router logits to remain small in an effort to improve stability.
15
+
16
+ Args:
17
+ router_logits (`float`):
18
+ Input logits of shape [batch_size, sequence_length, num_experts]
19
+
20
+ Returns:
21
+ Scalar router z-loss.
22
+ """
23
+ num_groups, tokens_per_group, _ = router_logits.shape
24
+ log_z = torch.logsumexp(router_logits, dim=-1)
25
+ z_loss = log_z**2
26
+ return torch.sum(z_loss) / (num_groups * tokens_per_group)
27
+
28
+
29
+ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
30
+ r"""
31
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
32
+
33
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
34
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
35
+ experts is too unbalanced.
36
+
37
+ Args:
38
+ router_probs (`torch.Tensor`):
39
+ Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].
40
+ expert_indices (`torch.Tensor`):
41
+ Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.
42
+
43
+ Returns:
44
+ The auxiliary loss.
45
+ """
46
+ num_experts = router_probs.shape[-1]
47
+
48
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
49
+ if expert_indices.dtype != torch.int64:
50
+ expert_indices = expert_indices.to(torch.int64)
51
+
52
+ if len(expert_indices.shape) == 2:
53
+ expert_indices = expert_indices.unsqueeze(2)
54
+
55
+ expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
56
+
57
+ # For a given token, determine if it was routed to a given expert.
58
+ expert_mask = torch.max(expert_mask, axis=-2).values
59
+
60
+ # cast to float32 otherwise mean will fail
61
+ expert_mask = expert_mask.to(torch.float32)
62
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
63
+
64
+ router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
65
+ return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
66
+
67
+
68
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
69
+ def _make_causal_mask(
70
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
71
+ ):
72
+ """
73
+ Make causal mask used for bi-directional self-attention.
74
+ """
75
+ bsz, tgt_len = input_ids_shape
76
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
77
+ mask_cond = torch.arange(mask.size(-1), device=device)
78
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
79
+ mask = mask.to(dtype)
80
+
81
+ if past_key_values_length > 0:
82
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
83
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
84
+
85
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
86
+ """
87
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
88
+ """
89
+ bsz, src_len = mask.size()
90
+ tgt_len = tgt_len if tgt_len is not None else src_len
91
+
92
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
93
+
94
+ inverted_mask = 1.0 - expanded_mask
95
+
96
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
97
+
98
+
99
+
100
+ class SwitchLlamaTop1Router(nn.Module):
101
+ """
102
+ Router using tokens choose top-1 experts assignment.
103
+
104
+ This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
105
+ (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
106
+ routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
107
+ token is processed by an expert**, or that each expert receives at least one token.
108
+
109
+ """
110
+
111
+ def __init__(self, config: SwitchLlamaConfig):
112
+ super().__init__()
113
+ self.num_experts = config.num_experts
114
+ self.expert_capacity = config.expert_capacity
115
+ self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
116
+ self.jitter_noise = config.router_jitter_noise
117
+ self.ignore_padding_tokens = config.router_ignore_padding_tokens
118
+
119
+ def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ r"""
121
+ Computes router probabilities from input hidden states.
122
+
123
+ Args:
124
+ hidden_states (`torch.Tensor`):
125
+ (batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
126
+ Returns:
127
+ router_probabilities (`torch.Tensor`):
128
+ Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
129
+ token and expert. Used for routing tokens to experts.
130
+ router_logits (`torch.Tensor`):
131
+ Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
132
+ This is used later for computing router z-loss.
133
+ """
134
+ if self.jitter_noise > 0:
135
+ # Get the lower and upper bound of the uniform distribution
136
+ # Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch
137
+ distrib_lower_bound = 1.0 - self.jitter_noise
138
+ distrib_upper_bound = 1.0 + self.jitter_noise
139
+
140
+ uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)
141
+ uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound)
142
+
143
+ uniform_distrib = uniform_distrib + distrib_upper_bound
144
+ # Multiply the token inputs by the uniform distribution - adding some noise
145
+ hidden_states *= uniform_distrib
146
+
147
+ # Shape: [num_groups, tokens_per_group, num_experts]
148
+ router_logits = self.classifier(hidden_states)
149
+
150
+ # Apply Softmax
151
+ router_probabilities = nn.functional.softmax(router_logits, dim=-1)
152
+ return router_probabilities, router_logits
153
+
154
+ def forward(self, hidden_states: torch.Tensor) -> Tuple:
155
+ r"""
156
+ Generic forward function for every Router class. Each Router expects to have the same input hidden states
157
+ (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
158
+ number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.
159
+
160
+ Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
161
+ `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
162
+ to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.
163
+
164
+ Args:
165
+ hidden_states (`torch.Tensor`) :
166
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
167
+ Returns:
168
+ Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs
169
+ and the router logits. The router probabilities and logits are required to compute the loss.
170
+ """
171
+ router_probs, router_logits = self._compute_router_probabilities(hidden_states)
172
+
173
+ expert_index = torch.argmax(router_probs, dim=-1)
174
+ expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
175
+
176
+ # Mask tokens outside expert capacity. Sum over each sequence
177
+ token_priority = torch.cumsum(expert_index, dim=-2)
178
+ # mask if the token routed to to the expert will overflow
179
+ expert_capacity_mask = token_priority <= self.expert_capacity
180
+ expert_index = expert_index * expert_capacity_mask
181
+
182
+ router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
183
+ return expert_index, router_probs, router_logits
184
+
185
+ class SwitchLlamaSparseMLP(nn.Module):
186
+ r"""
187
+ Implementation of the Switch Transformers Sparse MLP module.
188
+ """
189
+
190
+ def __init__(self, config: SwitchLlamaConfig, expert_class: nn.Module = LlamaMLP):
191
+ super().__init__()
192
+ # Step 1: Get the correct router according to its class
193
+ self.router = SwitchLlamaTop1Router(config)
194
+
195
+ # Step 2: Get the experts
196
+ self.experts = nn.ModuleDict()
197
+ for idx in range(config.num_experts):
198
+ self.experts[f"expert_{idx}"] = expert_class(config)
199
+
200
+ def forward(self, hidden_states):
201
+ r"""
202
+ Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
203
+
204
+ 1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
205
+ and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
206
+ hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
207
+
208
+ 2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
209
+ expert the corresponding hidden states.
210
+
211
+ """
212
+ # Step 1: Get the router_mask from the router as wel as the probabilities
213
+ router_mask, router_probs, router_logits = self.router(hidden_states)
214
+ expert_index = torch.argmax(router_mask, dim=-1)
215
+
216
+ # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
217
+ # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
218
+
219
+ next_states = hidden_states.clone()
220
+ for idx, expert in enumerate(self.experts.values()):
221
+ token_indices = router_mask[:, :, idx].bool()
222
+ next_states[token_indices] = expert(hidden_states[token_indices])
223
+
224
+ hidden_states = router_probs * next_states
225
+ return hidden_states, (router_logits, expert_index)
226
+
227
+ class SwitchLlamaLayerFF(nn.Module):
228
+ r"""
229
+ Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module.
230
+
231
+ Parameters:
232
+ config : ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model.
233
+ Initializing with a config file does not load the weights associated with the model, only the
234
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
235
+ is_sparse (`bool`):
236
+ Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not
237
+ """
238
+
239
+ def __init__(self, config: SwitchLlamaConfig, is_sparse=True):
240
+ super().__init__()
241
+ self.is_sparse = is_sparse
242
+
243
+ # Check if it is a sparse layer, if not then it is a dense layer
244
+ if not self.is_sparse:
245
+ self.mlp = LlamaMLP(config)
246
+ else:
247
+ self.mlp = SwitchLlamaSparseMLP(config)
248
+
249
+ # self.layer_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
+ self.dropout = nn.Dropout(config.dropout_rate)
251
+
252
+ def forward(self, hidden_states, output_router_logits=False):
253
+ # forwarded_states = self.layer_norm(hidden_states)
254
+ forwarded_states = self.mlp(hidden_states)
255
+
256
+ if isinstance(forwarded_states, tuple):
257
+ forwarded_states, router_tuple = forwarded_states
258
+ else:
259
+ router_tuple = None
260
+
261
+ output = hidden_states + self.dropout(forwarded_states)
262
+
263
+ if output_router_logits and router_tuple is not None:
264
+ output = (output, router_tuple)
265
+
266
+ return output
267
+
268
+ class SwitchLlamaDecoderLayer(nn.Module):
269
+ def __init__(self, config: SwitchLlamaConfig):
270
+ super().__init__()
271
+ self.hidden_size = config.hidden_size
272
+ self.self_attn = LlamaAttention(config=config)
273
+ self.mlp = SwitchLlamaLayerFF(config)
274
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
275
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
276
+
277
+ def forward(
278
+ self,
279
+ hidden_states: torch.Tensor,
280
+ attention_mask: Optional[torch.Tensor] = None,
281
+ position_ids: Optional[torch.LongTensor] = None,
282
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
283
+ output_attentions: Optional[bool] = False,
284
+ use_cache: Optional[bool] = False,
285
+ output_router_logits = True
286
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
287
+ """
288
+ Args:
289
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
290
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
291
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
292
+ output_attentions (`bool`, *optional*):
293
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
294
+ returned tensors for more detail.
295
+ use_cache (`bool`, *optional*):
296
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
297
+ (see `past_key_values`).
298
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
299
+ """
300
+
301
+ residual = hidden_states
302
+
303
+ hidden_states = self.input_layernorm(hidden_states)
304
+
305
+ # Self Attention
306
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
307
+ hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ position_ids=position_ids,
310
+ past_key_value=past_key_value,
311
+ output_attentions=output_attentions,
312
+ use_cache=use_cache,
313
+ )
314
+ hidden_states = residual + hidden_states
315
+
316
+ # Fully Connected
317
+ residual = hidden_states
318
+ hidden_states = self.post_attention_layernorm(hidden_states)
319
+ hidden_states = self.mlp(hidden_states, output_router_logits=output_router_logits)
320
+ if type(hidden_states)==tuple:
321
+ hidden_states, router_tuple = hidden_states
322
+ else:
323
+ router_tuple = (torch.tensor([0], device=hidden_states.device),)
324
+ hidden_states = residual + hidden_states
325
+
326
+ outputs = (hidden_states,)
327
+
328
+ if output_attentions:
329
+ outputs += (self_attn_weights,)
330
+
331
+ if use_cache:
332
+ outputs += (present_key_value,)
333
+
334
+ # if output_router_logits:
335
+ # outputs += (router_tuple,)
336
+ return outputs + (router_tuple,)
337
+
338
+ class SwitchLlamaPreTrainedModel(PreTrainedModel):
339
+ config_class = SwitchLlamaConfig
340
+ base_model_prefix = "model"
341
+ supports_gradient_checkpointing = True
342
+ _no_split_modules = ["SwitchLlamaDecoderLayer"]
343
+ _skip_keys_device_placement = "past_key_values"
344
+
345
+ def _init_weights(self, module):
346
+ std = self.config.initializer_range
347
+ if isinstance(module, nn.Linear):
348
+ module.weight.data.normal_(mean=0.0, std=std)
349
+ if module.bias is not None:
350
+ module.bias.data.zero_()
351
+ elif isinstance(module, nn.Embedding):
352
+ module.weight.data.normal_(mean=0.0, std=std)
353
+ if module.padding_idx is not None:
354
+ module.weight.data[module.padding_idx].zero_()
355
+
356
+ def _set_gradient_checkpointing(self, module, value=False):
357
+ if isinstance(module, LlamaModel):
358
+ module.gradient_checkpointing = value
359
+
360
+
361
+ class SwitchLlamaModel(SwitchLlamaPreTrainedModel):
362
+ """
363
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
364
+
365
+ Args:
366
+ config: SwitchLlamaConfig
367
+ """
368
+
369
+ def __init__(self, config: SwitchLlamaConfig):
370
+ super().__init__(config)
371
+ self.padding_idx = config.pad_token_id
372
+ self.vocab_size = config.vocab_size
373
+
374
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
375
+ self.layers = nn.ModuleList([SwitchLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
376
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
377
+
378
+ self.gradient_checkpointing = False
379
+ # Initialize weights and apply final processing
380
+ self.post_init()
381
+
382
+ def get_input_embeddings(self):
383
+ return self.embed_tokens
384
+
385
+ def set_input_embeddings(self, value):
386
+ self.embed_tokens = value
387
+
388
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
389
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
390
+ # create causal mask
391
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
392
+ combined_attention_mask = None
393
+ if input_shape[-1] > 1:
394
+ combined_attention_mask = _make_causal_mask(
395
+ input_shape,
396
+ inputs_embeds.dtype,
397
+ device=inputs_embeds.device,
398
+ past_key_values_length=past_key_values_length,
399
+ )
400
+
401
+ if attention_mask is not None:
402
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
403
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
404
+ inputs_embeds.device
405
+ )
406
+ combined_attention_mask = (
407
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
408
+ )
409
+
410
+ return combined_attention_mask
411
+
412
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
413
+ def forward(
414
+ self,
415
+ input_ids: torch.LongTensor = None,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ position_ids: Optional[torch.LongTensor] = None,
418
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
419
+ inputs_embeds: Optional[torch.FloatTensor] = None,
420
+ use_cache: Optional[bool] = None,
421
+ output_attentions: Optional[bool] = None,
422
+ output_hidden_states: Optional[bool] = None,
423
+ return_dict: Optional[bool] = None,
424
+ output_router_logits = False
425
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
426
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
427
+ output_hidden_states = (
428
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
429
+ )
430
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
431
+
432
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
433
+ all_router_probs = () if output_router_logits else None
434
+ # retrieve input_ids and inputs_embeds
435
+ if input_ids is not None and inputs_embeds is not None:
436
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
437
+ elif input_ids is not None:
438
+ batch_size, seq_length = input_ids.shape
439
+ elif inputs_embeds is not None:
440
+ batch_size, seq_length, _ = inputs_embeds.shape
441
+ else:
442
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
443
+
444
+ seq_length_with_past = seq_length
445
+ past_key_values_length = 0
446
+
447
+ if past_key_values is not None:
448
+ past_key_values_length = past_key_values[0][0].shape[2]
449
+ seq_length_with_past = seq_length_with_past + past_key_values_length
450
+
451
+ if position_ids is None:
452
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
453
+ position_ids = torch.arange(
454
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
455
+ )
456
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
457
+ else:
458
+ position_ids = position_ids.view(-1, seq_length).long()
459
+
460
+ if inputs_embeds is None:
461
+ inputs_embeds = self.embed_tokens(input_ids)
462
+ # embed positions
463
+ if attention_mask is None:
464
+ attention_mask = torch.ones(
465
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
466
+ )
467
+ attention_mask = self._prepare_decoder_attention_mask(
468
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
469
+ )
470
+
471
+ hidden_states = inputs_embeds
472
+
473
+ if self.gradient_checkpointing and self.training:
474
+ if use_cache:
475
+ logger.warning_once(
476
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
477
+ )
478
+ use_cache = False
479
+
480
+ # decoder layers
481
+ all_hidden_states = () if output_hidden_states else None
482
+ all_self_attns = () if output_attentions else None
483
+ next_decoder_cache = () if use_cache else None
484
+
485
+ for idx, decoder_layer in enumerate(self.layers):
486
+ if output_hidden_states:
487
+ all_hidden_states += (hidden_states,)
488
+
489
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
490
+
491
+ if self.gradient_checkpointing and self.training:
492
+
493
+ def create_custom_forward(module):
494
+ def custom_forward(*inputs):
495
+ # None for past_key_value
496
+ return module(*inputs, past_key_value, output_attentions)
497
+
498
+ return custom_forward
499
+
500
+ layer_outputs = torch.utils.checkpoint.checkpoint(
501
+ create_custom_forward(decoder_layer),
502
+ hidden_states,
503
+ attention_mask,
504
+ position_ids,
505
+ )
506
+ else:
507
+ layer_outputs = decoder_layer(
508
+ hidden_states,
509
+ attention_mask=attention_mask,
510
+ position_ids=position_ids,
511
+ past_key_value=past_key_value,
512
+ output_attentions=output_attentions,
513
+ use_cache=use_cache,
514
+ output_router_logits=output_router_logits
515
+ )
516
+
517
+ hidden_states = layer_outputs[0]
518
+ router_probs = layer_outputs[-1]
519
+
520
+ if use_cache:
521
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
522
+
523
+ if output_attentions:
524
+ all_self_attns += (layer_outputs[1],)
525
+
526
+ if output_router_logits:
527
+ all_router_probs = all_router_probs + (router_probs,)
528
+ hidden_states = self.norm(hidden_states)
529
+
530
+ # add hidden states from the last decoder layer
531
+ if output_hidden_states:
532
+ all_hidden_states += (hidden_states,)
533
+
534
+ next_cache = next_decoder_cache if use_cache else None
535
+ if not return_dict:
536
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
537
+
538
+ from transformers.models.switch_transformers.modeling_switch_transformers import MoEModelOutputWithPastAndCrossAttentions
539
+ return MoEModelOutputWithPastAndCrossAttentions(
540
+ last_hidden_state=hidden_states,
541
+ past_key_values=next_cache,
542
+ hidden_states=all_hidden_states,
543
+ attentions=all_self_attns,
544
+ router_probs=all_router_probs,
545
+ )
546
+
547
+ class SwitchLlamaForCausalLM(SwitchLlamaPreTrainedModel):
548
+ _tied_weights_keys = ["lm_head.weight"]
549
+
550
+ def __init__(self, config):
551
+ super().__init__(config)
552
+ self.model = SwitchLlamaModel(config)
553
+ self.vocab_size = config.vocab_size
554
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
555
+
556
+ self.router_z_loss_coef = config.router_z_loss_coef
557
+ self.router_aux_loss_coef = config.router_aux_loss_coef
558
+ # Initialize weights and apply final processing
559
+ self.post_init()
560
+ def _unpack_router_logits(self, router_outputs):
561
+ total_router_logits = []
562
+ total_expert_indexes = []
563
+ for router_output in router_outputs:
564
+ if len(router_output[0].shape) > 1:
565
+ router_logits, expert_indexes = router_output
566
+ total_router_logits.append(router_logits)
567
+ total_expert_indexes.append(expert_indexes)
568
+ return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
569
+
570
+
571
+ def get_input_embeddings(self):
572
+ return self.model.embed_tokens
573
+
574
+ def set_input_embeddings(self, value):
575
+ self.model.embed_tokens = value
576
+
577
+ def get_output_embeddings(self):
578
+ return self.lm_head
579
+
580
+ def set_output_embeddings(self, new_embeddings):
581
+ self.lm_head = new_embeddings
582
+
583
+ def set_decoder(self, decoder):
584
+ self.model = decoder
585
+
586
+ def get_decoder(self):
587
+ return self.model
588
+
589
+ def forward(
590
+ self,
591
+ input_ids: torch.LongTensor = None,
592
+ attention_mask: Optional[torch.Tensor] = None,
593
+ position_ids: Optional[torch.LongTensor] = None,
594
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
595
+ inputs_embeds: Optional[torch.FloatTensor] = None,
596
+ labels: Optional[torch.LongTensor] = None,
597
+ use_cache: Optional[bool] = None,
598
+ output_attentions: Optional[bool] = None,
599
+ output_hidden_states: Optional[bool] = None,
600
+ return_dict: Optional[bool] = None,
601
+ output_router_logits = False,
602
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
603
+ r"""
604
+ Args:
605
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
606
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
607
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
608
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
609
+
610
+ Returns:
611
+
612
+ Example:
613
+
614
+ ```python
615
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
616
+
617
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
618
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
619
+
620
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
621
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
622
+
623
+ >>> # Generate
624
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
625
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
626
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
627
+ ```"""
628
+
629
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
630
+ output_hidden_states = (
631
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
632
+ )
633
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
634
+
635
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
636
+ outputs = self.model(
637
+ input_ids=input_ids,
638
+ attention_mask=attention_mask,
639
+ position_ids=position_ids,
640
+ past_key_values=past_key_values,
641
+ inputs_embeds=inputs_embeds,
642
+ use_cache=use_cache,
643
+ output_attentions=output_attentions,
644
+ output_hidden_states=output_hidden_states,
645
+ return_dict=return_dict,
646
+ output_router_logits=output_router_logits
647
+ )
648
+
649
+ hidden_states = outputs[0]
650
+ if self.config.pretraining_tp > 1:
651
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
652
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
653
+ logits = torch.cat(logits, dim=-1)
654
+ else:
655
+ logits = self.lm_head(hidden_states)
656
+ logits = logits.float()
657
+
658
+ loss = None
659
+ decoder_z_loss = None
660
+ decoder_aux_loss = None
661
+
662
+ if output_router_logits:
663
+ decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(outputs[-1])
664
+ decoder_z_loss = router_z_loss_func(decoder_router_logits)
665
+ decoder_router_probs = nn.Softmax(dim=-1)(decoder_router_logits)
666
+ decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes)
667
+
668
+ if labels is not None:
669
+ # Shift so that tokens < n predict n
670
+ shift_logits = logits[..., :-1, :].contiguous()
671
+ shift_labels = labels[..., 1:].contiguous()
672
+ # Flatten the tokens
673
+ loss_fct = CrossEntropyLoss()
674
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
675
+ shift_labels = shift_labels.view(-1)
676
+ # Enable model parallelism
677
+ shift_labels = shift_labels.to(shift_logits.device)
678
+ loss = loss_fct(shift_logits, shift_labels)
679
+
680
+ ##########################
681
+ if output_router_logits:
682
+ z_loss = self.router_z_loss_coef * decoder_z_loss
683
+ aux_loss = self.router_aux_loss_coef * decoder_aux_loss
684
+ loss = loss + z_loss + aux_loss
685
+ #########################
686
+ if not return_dict:
687
+ output = (logits,) + outputs[1:]
688
+ return (loss,) + output if loss is not None else output
689
+
690
+ return CausalLMOutputWithPast(
691
+ loss=loss,
692
+ logits=logits,
693
+ past_key_values=outputs.past_key_values,
694
+ hidden_states=outputs.hidden_states,
695
+ attentions=outputs.attentions,
696
+ )
697
+
698
+ def prepare_inputs_for_generation(
699
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
700
+ ):
701
+ if past_key_values:
702
+ input_ids = input_ids[:, -1:]
703
+
704
+ position_ids = kwargs.get("position_ids", None)
705
+ if attention_mask is not None and position_ids is None:
706
+ # create position_ids on the fly for batch generation
707
+ position_ids = attention_mask.long().cumsum(-1) - 1
708
+ position_ids.masked_fill_(attention_mask == 0, 1)
709
+ if past_key_values:
710
+ position_ids = position_ids[:, -1].unsqueeze(-1)
711
+
712
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
713
+ if inputs_embeds is not None and past_key_values is None:
714
+ model_inputs = {"inputs_embeds": inputs_embeds}
715
+ else:
716
+ model_inputs = {"input_ids": input_ids}
717
+
718
+ model_inputs.update(
719
+ {
720
+ "position_ids": position_ids,
721
+ "past_key_values": past_key_values,
722
+ "use_cache": kwargs.get("use_cache"),
723
+ "attention_mask": attention_mask,
724
+ }
725
+ )
726
+ return model_inputs
727
+
728
+ @staticmethod
729
+ def _reorder_cache(past_key_values, beam_idx):
730
+ reordered_past = ()
731
+ for layer_past in past_key_values:
732
+ reordered_past += (
733
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
734
+ )
735
+ return reordered_past