Qubitium commited on
Commit
51673c7
1 Parent(s): 517efb2

Create modeling_dbrx.py

Browse files
Files changed (1) hide show
  1. modeling_dbrx.py +1447 -0
modeling_dbrx.py ADDED
@@ -0,0 +1,1447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Dbrx model."""
2
+
3
+ import math
4
+ import warnings
5
+ from copy import deepcopy
6
+ from functools import partial
7
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from transformers.modeling_outputs import (MoeCausalLMOutputWithPast,
16
+ MoeModelOutputWithPast)
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import is_flash_attn_2_available, logging
19
+
20
+ from .configuration_dbrx import DbrxAttentionConfig, DbrxConfig, DbrxFFNConfig
21
+
22
+ if is_flash_attn_2_available():
23
+ try:
24
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
25
+ from flash_attn.bert_padding import pad_input # noqa
26
+ from flash_attn.bert_padding import index_first_axis, unpad_input
27
+ except:
28
+ pass
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ _CONFIG_FOR_DOC = 'DbrxConfig'
33
+
34
+ #############################################################################
35
+ # Copied from LLaMaRotaryEmbedding
36
+ #############################################################################
37
+
38
+
39
+ class DbrxRotaryEmbedding(nn.Module):
40
+
41
+ def __init__(self,
42
+ dim: int,
43
+ max_position_embeddings: int = 2048,
44
+ base: float = 10000.0,
45
+ scaling_factor: float = 1.0):
46
+ super().__init__()
47
+ self.scaling_factor = scaling_factor
48
+ self.dim = dim
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.base = base
51
+ inv_freq = 1.0 / (self.base**(
52
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
53
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
54
+ # For BC we register cos and sin cached
55
+ self.max_seq_len_cached = max_position_embeddings
56
+
57
+ @torch.no_grad()
58
+ def forward(
59
+ self, x: torch.Tensor, position_ids: torch.LongTensor
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ # x: [bs, num_attention_heads, seq_len, head_size]
62
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
63
+ position_ids.shape[0], -1, 1)
64
+ position_ids_expanded = position_ids[:, None, :].float()
65
+ # Force float32 since bfloat16 loses precision on long contexts
66
+ # See https://github.com/huggingface/transformers/pull/29285
67
+ device_type = x.device.type
68
+ device_type = device_type if isinstance(
69
+ device_type, str) and device_type != 'mps' else 'cpu'
70
+ with torch.autocast(device_type=device_type, enabled=False):
71
+ freqs = (inv_freq_expanded.float()
72
+ @ position_ids_expanded.float()).transpose(1, 2)
73
+ emb = torch.cat((freqs, freqs), dim=-1)
74
+ cos = emb.cos()
75
+ sin = emb.sin()
76
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
77
+
78
+
79
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
80
+ """Rotates half the hidden dims of the input."""
81
+ x1 = x[..., :x.shape[-1] // 2]
82
+ x2 = x[..., x.shape[-1] // 2:]
83
+ return torch.cat((-x2, x1), dim=-1)
84
+
85
+
86
+ def apply_rotary_pos_emb(
87
+ q: torch.Tensor,
88
+ k: torch.Tensor,
89
+ cos: torch.Tensor,
90
+ sin: torch.Tensor,
91
+ unsqueeze_dim: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """Applies Rotary Position Embedding to the query and key tensors.
93
+
94
+ Args:
95
+ q (`torch.Tensor`): The query tensor.
96
+ k (`torch.Tensor`): The key tensor.
97
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
98
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
99
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
100
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos and
101
+ sin so that they can be properly broadcasted to the dimensions of q and k. For example, note
102
+ that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and
103
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
104
+ cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have
105
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
106
+
107
+ Returns:
108
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
109
+ """
110
+ cos = cos.unsqueeze(unsqueeze_dim)
111
+ sin = sin.unsqueeze(unsqueeze_dim)
112
+ q_embed = (q * cos) + (rotate_half(q) * sin)
113
+ k_embed = (k * cos) + (rotate_half(k) * sin)
114
+ return q_embed, k_embed
115
+
116
+
117
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
118
+ """Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
119
+
120
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
121
+ (batch, num_attention_heads, seqlen, head_dim)
122
+ """
123
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
124
+ if n_rep == 1:
125
+ return hidden_states
126
+ hidden_states = hidden_states[:, :,
127
+ None, :, :].expand(batch, num_key_value_heads,
128
+ n_rep, slen, head_dim)
129
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
130
+ head_dim)
131
+
132
+
133
+ #############################################################################
134
+
135
+ #############################################################################
136
+ # Modified from modeling_mixtral
137
+ #############################################################################
138
+
139
+
140
+ def load_balancing_loss_func(
141
+ gate_logits: torch.Tensor,
142
+ num_experts: int,
143
+ top_k: int,
144
+ attention_mask: Optional[torch.Tensor],
145
+ ) -> torch.Tensor:
146
+ r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
147
+
148
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
149
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
150
+ experts is too unbalanced.
151
+
152
+ Args:
153
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
154
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
155
+ shape [batch_size X sequence_length, num_experts].
156
+ num_experts (`int`):
157
+ Number of experts.
158
+ top_k (`int`):
159
+ The number of experts each token is routed to.
160
+ attention_mask (`torch.Tensor`, None):
161
+ The attention_mask used in forward function
162
+ shape [batch_size X sequence_length] if not None.
163
+
164
+ Returns:
165
+ The auxiliary loss.
166
+ """
167
+ if gate_logits is None or not isinstance(gate_logits, tuple):
168
+ return torch.tensor(0.0)
169
+
170
+ if isinstance(gate_logits, tuple):
171
+ compute_device = gate_logits[0].device
172
+ concatenated_gate_logits = torch.cat(
173
+ [layer_gate.to(compute_device) for layer_gate in gate_logits],
174
+ dim=0)
175
+
176
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits,
177
+ dim=-1)
178
+
179
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
180
+
181
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
182
+
183
+ if attention_mask is None:
184
+ # Compute the percentage of tokens routed to each experts
185
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
186
+
187
+ # Compute the average probability of routing to these experts
188
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
189
+ else:
190
+ batch_size, sequence_length = attention_mask.shape
191
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (
192
+ batch_size * sequence_length)
193
+
194
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
195
+ expert_attention_mask = (attention_mask[None, :, :, None, None].expand(
196
+ (num_hidden_layers, batch_size, sequence_length, top_k,
197
+ num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
198
+
199
+ # Compute the percentage of tokens routed to each experts
200
+ tokens_per_expert = torch.sum(
201
+ expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
202
+ expert_attention_mask, dim=0)
203
+
204
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
205
+ router_per_expert_attention_mask = (
206
+ attention_mask[None, :, :, None].expand(
207
+ (num_hidden_layers, batch_size, sequence_length,
208
+ num_experts)).reshape(-1, num_experts).to(compute_device))
209
+
210
+ # Compute the average probability of routing to these experts
211
+ router_prob_per_expert = torch.sum(
212
+ routing_weights * router_per_expert_attention_mask,
213
+ dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
214
+
215
+ overall_loss = torch.sum(tokens_per_expert *
216
+ router_prob_per_expert.unsqueeze(0))
217
+ return overall_loss * num_experts
218
+
219
+
220
+ #############################################################################
221
+
222
+
223
+ def resolve_ffn_act_fn(
224
+ ffn_act_fn: dict) -> Callable[[torch.Tensor], torch.Tensor]:
225
+ """Resolve the activation function for the feed-forward network.
226
+
227
+ Args:
228
+ ffn_act_fn (dict): The configuration dictionary for the activation function.
229
+ The dict config must specify the 'name' of a torch.nn.functional activation
230
+ function. All of other key values pairs are bound to the function as a partial.
231
+
232
+ Returns:
233
+ Callable[[torch.Tensor], torch.Tensor]: The activation function.
234
+ """
235
+ config = deepcopy(ffn_act_fn)
236
+ name = config.pop('name')
237
+ if not hasattr(nn.functional, name):
238
+ raise ValueError(f'Unrecognised activation function name ({name}).')
239
+ act = getattr(nn.functional, name)
240
+ return partial(act, **config)
241
+
242
+
243
+ #############################################################################
244
+ # Copied from LLaMaAttention
245
+ #############################################################################
246
+
247
+
248
+ def _get_unpad_data(attention_mask: torch.Tensor):
249
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
250
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
251
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
252
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
253
+ (1, 0))
254
+ return (
255
+ indices,
256
+ cu_seqlens,
257
+ max_seqlen_in_batch,
258
+ )
259
+
260
+
261
+ class DbrxAttention(nn.Module):
262
+ """Multi-head self attention."""
263
+
264
+ def __init__(self,
265
+ hidden_size: int,
266
+ num_heads: int,
267
+ max_position_embeddings: int,
268
+ attn_config: DbrxAttentionConfig,
269
+ block_idx: Optional[int] = None):
270
+ super().__init__()
271
+ self.hidden_size = hidden_size
272
+ self.num_heads = num_heads
273
+ self.head_dim = self.hidden_size // self.num_heads
274
+ self.max_position_embeddings = max_position_embeddings
275
+ self.block_idx = block_idx
276
+ self.config = attn_config
277
+ if block_idx is None:
278
+ logger.warning_once(
279
+ f'Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will '
280
+ +
281
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` '
282
+ + 'when creating this class.')
283
+
284
+ self.attn_pdrop = attn_config.attn_pdrop
285
+ self.clip_qkv = attn_config.clip_qkv
286
+ self.num_key_value_heads = attn_config.kv_n_heads
287
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
288
+ self.rope_theta = attn_config.rope_theta
289
+
290
+ self.q_proj = nn.Linear(self.hidden_size,
291
+ self.hidden_size,
292
+ bias=False)
293
+ self.k_proj = nn.Linear(self.hidden_size,
294
+ self.num_key_value_heads * self.head_dim,
295
+ bias=False)
296
+ self.v_proj = nn.Linear(self.hidden_size,
297
+ self.num_key_value_heads * self.head_dim,
298
+ bias=False)
299
+ self.out_proj = nn.Linear(self.hidden_size,
300
+ self.hidden_size,
301
+ bias=False)
302
+ self.rotary_emb = DbrxRotaryEmbedding(
303
+ self.head_dim,
304
+ max_position_embeddings=self.max_position_embeddings,
305
+ base=self.rope_theta,
306
+ )
307
+
308
+ def forward(
309
+ self,
310
+ hidden_states: torch.Tensor,
311
+ position_ids: torch.LongTensor,
312
+ attention_mask: Optional[torch.Tensor] = None,
313
+ past_key_value: Optional[Cache] = None,
314
+ output_attentions: bool = False,
315
+ use_cache: bool = False,
316
+ cache_position: Optional[torch.LongTensor] = None,
317
+ **kwargs: Any,
318
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
319
+ bsz, q_len, _ = hidden_states.size()
320
+
321
+ query_states = self.q_proj(hidden_states)
322
+ key_states = self.k_proj(hidden_states)
323
+ value_states = self.v_proj(hidden_states)
324
+ if self.clip_qkv is not None:
325
+ query_states = query_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
326
+ key_states = key_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
327
+ value_states = value_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
328
+
329
+ query_states = query_states.view(bsz, q_len, self.num_heads,
330
+ self.head_dim).transpose(1, 2)
331
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
332
+ self.head_dim).transpose(1, 2)
333
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
334
+ self.head_dim).transpose(1, 2)
335
+
336
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
337
+ cos, sin = self.rotary_emb(value_states, position_ids)
338
+ query_states, key_states = apply_rotary_pos_emb(query_states,
339
+ key_states, cos, sin)
340
+
341
+ if past_key_value is not None:
342
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
343
+ cache_kwargs = {
344
+ 'sin': sin,
345
+ 'cos': cos,
346
+ 'cache_position': cache_position
347
+ }
348
+ key_states, value_states = past_key_value.update(
349
+ key_states, value_states, self.block_idx, cache_kwargs)
350
+
351
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
352
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
353
+
354
+ attn_weights = torch.matmul(query_states, key_states.transpose(
355
+ 2, 3)) / math.sqrt(self.head_dim)
356
+
357
+ if attention_mask is not None: # no matter the length, we just slice it
358
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
359
+ attn_weights = attn_weights + causal_mask
360
+
361
+ # upcast attention to fp32
362
+ attn_weights = nn.functional.softmax(attn_weights,
363
+ dim=-1,
364
+ dtype=torch.float32).to(
365
+ query_states.dtype)
366
+ attn_weights = nn.functional.dropout(attn_weights,
367
+ p=self.attn_pdrop,
368
+ training=self.training)
369
+ attn_output = torch.matmul(attn_weights, value_states)
370
+
371
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
372
+ raise ValueError(
373
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
374
+ + f' {attn_output.size()}')
375
+
376
+ attn_output = attn_output.transpose(1, 2).contiguous()
377
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
378
+ attn_output = self.out_proj(attn_output)
379
+
380
+ if not output_attentions:
381
+ attn_weights = None
382
+
383
+ return attn_output, attn_weights, past_key_value
384
+
385
+
386
+ class DbrxFlashAttention2(DbrxAttention):
387
+ """Dbrx flash attention module.
388
+
389
+ This module inherits from `DbrxAttention` as the weights of the module stays
390
+ untouched. The only required change would be on the forward pass where it
391
+ calls the public API of flash attention.
392
+ """
393
+
394
+ def __init__(self, *args: Any, **kwargs: Any):
395
+ if not is_flash_attn_2_available():
396
+ raise ImportError(
397
+ 'Flash Attention 2 is not available. Please install it with `pip install flash-attn`.'
398
+ )
399
+
400
+ super().__init__(*args, **kwargs)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states: torch.Tensor,
405
+ attention_mask: Optional[torch.LongTensor] = None,
406
+ position_ids: Optional[torch.LongTensor] = None,
407
+ past_key_value: Optional[Cache] = None,
408
+ output_attentions: bool = False,
409
+ use_cache: bool = False,
410
+ cache_position: Optional[torch.LongTensor] = None,
411
+ **kwargs: Any,
412
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
413
+ Optional[Tuple[torch.Tensor]]]:
414
+ logger.info(
415
+ 'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
416
+ )
417
+ output_attentions = False
418
+
419
+ bsz, q_len, _ = hidden_states.size()
420
+
421
+ query_states = self.q_proj(hidden_states)
422
+ key_states = self.k_proj(hidden_states)
423
+ value_states = self.v_proj(hidden_states)
424
+ if self.clip_qkv is not None:
425
+ query_states = query_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
426
+ key_states = key_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
427
+ value_states = value_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
428
+
429
+ # Flash attention requires the input to have the shape
430
+ # batch_size x seq_length x head_dim x hidden_dim
431
+ # therefore we just need to keep the original shape
432
+ query_states = query_states.view(bsz, q_len, self.num_heads,
433
+ self.head_dim).transpose(1, 2)
434
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
435
+ self.head_dim).transpose(1, 2)
436
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
437
+ self.head_dim).transpose(1, 2)
438
+
439
+ cos, sin = self.rotary_emb(value_states, position_ids)
440
+ query_states, key_states = apply_rotary_pos_emb(query_states,
441
+ key_states, cos, sin)
442
+
443
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
444
+
445
+ if past_key_value is not None:
446
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
447
+ cache_kwargs = {
448
+ 'sin': sin,
449
+ 'cos': cos,
450
+ 'cache_position': cache_position
451
+ }
452
+ key_states, value_states = past_key_value.update(
453
+ key_states, value_states, self.block_idx, cache_kwargs)
454
+
455
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
456
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
457
+ # to be able to avoid many of these transpose/reshape/view.
458
+ query_states = query_states.transpose(1, 2)
459
+ key_states = key_states.transpose(1, 2)
460
+ value_states = value_states.transpose(1, 2)
461
+
462
+ dropout_rate = self.attn_pdrop if self.training else 0.0
463
+
464
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
465
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
466
+ # cast them back in the correct dtype just to be sure everything works as expected.
467
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
468
+ # in fp32. (LlamaRMSNorm handles it correctly)
469
+ input_dtype = query_states.dtype
470
+ if input_dtype == torch.float32:
471
+ if torch.is_autocast_enabled():
472
+ target_dtype = torch.get_autocast_gpu_dtype()
473
+ # Handle the case where the model is quantized
474
+ elif hasattr(self.config, '_pre_quantization_dtype'):
475
+ target_dtype = self.config._pre_quantization_dtype
476
+ else:
477
+ target_dtype = query_states.dtype
478
+
479
+ logger.warning_once(
480
+ f'The input hidden states seems to be silently casted in float32, this might be '
481
+ +
482
+ f'related to the fact you have upcasted embedding or layer norm layers in '
483
+ + f'float32. We will cast back the input in {target_dtype}.')
484
+
485
+ query_states = query_states.to(target_dtype)
486
+ key_states = key_states.to(target_dtype)
487
+ value_states = value_states.to(target_dtype)
488
+
489
+ attn_output = self._flash_attention_forward(
490
+ query_states,
491
+ key_states,
492
+ value_states,
493
+ attention_mask,
494
+ q_len,
495
+ dropout=dropout_rate,
496
+ )
497
+
498
+ attn_output = attn_output.reshape(bsz, q_len,
499
+ self.hidden_size).contiguous()
500
+ attn_output = self.out_proj(attn_output)
501
+
502
+ if not output_attentions:
503
+ attn_weights = None
504
+
505
+ return attn_output, attn_weights, past_key_value # type: ignore
506
+
507
+ def _flash_attention_forward(
508
+ self,
509
+ query_states: torch.Tensor,
510
+ key_states: torch.Tensor,
511
+ value_states: torch.Tensor,
512
+ attention_mask: Union[torch.LongTensor, None],
513
+ query_length: int,
514
+ dropout: float = 0.0,
515
+ softmax_scale: Optional[float] = None,
516
+ ):
517
+ """Use FlashAttention, stripping padding tokens if necessary.
518
+
519
+ Args:
520
+ query_states (torch.Tensor): Input query states to be passed to Flash Attention API
521
+ key_states (torch.Tensor): Input key states to be passed to Flash Attention API
522
+ value_states (torch.Tensor): Input value states to be passed to Flash Attention API
523
+ attention_mask (torch.LongTensor | None): The padding mask - corresponds to a tensor of size
524
+ (batch_size, seq_len) where 0 stands for the position of padding tokens and 1
525
+ for the position of non-padding tokens.
526
+ query_length (int): The length of the query sequence
527
+ dropout (float): Attention dropout
528
+ softmax_scale (float, optional): The scaling of QK^T before applying softmax.
529
+ Defaults to 1 / sqrt(head_dim)
530
+ """
531
+ causal = True
532
+ # Contains at least one padding token in the sequence
533
+ if attention_mask is not None:
534
+ batch_size = query_states.shape[0]
535
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
536
+ query_states, key_states, value_states, attention_mask,
537
+ query_length)
538
+
539
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
540
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
541
+
542
+ attn_output_unpad = flash_attn_varlen_func(
543
+ query_states,
544
+ key_states,
545
+ value_states,
546
+ cu_seqlens_q=cu_seqlens_q,
547
+ cu_seqlens_k=cu_seqlens_k,
548
+ max_seqlen_q=max_seqlen_in_batch_q,
549
+ max_seqlen_k=max_seqlen_in_batch_k,
550
+ dropout_p=dropout,
551
+ softmax_scale=softmax_scale,
552
+ causal=causal,
553
+ )
554
+
555
+ attn_output = pad_input(
556
+ attn_output_unpad,
557
+ indices_q,
558
+ batch_size,
559
+ query_length,
560
+ )
561
+ else:
562
+ attn_output = flash_attn_func(
563
+ query_states,
564
+ key_states,
565
+ value_states,
566
+ dropout,
567
+ softmax_scale=softmax_scale,
568
+ causal=causal,
569
+ )
570
+
571
+ return attn_output
572
+
573
+ def _upad_input(self, query_layer: torch.Tensor, key_layer: torch.Tensor,
574
+ value_layer: torch.Tensor, attention_mask: torch.Tensor,
575
+ query_length: int):
576
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
577
+ attention_mask)
578
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
579
+
580
+ key_layer = index_first_axis(
581
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
582
+ head_dim), indices_k)
583
+ value_layer = index_first_axis(
584
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
585
+ head_dim), indices_k)
586
+ if query_length == kv_seq_len:
587
+ query_layer = index_first_axis(
588
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
589
+ head_dim), indices_k)
590
+ cu_seqlens_q = cu_seqlens_k
591
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
592
+ indices_q = indices_k
593
+ elif query_length == 1:
594
+ max_seqlen_in_batch_q = 1
595
+ cu_seqlens_q = torch.arange(
596
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
597
+ ) # There is a memcpy here, that is very bad.
598
+ indices_q = cu_seqlens_q[:-1]
599
+ query_layer = query_layer.squeeze(1)
600
+ else:
601
+ # The -q_len: slice assumes left padding.
602
+ attention_mask = attention_mask[:, -query_length:]
603
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
604
+ query_layer, attention_mask)
605
+
606
+ return (
607
+ query_layer,
608
+ key_layer,
609
+ value_layer,
610
+ indices_q,
611
+ (cu_seqlens_q, cu_seqlens_k),
612
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
613
+ )
614
+
615
+
616
+ DBRX_ATTENTION_CLASSES = {
617
+ 'eager': DbrxAttention,
618
+ 'flash_attention_2': DbrxFlashAttention2,
619
+ }
620
+
621
+
622
+ class DbrxNormAttentionNorm(nn.Module):
623
+
624
+ def __init__(
625
+ self,
626
+ hidden_size: int,
627
+ num_heads: int,
628
+ max_position_embeddings: int,
629
+ resid_pdrop: float,
630
+ attn_implementation: str,
631
+ attn_config: DbrxAttentionConfig,
632
+ block_idx: Optional[int] = None,
633
+ ):
634
+ super().__init__()
635
+ self.block_idx = block_idx
636
+ self.resid_pdrop = resid_pdrop
637
+ self.norm_1 = nn.LayerNorm(hidden_size, bias=False)
638
+ self.attn = DBRX_ATTENTION_CLASSES[attn_implementation](
639
+ hidden_size=hidden_size,
640
+ num_heads=num_heads,
641
+ max_position_embeddings=max_position_embeddings,
642
+ attn_config=attn_config,
643
+ block_idx=block_idx,
644
+ )
645
+ self.norm_2 = nn.LayerNorm(hidden_size, bias=False)
646
+
647
+ def forward(
648
+ self,
649
+ hidden_states: torch.Tensor,
650
+ position_ids: torch.LongTensor,
651
+ attention_mask: Optional[torch.Tensor] = None,
652
+ past_key_value: Optional[Cache] = None,
653
+ output_attentions: bool = False,
654
+ use_cache: bool = False,
655
+ cache_position: Optional[torch.LongTensor] = None,
656
+ **kwargs: Any,
657
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
658
+ Optional[Cache]]:
659
+
660
+ residual_states = hidden_states
661
+ hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
662
+
663
+ hidden_states, attn_weights, past_key_value = self.attn(
664
+ hidden_states=hidden_states,
665
+ attention_mask=attention_mask,
666
+ position_ids=position_ids,
667
+ past_key_value=past_key_value,
668
+ output_attentions=output_attentions,
669
+ use_cache=use_cache,
670
+ cache_position=cache_position,
671
+ **kwargs,
672
+ )
673
+
674
+ hidden_states = nn.functional.dropout(hidden_states,
675
+ p=self.resid_pdrop,
676
+ training=self.training)
677
+ hidden_states = hidden_states + residual_states
678
+
679
+ residual_states = hidden_states
680
+ hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
681
+
682
+ return residual_states, hidden_states, attn_weights, past_key_value
683
+
684
+
685
+ class DbrxRouter(nn.Module):
686
+
687
+ def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
688
+ moe_jitter_eps: Optional[float],
689
+ moe_normalize_expert_weights: Optional[float],
690
+ uniform_expert_assignment: bool):
691
+ super().__init__()
692
+ self.hidden_size = hidden_size
693
+ self.moe_num_experts = moe_num_experts
694
+ self.moe_top_k = moe_top_k
695
+ self.moe_jitter_eps = moe_jitter_eps
696
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
697
+ self.uniform_expert_assignment = uniform_expert_assignment
698
+
699
+ self.layer = nn.Linear(self.hidden_size,
700
+ self.moe_num_experts,
701
+ bias=False)
702
+
703
+ def jitter(self, x: torch.Tensor) -> torch.Tensor:
704
+ if self.moe_jitter_eps is None:
705
+ raise RuntimeError('The router does not have moe_jitter_eps set.')
706
+ low = 1.0 - self.moe_jitter_eps
707
+ high = 1.0 + self.moe_jitter_eps
708
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
709
+ return low + noise * (high - low)
710
+
711
+ def forward(
712
+ self, x: torch.Tensor
713
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
714
+ if self.training and self.moe_jitter_eps is not None:
715
+ x = x * self.jitter(x)
716
+
717
+ weights = self.layer(x.view(-1,
718
+ x.shape[-1])).softmax(dim=-1,
719
+ dtype=torch.float32)
720
+ top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
721
+
722
+ if self.moe_normalize_expert_weights:
723
+ top_weights = top_weights / torch.norm(
724
+ top_weights,
725
+ p=self.moe_normalize_expert_weights,
726
+ dim=-1,
727
+ keepdim=True)
728
+
729
+ if self.uniform_expert_assignment:
730
+ with torch.no_grad():
731
+ uniform_tensor = torch.arange(
732
+ 0,
733
+ top_experts.numel(),
734
+ device=top_experts.device,
735
+ dtype=top_experts.dtype) % self.moe_num_experts
736
+ top_experts = uniform_tensor.reshape(top_experts.shape)
737
+ # Note, weights and top_weights are not changed
738
+
739
+ weights = weights.to(x.dtype)
740
+ top_weights = top_weights.to(x.dtype)
741
+ return weights, top_weights, top_experts # type: ignore
742
+
743
+
744
+ class DbrxMLP(nn.Module):
745
+
746
+ def __init__(self, hidden_size: int, ffn_hidden_size: int, ffn_act_fn: dict):
747
+ super().__init__()
748
+
749
+ self.w1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
750
+ self.v1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
751
+ self.w2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
752
+ self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
753
+
754
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
755
+
756
+ return self.w2(self.activation_fn(self.w1(x)) * self.v1(x))
757
+
758
+
759
+ class DbrxExperts(nn.Module):
760
+
761
+ def __init__(self, hidden_size: int, ffn_hidden_size: int,
762
+ moe_num_experts: int, ffn_act_fn: dict):
763
+ super().__init__()
764
+ self.moe_num_experts = moe_num_experts
765
+ self.mlp = nn.ModuleList([DbrxMLP(hidden_size, ffn_hidden_size, ffn_act_fn) for _ in range(moe_num_experts)])
766
+
767
+ def forward(self, x: torch.Tensor, weights: torch.Tensor,
768
+ top_weights: torch.Tensor,
769
+ top_experts: torch.LongTensor) -> torch.Tensor:
770
+ bsz, q_len, hidden_size = x.shape
771
+ x = x.view(-1, hidden_size)
772
+ out = torch.zeros_like(x)
773
+
774
+ expert_mask = nn.functional.one_hot(
775
+ top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
776
+ for expert_idx in range(0, self.moe_num_experts):
777
+ topk_idx, token_idx = torch.where(expert_mask[expert_idx])
778
+ if token_idx.shape[0] == 0:
779
+ continue
780
+
781
+ expert_tokens = x[None, token_idx].reshape(-1, hidden_size)
782
+ expert_out = self.mlp[expert_idx](expert_tokens) * top_weights[token_idx, topk_idx, None]
783
+
784
+ out.index_add_(0, token_idx, expert_out)
785
+
786
+ out = out.reshape(bsz, q_len, hidden_size)
787
+ return out
788
+
789
+
790
+ class DbrxFFN(nn.Module):
791
+
792
+ def __init__(self, hidden_size: int, ffn_config: DbrxFFNConfig):
793
+ super().__init__()
794
+
795
+ self.router = DbrxRouter(
796
+ hidden_size,
797
+ moe_num_experts=ffn_config.moe_num_experts,
798
+ moe_top_k=ffn_config.moe_top_k,
799
+ moe_jitter_eps=ffn_config.moe_jitter_eps,
800
+ moe_normalize_expert_weights=ffn_config.
801
+ moe_normalize_expert_weights,
802
+ uniform_expert_assignment=ffn_config.uniform_expert_assignment,
803
+ )
804
+
805
+ self.experts = DbrxExperts(
806
+ hidden_size=hidden_size,
807
+ ffn_hidden_size=ffn_config.ffn_hidden_size,
808
+ moe_num_experts=ffn_config.moe_num_experts,
809
+ ffn_act_fn=ffn_config.ffn_act_fn,
810
+ )
811
+
812
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
813
+ weights, top_weights, top_experts = self.router(x)
814
+ out = self.experts(x, weights, top_weights, top_experts)
815
+ return out, weights
816
+
817
+
818
+ class DbrxBlock(nn.Module):
819
+
820
+ def __init__(self, config: DbrxConfig, block_idx: int):
821
+ super().__init__()
822
+ self.hidden_size = config.d_model
823
+ self.resid_pdrop = config.resid_pdrop
824
+ self.block_idx = block_idx
825
+ self.norm_attn_norm = DbrxNormAttentionNorm(
826
+ hidden_size=config.d_model,
827
+ num_heads=config.n_heads,
828
+ max_position_embeddings=config.max_seq_len,
829
+ resid_pdrop=config.resid_pdrop,
830
+ attn_implementation=config._attn_implementation,
831
+ attn_config=config.attn_config,
832
+ block_idx=block_idx,
833
+ )
834
+ self.ffn = DbrxFFN(hidden_size=config.d_model,
835
+ ffn_config=config.ffn_config)
836
+
837
+ def forward(
838
+ self,
839
+ hidden_states: torch.Tensor,
840
+ position_ids: torch.LongTensor,
841
+ attention_mask: Optional[torch.Tensor] = None,
842
+ past_key_value: Optional[Cache] = None,
843
+ output_attentions: Optional[bool] = False,
844
+ output_router_logits: Optional[bool] = False,
845
+ use_cache: Optional[bool] = False,
846
+ cache_position: Optional[torch.LongTensor] = None,
847
+ **kwargs: Any,
848
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]],
849
+ Tuple[torch.Tensor, Optional[Cache]], Tuple[
850
+ torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
851
+ Tuple[torch.Tensor, Optional[torch.Tensor],
852
+ Optional[torch.Tensor]], Tuple[
853
+ torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
854
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache],
855
+ Optional[torch.Tensor]],]:
856
+ """Forward function for DbrxBlock.
857
+
858
+ Args:
859
+ hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
860
+ position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
861
+ attention_mask (`torch.Tensor`, optional): attention mask of size (batch_size, sequence_length)
862
+ if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
863
+ if default attention is used.
864
+ past_key_value (`Tuple(torch.Tensor)`, optional): cached past key and value projection states
865
+ output_attentions (`bool`, optional): Whether or not to return the attentions tensors of all
866
+ attention layers. See `attentions` under returned tensors for more detail.
867
+ output_router_logits (`bool`, optional): Whether or not to return the router logits.
868
+ use_cache (`bool`, optional): If set to `True`, `past_key_values` key value states are
869
+ returned and can be used to speed up decoding (see `past_key_values`).
870
+ cache_position (`torch.LongTensor`, optional): position ids of the cache
871
+ """
872
+ if 'padding_mask' in kwargs:
873
+ warnings.warn(
874
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
875
+ )
876
+
877
+ # Norm + Attention + Norm
878
+ resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
879
+ hidden_states=hidden_states,
880
+ attention_mask=attention_mask,
881
+ position_ids=position_ids,
882
+ past_key_value=past_key_value,
883
+ output_attentions=output_attentions,
884
+ use_cache=use_cache,
885
+ cache_position=cache_position,
886
+ **kwargs,
887
+ )
888
+
889
+ # Fully Connected
890
+ hidden_states, router_logits = self.ffn(hidden_states)
891
+ hidden_states = nn.functional.dropout(hidden_states,
892
+ p=self.resid_pdrop,
893
+ training=self.training)
894
+ hidden_states = resid_states + hidden_states
895
+
896
+ outputs = (hidden_states,)
897
+
898
+ if output_attentions:
899
+ outputs += (self_attn_weights,)
900
+
901
+ if use_cache:
902
+ outputs += (present_key_value,)
903
+
904
+ if output_router_logits:
905
+ outputs += (router_logits,)
906
+
907
+ return outputs
908
+
909
+
910
+ class DbrxPreTrainedModel(PreTrainedModel):
911
+ config_class = DbrxConfig
912
+ base_model_prefix = 'transformer'
913
+ supports_gradient_checkpointing = True
914
+ _no_split_modules = ['DbrxBlock']
915
+ _skip_keys_device_placement = ['past_key_values']
916
+ _supports_flash_attn_2 = True
917
+ _supports_sdpa = False
918
+ _supports_cache_class = True
919
+
920
+ def _init_weights(self, module: nn.Module):
921
+ std = self.config.initializer_range
922
+ if isinstance(module, nn.Linear):
923
+ module.weight.data.normal_(mean=0.0, std=std)
924
+ if module.bias is not None:
925
+ module.bias.data.zero_()
926
+ elif isinstance(module, nn.Embedding):
927
+ module.weight.data.normal_(mean=0.0, std=std)
928
+ if module.padding_idx is not None:
929
+ module.weight.data[module.padding_idx].zero_()
930
+ elif isinstance(module, nn.LayerNorm):
931
+ module.weight.data.normal_(mean=0.0, std=std)
932
+ if module.bias is not None:
933
+ module.bias.data.zero_()
934
+
935
+ def _setup_cache(self, cache_cls: Any, max_batch_size: int,
936
+ max_cache_len: int): # TODO: how to set var type of class?
937
+ if self.config._attn_implementation == 'flash_attention_2' and cache_cls == StaticCache:
938
+ raise ValueError(
939
+ '`static` cache implementation is not compatible with ' +
940
+ '`attn_implementation==flash_attention_2`. Make sure to use ' +
941
+ '`spda` in the mean time and open an issue at https://github.com/huggingface/transformers.'
942
+ )
943
+
944
+ for block in self.transformer.blocks:
945
+ device = block.norm_attn_norm.norm_1.weight.device
946
+ if hasattr(self.config, '_pre_quantization_dtype'):
947
+ dtype = self.config._pre_quantization_dtype
948
+ else:
949
+ dtype = block.norm_attn_norm.attn.out_proj.weight.dtype
950
+ block.norm_attn_norm.attn.past_key_value = cache_cls(self.config,
951
+ max_batch_size,
952
+ max_cache_len,
953
+ device=device,
954
+ dtype=dtype)
955
+
956
+ def _reset_cache(self):
957
+ for block in self.transformer.blocks:
958
+ block.norm_attn_norm.attn.past_key_value = None
959
+
960
+
961
+ class DbrxModel(DbrxPreTrainedModel):
962
+ """Transformer decoder consisting of *config.num_hidden_layers*
963
+
964
+ [`DbrxBlock`] layers.
965
+
966
+ Args:
967
+ config: DbrxConfig
968
+ """
969
+
970
+ def __init__(self, config: DbrxConfig):
971
+ super().__init__(config)
972
+ self.padding_idx = config.pad_token_id
973
+ self.vocab_size = config.vocab_size
974
+ self.emb_pdrop = config.emb_pdrop
975
+
976
+ self.wte = nn.Embedding(config.vocab_size, config.d_model,
977
+ self.padding_idx)
978
+ self.blocks = nn.ModuleList([
979
+ DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)
980
+ ])
981
+ self.norm_f = nn.LayerNorm(config.d_model, bias=False)
982
+ self.gradient_checkpointing = False
983
+
984
+ # Initialize weights and apply final processing
985
+ self.post_init()
986
+
987
+ def get_input_embeddings(self) -> nn.Embedding:
988
+ return self.wte
989
+
990
+ def set_input_embeddings(self, value: nn.Embedding):
991
+ self.wte = value
992
+
993
+ def _autocast_input_embeddings(self,
994
+ inputs_embeds: torch.Tensor) -> torch.Tensor:
995
+ if inputs_embeds.device.type == 'cuda' and torch.is_autocast_enabled():
996
+ return inputs_embeds.to(dtype=torch.get_autocast_gpu_dtype())
997
+ elif inputs_embeds.device.type == 'cpu' and torch.is_autocast_cpu_enabled(
998
+ ):
999
+ return inputs_embeds.to(dtype=torch.get_autocast_cpu_dtype())
1000
+ else:
1001
+ return inputs_embeds
1002
+
1003
+ def forward(
1004
+ self,
1005
+ input_ids: Optional[torch.LongTensor] = None,
1006
+ attention_mask: Optional[torch.Tensor] = None,
1007
+ position_ids: Optional[torch.LongTensor] = None,
1008
+ past_key_values: Optional[Cache] = None,
1009
+ inputs_embeds: Optional[torch.Tensor] = None,
1010
+ use_cache: Optional[bool] = None,
1011
+ output_attentions: Optional[bool] = None,
1012
+ output_hidden_states: Optional[bool] = None,
1013
+ output_router_logits: Optional[bool] = None,
1014
+ return_dict: Optional[bool] = None,
1015
+ cache_position: Optional[torch.LongTensor] = None,
1016
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1017
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1018
+ output_hidden_states = (output_hidden_states
1019
+ if output_hidden_states is not None else
1020
+ self.config.output_hidden_states)
1021
+ output_router_logits = (output_router_logits
1022
+ if output_router_logits is not None else
1023
+ self.config.output_router_logits)
1024
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1025
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1026
+
1027
+ if (input_ids is None) ^ (inputs_embeds is not None):
1028
+ raise ValueError(
1029
+ 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
1030
+ )
1031
+
1032
+ if self.gradient_checkpointing and self.training and use_cache:
1033
+ logger.warning_once(
1034
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
1035
+ )
1036
+ use_cache = False
1037
+
1038
+ if inputs_embeds is None:
1039
+ inputs_embeds = self.wte(input_ids)
1040
+
1041
+ inputs_embeds = self._autocast_input_embeddings(
1042
+ inputs_embeds) # type: ignore
1043
+ inputs_embeds = nn.functional.dropout(inputs_embeds,
1044
+ p=self.emb_pdrop,
1045
+ training=self.training)
1046
+
1047
+ past_seen_tokens = 0
1048
+ if use_cache: # kept for BC (cache positions)
1049
+ if not isinstance(past_key_values, StaticCache):
1050
+ past_key_values = DynamicCache.from_legacy_cache(
1051
+ past_key_values)
1052
+ past_seen_tokens = past_key_values.get_seq_length( # type: ignore
1053
+ )
1054
+
1055
+ if cache_position is None:
1056
+ if isinstance(past_key_values, StaticCache):
1057
+ raise ValueError(
1058
+ 'cache_position is a required argument when using StaticCache.'
1059
+ )
1060
+ cache_position = torch.arange( # type: ignore
1061
+ past_seen_tokens,
1062
+ past_seen_tokens + inputs_embeds.shape[1],
1063
+ device=inputs_embeds.device)
1064
+
1065
+ if position_ids is None:
1066
+ position_ids = cache_position.unsqueeze(0) # type: ignore
1067
+
1068
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
1069
+ cache_position) # type: ignore
1070
+
1071
+ # embed positions
1072
+ hidden_states = inputs_embeds
1073
+
1074
+ # decoder layers
1075
+ all_hidden_states = () if output_hidden_states else None
1076
+ all_self_attns = () if output_attentions else None
1077
+ all_router_logits = () if output_router_logits else None
1078
+ next_decoder_cache = None
1079
+
1080
+ for block in self.blocks:
1081
+ if output_hidden_states:
1082
+ all_hidden_states += (hidden_states,) # type: ignore
1083
+
1084
+ if self.gradient_checkpointing and self.training:
1085
+ block_outputs = self._gradient_checkpointing_func(
1086
+ block.__call__,
1087
+ hidden_states,
1088
+ attention_mask=causal_mask,
1089
+ position_ids=position_ids,
1090
+ past_key_values=past_key_values,
1091
+ output_attentions=output_attentions,
1092
+ output_router_logits=output_router_logits,
1093
+ use_cache=use_cache,
1094
+ cache_position=cache_position,
1095
+ )
1096
+ else:
1097
+ block_outputs = block(
1098
+ hidden_states,
1099
+ attention_mask=causal_mask,
1100
+ position_ids=position_ids,
1101
+ past_key_value=past_key_values,
1102
+ output_attentions=output_attentions,
1103
+ output_router_logits=output_router_logits,
1104
+ use_cache=use_cache,
1105
+ cache_position=cache_position,
1106
+ )
1107
+
1108
+ hidden_states = block_outputs[0]
1109
+
1110
+ if use_cache:
1111
+ next_decoder_cache = block_outputs[
1112
+ 2 if output_attentions else 1]
1113
+
1114
+ if output_attentions:
1115
+ all_self_attns += (block_outputs[1],) # type: ignore
1116
+
1117
+ if output_router_logits:
1118
+ all_router_logits += (block_outputs[-1],) # type: ignore
1119
+
1120
+ hidden_states = self.norm_f(hidden_states)
1121
+
1122
+ # add hidden states from the last decoder layer
1123
+ if output_hidden_states:
1124
+ all_hidden_states += (hidden_states,) # type: ignore
1125
+
1126
+ next_cache = None
1127
+ if use_cache:
1128
+ next_cache = (
1129
+ next_decoder_cache.to_legacy_cache() # type: ignore
1130
+ if isinstance(next_decoder_cache, Cache) else
1131
+ next_decoder_cache)
1132
+ if not return_dict:
1133
+ return tuple(v for v in [
1134
+ hidden_states, next_cache, all_hidden_states, all_self_attns,
1135
+ all_router_logits
1136
+ ] if v is not None)
1137
+ return MoeModelOutputWithPast(
1138
+ last_hidden_state=hidden_states,
1139
+ past_key_values=next_cache,
1140
+ hidden_states=all_hidden_states,
1141
+ attentions=all_self_attns,
1142
+ router_logits=all_router_logits,
1143
+ )
1144
+
1145
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1146
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1147
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1148
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1149
+ def _update_causal_mask(
1150
+ self, attention_mask: Optional[torch.Tensor],
1151
+ input_tensor: torch.Tensor,
1152
+ cache_position: torch.Tensor) -> Optional[torch.Tensor]:
1153
+ if self.config._attn_implementation == 'flash_attention_2':
1154
+ if attention_mask is not None and 0.0 in attention_mask:
1155
+ return attention_mask
1156
+ return None
1157
+
1158
+ dtype, device = input_tensor.dtype, input_tensor.device
1159
+ min_dtype = torch.finfo(dtype).min
1160
+ sequence_length = input_tensor.shape[1]
1161
+ if hasattr(self.blocks[0].norm_attn_norm.attn,
1162
+ 'past_key_value'): # static cache
1163
+ target_length = self.config.max_position_embeddings
1164
+ else: # dynamic cache
1165
+ target_length = (attention_mask.shape[-1] if isinstance(
1166
+ attention_mask, torch.Tensor) else cache_position[-1] + 1)
1167
+ target_length = int(target_length)
1168
+
1169
+ causal_mask = torch.full((sequence_length, target_length),
1170
+ fill_value=min_dtype,
1171
+ dtype=dtype,
1172
+ device=device)
1173
+ if sequence_length != 1:
1174
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1175
+ causal_mask *= torch.arange(
1176
+ target_length, device=device) > cache_position.reshape(-1, 1)
1177
+ causal_mask = causal_mask[None,
1178
+ None, :, :].expand(input_tensor.shape[0], 1,
1179
+ -1, -1)
1180
+ if attention_mask is not None:
1181
+ causal_mask = causal_mask.clone(
1182
+ ) # copy to contiguous memory for in-place edit
1183
+ if attention_mask.dim() == 2:
1184
+ mask_length = attention_mask.shape[-1]
1185
+ padding_mask = causal_mask[..., :mask_length].eq(
1186
+ 0.0) * attention_mask[:, None, None, :].eq(0.0)
1187
+ causal_mask[..., :mask_length] = causal_mask[
1188
+ ..., :mask_length].masked_fill(padding_mask, min_dtype)
1189
+ elif attention_mask.dim() == 4:
1190
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1191
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1192
+ if attention_mask.shape[
1193
+ -2] < cache_position[0] + sequence_length:
1194
+ offset = cache_position[0]
1195
+ else:
1196
+ offset = 0
1197
+ mask_shape = attention_mask.shape
1198
+ mask_slice = (attention_mask.eq(0.0)).to(
1199
+ dtype=dtype) * min_dtype
1200
+ causal_mask[:mask_shape[0], :mask_shape[1],
1201
+ offset:mask_shape[2] +
1202
+ offset, :mask_shape[3]] = mask_slice
1203
+
1204
+ if (self.config._attn_implementation == 'sdpa' and
1205
+ attention_mask is not None and
1206
+ attention_mask.device.type == 'cuda'):
1207
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1208
+ is_tracing = (
1209
+ torch.jit.is_tracing() or
1210
+ isinstance(input_tensor, torch.fx.Proxy) or # type: ignore
1211
+ (hasattr(torch, '_dynamo') and torch._dynamo.is_compiling()))
1212
+ if not is_tracing and torch.any(attention_mask != 1):
1213
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1214
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1215
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1216
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1217
+ causal_mask, min_dtype)
1218
+
1219
+ return causal_mask
1220
+
1221
+
1222
+ class DbrxForCausalLM(DbrxPreTrainedModel):
1223
+
1224
+ def __init__(self, config: DbrxConfig):
1225
+ super().__init__(config)
1226
+ self.transformer = DbrxModel(config)
1227
+ self.vocab_size = config.vocab_size
1228
+ self.lm_head = nn.Linear(config.hidden_size,
1229
+ config.vocab_size,
1230
+ bias=False)
1231
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1232
+ self.num_experts = config.ffn_config.moe_num_experts
1233
+ self.num_experts_per_tok = config.ffn_config.moe_top_k
1234
+
1235
+ # Initialize weights and apply final processing
1236
+ self.post_init()
1237
+
1238
+ def get_input_embeddings(self) -> nn.Embedding:
1239
+ return self.transformer.get_input_embeddings()
1240
+
1241
+ def set_input_embeddings(self, value: nn.Embedding):
1242
+ self.transformer.set_input_embeddings(value)
1243
+
1244
+ def get_output_embeddings(self) -> nn.Linear:
1245
+ return self.lm_head
1246
+
1247
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1248
+ self.lm_head = new_embeddings
1249
+
1250
+ def set_decoder(self, decoder: DbrxModel):
1251
+ self.transformer = decoder
1252
+
1253
+ def get_decoder(self) -> DbrxModel:
1254
+ return self.transformer
1255
+
1256
+ def forward(
1257
+ self,
1258
+ input_ids: Optional[torch.LongTensor] = None,
1259
+ attention_mask: Optional[torch.Tensor] = None,
1260
+ position_ids: Optional[torch.LongTensor] = None,
1261
+ past_key_values: Optional[Cache] = None,
1262
+ inputs_embeds: Optional[torch.Tensor] = None,
1263
+ labels: Optional[torch.LongTensor] = None,
1264
+ use_cache: Optional[bool] = None,
1265
+ output_attentions: Optional[bool] = None,
1266
+ output_hidden_states: Optional[bool] = None,
1267
+ output_router_logits: Optional[bool] = None,
1268
+ return_dict: Optional[bool] = None,
1269
+ cache_position: Optional[torch.LongTensor] = None,
1270
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1271
+ r"""Forward function for causal language modeling.
1272
+
1273
+ Example:
1274
+ ```python
1275
+ >>> from transformers import AutoTokenizer, DbrxForCausalLM
1276
+
1277
+ >>> model = DbrxForCausalLM.from_pretrained("databricks/dbrx")
1278
+ >>> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx")
1279
+
1280
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1281
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1282
+
1283
+ >>> # Generate
1284
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1285
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1286
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1287
+ ```
1288
+ """
1289
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1290
+ output_hidden_states = (output_hidden_states
1291
+ if output_hidden_states is not None else
1292
+ self.config.output_hidden_states)
1293
+ output_router_logits = (output_router_logits
1294
+ if output_router_logits is not None else
1295
+ self.config.output_router_logits)
1296
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1297
+
1298
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1299
+ outputs = self.transformer(
1300
+ input_ids=input_ids,
1301
+ attention_mask=attention_mask,
1302
+ position_ids=position_ids,
1303
+ past_key_values=past_key_values,
1304
+ inputs_embeds=inputs_embeds,
1305
+ use_cache=use_cache,
1306
+ output_attentions=output_attentions,
1307
+ output_hidden_states=output_hidden_states,
1308
+ output_router_logits=output_router_logits,
1309
+ return_dict=return_dict,
1310
+ cache_position=cache_position,
1311
+ )
1312
+
1313
+ hidden_states = outputs[0]
1314
+ logits = self.lm_head(hidden_states)
1315
+
1316
+ loss = None
1317
+ if labels is not None:
1318
+ # Shift so that tokens < n predict n
1319
+ shift_logits = logits[..., :-1, :].contiguous()
1320
+ shift_labels = labels[..., 1:].contiguous()
1321
+ # Flatten the tokens
1322
+ loss_fct = nn.CrossEntropyLoss()
1323
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1324
+ shift_labels = shift_labels.view(-1)
1325
+ # Enable model parallelism
1326
+ shift_labels = shift_labels.to(shift_logits.device)
1327
+ loss = loss_fct(shift_logits, shift_labels)
1328
+
1329
+ aux_loss = None
1330
+ if output_router_logits:
1331
+ aux_loss = load_balancing_loss_func(
1332
+ outputs.router_logits if return_dict else outputs[-1],
1333
+ self.num_experts,
1334
+ self.num_experts_per_tok,
1335
+ attention_mask,
1336
+ )
1337
+ if labels is not None and loss is not None:
1338
+ loss += self.router_aux_loss_coef * aux_loss.to(
1339
+ loss.device) # make sure to reside in the same device
1340
+
1341
+ if not return_dict:
1342
+ output = (logits,) + outputs[1:]
1343
+ return (loss,) + output if loss is not None else output
1344
+
1345
+ return MoeCausalLMOutputWithPast(
1346
+ loss=loss,
1347
+ aux_loss=aux_loss,
1348
+ logits=logits,
1349
+ past_key_values=outputs.past_key_values,
1350
+ hidden_states=outputs.hidden_states,
1351
+ attentions=outputs.attentions,
1352
+ router_logits=outputs.router_logits,
1353
+ )
1354
+
1355
+ def prepare_inputs_for_generation(
1356
+ self,
1357
+ input_ids: torch.Tensor,
1358
+ past_key_values: Optional[Cache] = None,
1359
+ attention_mask: Optional[torch.Tensor] = None,
1360
+ inputs_embeds: Optional[torch.Tensor] = None,
1361
+ **kwargs: Any) -> Dict[str, Any]:
1362
+ past_length = 0
1363
+ if past_key_values is not None:
1364
+ if isinstance(past_key_values, Cache):
1365
+ cache_length = past_key_values.get_seq_length()
1366
+ past_length = past_key_values.seen_tokens
1367
+ max_cache_length = past_key_values.get_max_length()
1368
+ else:
1369
+ cache_length = past_length = past_key_values[0][0].shape[2]
1370
+ max_cache_length = None
1371
+
1372
+ # Keep only the unprocessed tokens:
1373
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1374
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1375
+ # input)
1376
+ if attention_mask is not None and attention_mask.shape[
1377
+ 1] > input_ids.shape[1]:
1378
+ input_ids = input_ids[:,
1379
+ -(attention_mask.shape[1] - past_length):]
1380
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1381
+ # input_ids based on the past_length.
1382
+ elif past_length < input_ids.shape[1]:
1383
+ input_ids = input_ids[:, past_length:]
1384
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1385
+
1386
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1387
+ if (max_cache_length is not None and attention_mask is not None and
1388
+ cache_length + input_ids.shape[1] > max_cache_length):
1389
+ attention_mask = attention_mask[:, -max_cache_length:]
1390
+
1391
+ position_ids = kwargs.get('position_ids', None)
1392
+ if attention_mask is not None and position_ids is None:
1393
+ # create position_ids on the fly for batch generation
1394
+ position_ids = attention_mask.long().cumsum(-1) - 1
1395
+ position_ids.masked_fill_(attention_mask == 0, 1)
1396
+ if past_key_values:
1397
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1398
+
1399
+ if self.generation_config.cache_implementation == 'static':
1400
+ # generation with static cache
1401
+ cache_position = kwargs.get('cache_position', None)
1402
+ if cache_position is None:
1403
+ past_length = 0
1404
+ else:
1405
+ past_length = cache_position[-1] + 1
1406
+ input_ids = input_ids[:, past_length:]
1407
+ position_ids = position_ids[:,
1408
+ past_length:] if position_ids is not None else None
1409
+
1410
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1411
+ # same goes for position ids. Could also help with continued generation.
1412
+ input_length = position_ids.shape[
1413
+ -1] if position_ids is not None else input_ids.shape[-1]
1414
+ cache_position = torch.arange(past_length,
1415
+ past_length + input_length,
1416
+ device=input_ids.device)
1417
+ position_ids = position_ids.contiguous(
1418
+ ) if position_ids is not None else None
1419
+
1420
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1421
+ if inputs_embeds is not None and past_key_values is None:
1422
+ model_inputs = {'inputs_embeds': inputs_embeds}
1423
+ else:
1424
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1425
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1426
+ # TODO: use `next_tokens` directly instead.
1427
+ model_inputs = {'input_ids': input_ids.contiguous()}
1428
+
1429
+ model_inputs.update(
1430
+ { # type: ignore
1431
+ 'position_ids': position_ids,
1432
+ 'cache_position': cache_position,
1433
+ 'past_key_values': past_key_values,
1434
+ 'use_cache': kwargs.get('use_cache'),
1435
+ 'attention_mask': attention_mask,
1436
+ }
1437
+ )
1438
+ return model_inputs
1439
+
1440
+ @staticmethod
1441
+ def _reorder_cache(past_key_values: Cache, beam_idx: torch.LongTensor):
1442
+ reordered_past = ()
1443
+ for layer_past in past_key_values:
1444
+ reordered_past += (tuple(
1445
+ past_state.index_select(0, beam_idx.to(past_state.device))
1446
+ for past_state in layer_past),)
1447
+ return reordered_past