SinclairSchneider commited on
Commit
7de5909
1 Parent(s): 8fa65ea

Delete .ipynb_checkpoints/modeling_dbrx-checkpoint.py

Browse files
.ipynb_checkpoints/modeling_dbrx-checkpoint.py DELETED
@@ -1,1455 +0,0 @@
1
- """PyTorch Dbrx model."""
2
-
3
- import math
4
- import warnings
5
- from copy import deepcopy
6
- from functools import partial
7
- from typing import Any, Callable, Dict, Optional, Tuple, Union
8
-
9
- import torch
10
- import torch.nn.functional as F
11
- import torch.utils.checkpoint
12
- from torch import nn
13
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
14
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
- from transformers.modeling_outputs import (MoeCausalLMOutputWithPast,
16
- MoeModelOutputWithPast)
17
- from transformers.modeling_utils import PreTrainedModel
18
- from transformers.utils import is_flash_attn_2_available, logging
19
-
20
- from .configuration_dbrx import DbrxAttentionConfig, DbrxConfig, DbrxFFNConfig
21
-
22
- if is_flash_attn_2_available():
23
- try:
24
- from flash_attn import flash_attn_func, flash_attn_varlen_func
25
- from flash_attn.bert_padding import pad_input # noqa
26
- from flash_attn.bert_padding import index_first_axis, unpad_input
27
- except:
28
- pass
29
-
30
- logger = logging.get_logger(__name__)
31
-
32
- _CONFIG_FOR_DOC = 'DbrxConfig'
33
-
34
- #############################################################################
35
- # Copied from LLaMaRotaryEmbedding
36
- #############################################################################
37
-
38
-
39
- class DbrxRotaryEmbedding(nn.Module):
40
-
41
- def __init__(self,
42
- dim: int,
43
- max_position_embeddings: int = 2048,
44
- base: float = 10000.0,
45
- scaling_factor: float = 1.0):
46
- super().__init__()
47
- self.scaling_factor = scaling_factor
48
- self.dim = dim
49
- self.max_position_embeddings = max_position_embeddings
50
- self.base = base
51
- inv_freq = 1.0 / (self.base**(
52
- torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
53
- self.register_buffer('inv_freq', inv_freq, persistent=False)
54
- # For BC we register cos and sin cached
55
- self.max_seq_len_cached = max_position_embeddings
56
-
57
- @torch.no_grad()
58
- def forward(
59
- self, x: torch.Tensor, position_ids: torch.LongTensor
60
- ) -> Tuple[torch.Tensor, torch.Tensor]:
61
- # x: [bs, num_attention_heads, seq_len, head_size]
62
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
63
- position_ids.shape[0], -1, 1)
64
- position_ids_expanded = position_ids[:, None, :].float()
65
- # Force float32 since bfloat16 loses precision on long contexts
66
- # See https://github.com/huggingface/transformers/pull/29285
67
- device_type = x.device.type
68
- device_type = device_type if isinstance(
69
- device_type, str) and device_type != 'mps' else 'cpu'
70
- with torch.autocast(device_type=device_type, enabled=False):
71
- freqs = (inv_freq_expanded.float()
72
- @ position_ids_expanded.float()).transpose(1, 2)
73
- emb = torch.cat((freqs, freqs), dim=-1)
74
- cos = emb.cos()
75
- sin = emb.sin()
76
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
77
-
78
-
79
- def rotate_half(x: torch.Tensor) -> torch.Tensor:
80
- """Rotates half the hidden dims of the input."""
81
- x1 = x[..., :x.shape[-1] // 2]
82
- x2 = x[..., x.shape[-1] // 2:]
83
- return torch.cat((-x2, x1), dim=-1)
84
-
85
-
86
- def apply_rotary_pos_emb(
87
- q: torch.Tensor,
88
- k: torch.Tensor,
89
- cos: torch.Tensor,
90
- sin: torch.Tensor,
91
- unsqueeze_dim: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
92
- """Applies Rotary Position Embedding to the query and key tensors.
93
-
94
- Args:
95
- q (`torch.Tensor`): The query tensor.
96
- k (`torch.Tensor`): The key tensor.
97
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
98
- sin (`torch.Tensor`): The sine part of the rotary embedding.
99
- unsqueeze_dim (`int`, *optional*, defaults to 1):
100
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos and
101
- sin so that they can be properly broadcasted to the dimensions of q and k. For example, note
102
- that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and
103
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
104
- cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have
105
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
106
-
107
- Returns:
108
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
109
- """
110
- cos = cos.unsqueeze(unsqueeze_dim)
111
- sin = sin.unsqueeze(unsqueeze_dim)
112
- q_embed = (q * cos) + (rotate_half(q) * sin)
113
- k_embed = (k * cos) + (rotate_half(k) * sin)
114
- return q_embed, k_embed
115
-
116
-
117
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
118
- """Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
119
-
120
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
121
- (batch, num_attention_heads, seqlen, head_dim)
122
- """
123
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
124
- if n_rep == 1:
125
- return hidden_states
126
- hidden_states = hidden_states[:, :,
127
- None, :, :].expand(batch, num_key_value_heads,
128
- n_rep, slen, head_dim)
129
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
130
- head_dim)
131
-
132
-
133
- #############################################################################
134
-
135
- #############################################################################
136
- # Modified from modeling_mixtral
137
- #############################################################################
138
-
139
-
140
- def load_balancing_loss_func(
141
- gate_logits: torch.Tensor,
142
- num_experts: int,
143
- top_k: int,
144
- attention_mask: Optional[torch.Tensor],
145
- ) -> torch.Tensor:
146
- r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
147
-
148
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
149
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
150
- experts is too unbalanced.
151
-
152
- Args:
153
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
154
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
155
- shape [batch_size X sequence_length, num_experts].
156
- num_experts (`int`):
157
- Number of experts.
158
- top_k (`int`):
159
- The number of experts each token is routed to.
160
- attention_mask (`torch.Tensor`, None):
161
- The attention_mask used in forward function
162
- shape [batch_size X sequence_length] if not None.
163
-
164
- Returns:
165
- The auxiliary loss.
166
- """
167
- if gate_logits is None or not isinstance(gate_logits, tuple):
168
- return torch.tensor(0.0)
169
-
170
- if isinstance(gate_logits, tuple):
171
- compute_device = gate_logits[0].device
172
- concatenated_gate_logits = torch.cat(
173
- [layer_gate.to(compute_device) for layer_gate in gate_logits],
174
- dim=0)
175
-
176
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits,
177
- dim=-1)
178
-
179
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
180
-
181
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
182
-
183
- if attention_mask is None:
184
- # Compute the percentage of tokens routed to each experts
185
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
186
-
187
- # Compute the average probability of routing to these experts
188
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
189
- else:
190
- batch_size, sequence_length = attention_mask.shape
191
- num_hidden_layers = concatenated_gate_logits.shape[0] // (
192
- batch_size * sequence_length)
193
-
194
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
195
- expert_attention_mask = (attention_mask[None, :, :, None, None].expand(
196
- (num_hidden_layers, batch_size, sequence_length, top_k,
197
- num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
198
-
199
- # Compute the percentage of tokens routed to each experts
200
- tokens_per_expert = torch.sum(
201
- expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
202
- expert_attention_mask, dim=0)
203
-
204
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
205
- router_per_expert_attention_mask = (
206
- attention_mask[None, :, :, None].expand(
207
- (num_hidden_layers, batch_size, sequence_length,
208
- num_experts)).reshape(-1, num_experts).to(compute_device))
209
-
210
- # Compute the average probability of routing to these experts
211
- router_prob_per_expert = torch.sum(
212
- routing_weights * router_per_expert_attention_mask,
213
- dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
214
-
215
- overall_loss = torch.sum(tokens_per_expert *
216
- router_prob_per_expert.unsqueeze(0))
217
- return overall_loss * num_experts
218
-
219
-
220
- #############################################################################
221
-
222
-
223
- def resolve_ffn_act_fn(
224
- ffn_act_fn: dict) -> Callable[[torch.Tensor], torch.Tensor]:
225
- """Resolve the activation function for the feed-forward network.
226
-
227
- Args:
228
- ffn_act_fn (dict): The configuration dictionary for the activation function.
229
- The dict config must specify the 'name' of a torch.nn.functional activation
230
- function. All of other key values pairs are bound to the function as a partial.
231
-
232
- Returns:
233
- Callable[[torch.Tensor], torch.Tensor]: The activation function.
234
- """
235
- config = deepcopy(ffn_act_fn)
236
- name = config.pop('name')
237
- if not hasattr(nn.functional, name):
238
- raise ValueError(f'Unrecognised activation function name ({name}).')
239
- act = getattr(nn.functional, name)
240
- return partial(act, **config)
241
-
242
-
243
- #############################################################################
244
- # Copied from LLaMaAttention
245
- #############################################################################
246
-
247
-
248
- def _get_unpad_data(attention_mask: torch.Tensor):
249
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
250
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
251
- max_seqlen_in_batch = seqlens_in_batch.max().item()
252
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
253
- (1, 0))
254
- return (
255
- indices,
256
- cu_seqlens,
257
- max_seqlen_in_batch,
258
- )
259
-
260
-
261
- class DbrxAttention(nn.Module):
262
- """Multi-head self attention."""
263
-
264
- def __init__(self,
265
- hidden_size: int,
266
- num_heads: int,
267
- max_position_embeddings: int,
268
- attn_config: DbrxAttentionConfig,
269
- block_idx: Optional[int] = None):
270
- super().__init__()
271
- self.hidden_size = hidden_size
272
- self.num_heads = num_heads
273
- self.head_dim = self.hidden_size // self.num_heads
274
- self.max_position_embeddings = max_position_embeddings
275
- self.block_idx = block_idx
276
- self.config = attn_config
277
- if block_idx is None:
278
- logger.warning_once(
279
- f'Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will '
280
- +
281
- 'lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` '
282
- + 'when creating this class.')
283
-
284
- self.attn_pdrop = attn_config.attn_pdrop
285
- self.clip_qkv = attn_config.clip_qkv
286
- self.num_key_value_heads = attn_config.kv_n_heads
287
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
288
- self.rope_theta = attn_config.rope_theta
289
-
290
- self.Wqkv = nn.Linear(self.hidden_size,
291
- self.hidden_size +
292
- 2 * self.num_key_value_heads * self.head_dim,
293
- bias=False)
294
- self.out_proj = nn.Linear(self.hidden_size,
295
- self.hidden_size,
296
- bias=False)
297
- self.rotary_emb = DbrxRotaryEmbedding(
298
- self.head_dim,
299
- max_position_embeddings=self.max_position_embeddings,
300
- base=self.rope_theta,
301
- )
302
-
303
- def forward(
304
- self,
305
- hidden_states: torch.Tensor,
306
- position_ids: torch.LongTensor,
307
- attention_mask: Optional[torch.Tensor] = None,
308
- past_key_value: Optional[Cache] = None,
309
- output_attentions: bool = False,
310
- use_cache: bool = False,
311
- cache_position: Optional[torch.LongTensor] = None,
312
- **kwargs: Any,
313
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
314
- bsz, q_len, _ = hidden_states.size()
315
-
316
- qkv_states = self.Wqkv(hidden_states)
317
- if self.clip_qkv is not None:
318
- qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
319
-
320
- query_states, key_states, value_states = qkv_states.split(
321
- [
322
- self.hidden_size,
323
- self.num_key_value_heads * self.head_dim,
324
- self.num_key_value_heads * self.head_dim,
325
- ],
326
- dim=2,
327
- )
328
-
329
- query_states = query_states.view(bsz, q_len, self.num_heads,
330
- self.head_dim).transpose(1, 2)
331
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
332
- self.head_dim).transpose(1, 2)
333
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
334
- self.head_dim).transpose(1, 2)
335
-
336
- past_key_value = getattr(self, 'past_key_value', past_key_value)
337
- cos, sin = self.rotary_emb(value_states, position_ids)
338
- query_states, key_states = apply_rotary_pos_emb(query_states,
339
- key_states, cos, sin)
340
-
341
- if past_key_value is not None:
342
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
343
- cache_kwargs = {
344
- 'sin': sin,
345
- 'cos': cos,
346
- 'cache_position': cache_position
347
- }
348
- key_states, value_states = past_key_value.update(
349
- key_states, value_states, self.block_idx, cache_kwargs)
350
-
351
- key_states = repeat_kv(key_states, self.num_key_value_groups)
352
- value_states = repeat_kv(value_states, self.num_key_value_groups)
353
-
354
- attn_weights = torch.matmul(query_states, key_states.transpose(
355
- 2, 3)) / math.sqrt(self.head_dim)
356
-
357
- if attention_mask is not None: # no matter the length, we just slice it
358
- causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
359
- attn_weights = attn_weights + causal_mask
360
-
361
- # upcast attention to fp32
362
- attn_weights = nn.functional.softmax(attn_weights,
363
- dim=-1,
364
- dtype=torch.float32).to(
365
- query_states.dtype)
366
- attn_weights = nn.functional.dropout(attn_weights,
367
- p=self.attn_pdrop,
368
- training=self.training)
369
- attn_output = torch.matmul(attn_weights, value_states)
370
-
371
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
372
- raise ValueError(
373
- f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
374
- + f' {attn_output.size()}')
375
-
376
- attn_output = attn_output.transpose(1, 2).contiguous()
377
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
378
- attn_output = self.out_proj(attn_output)
379
-
380
- if not output_attentions:
381
- attn_weights = None
382
-
383
- return attn_output, attn_weights, past_key_value
384
-
385
-
386
- class DbrxFlashAttention2(DbrxAttention):
387
- """Dbrx flash attention module.
388
-
389
- This module inherits from `DbrxAttention` as the weights of the module stays
390
- untouched. The only required change would be on the forward pass where it
391
- calls the public API of flash attention.
392
- """
393
-
394
- def __init__(self, *args: Any, **kwargs: Any):
395
- if not is_flash_attn_2_available():
396
- raise ImportError(
397
- 'Flash Attention 2 is not available. Please install it with `pip install flash-attn`.'
398
- )
399
-
400
- super().__init__(*args, **kwargs)
401
-
402
- def forward(
403
- self,
404
- hidden_states: torch.Tensor,
405
- attention_mask: Optional[torch.LongTensor] = None,
406
- position_ids: Optional[torch.LongTensor] = None,
407
- past_key_value: Optional[Cache] = None,
408
- output_attentions: bool = False,
409
- use_cache: bool = False,
410
- cache_position: Optional[torch.LongTensor] = None,
411
- **kwargs: Any,
412
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
413
- Optional[Tuple[torch.Tensor]]]:
414
- logger.info(
415
- 'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
416
- )
417
- output_attentions = False
418
-
419
- bsz, q_len, _ = hidden_states.size()
420
-
421
- qkv_states = self.Wqkv(hidden_states)
422
- if self.clip_qkv is not None:
423
- qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
424
-
425
- query_states, key_states, value_states = qkv_states.split(
426
- [
427
- self.hidden_size,
428
- self.num_key_value_heads * self.head_dim,
429
- self.num_key_value_heads * self.head_dim,
430
- ],
431
- dim=2,
432
- )
433
-
434
- # Flash attention requires the input to have the shape
435
- # batch_size x seq_length x head_dim x hidden_dim
436
- # therefore we just need to keep the original shape
437
- query_states = query_states.view(bsz, q_len, self.num_heads,
438
- self.head_dim).transpose(1, 2)
439
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
440
- self.head_dim).transpose(1, 2)
441
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
442
- self.head_dim).transpose(1, 2)
443
-
444
- cos, sin = self.rotary_emb(value_states, position_ids)
445
- query_states, key_states = apply_rotary_pos_emb(query_states,
446
- key_states, cos, sin)
447
-
448
- past_key_value = getattr(self, 'past_key_value', past_key_value)
449
-
450
- if past_key_value is not None:
451
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
452
- cache_kwargs = {
453
- 'sin': sin,
454
- 'cos': cos,
455
- 'cache_position': cache_position
456
- }
457
- key_states, value_states = past_key_value.update(
458
- key_states, value_states, self.block_idx, cache_kwargs)
459
-
460
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout
461
- # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
462
- # to be able to avoid many of these transpose/reshape/view.
463
- query_states = query_states.transpose(1, 2)
464
- key_states = key_states.transpose(1, 2)
465
- value_states = value_states.transpose(1, 2)
466
-
467
- dropout_rate = self.attn_pdrop if self.training else 0.0
468
-
469
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
470
- # therefore the input hidden states gets silently casted in float32. Hence, we need
471
- # cast them back in the correct dtype just to be sure everything works as expected.
472
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
473
- # in fp32. (LlamaRMSNorm handles it correctly)
474
- input_dtype = query_states.dtype
475
- if input_dtype == torch.float32:
476
- if torch.is_autocast_enabled():
477
- target_dtype = torch.get_autocast_gpu_dtype()
478
- # Handle the case where the model is quantized
479
- elif hasattr(self.config, '_pre_quantization_dtype'):
480
- target_dtype = self.config._pre_quantization_dtype
481
- else:
482
- target_dtype = query_states.dtype
483
-
484
- logger.warning_once(
485
- f'The input hidden states seems to be silently casted in float32, this might be '
486
- +
487
- f'related to the fact you have upcasted embedding or layer norm layers in '
488
- + f'float32. We will cast back the input in {target_dtype}.')
489
-
490
- query_states = query_states.to(target_dtype)
491
- key_states = key_states.to(target_dtype)
492
- value_states = value_states.to(target_dtype)
493
-
494
- attn_output = self._flash_attention_forward(
495
- query_states,
496
- key_states,
497
- value_states,
498
- attention_mask,
499
- q_len,
500
- dropout=dropout_rate,
501
- )
502
-
503
- attn_output = attn_output.reshape(bsz, q_len,
504
- self.hidden_size).contiguous()
505
- attn_output = self.out_proj(attn_output)
506
-
507
- if not output_attentions:
508
- attn_weights = None
509
-
510
- return attn_output, attn_weights, past_key_value # type: ignore
511
-
512
- def _flash_attention_forward(
513
- self,
514
- query_states: torch.Tensor,
515
- key_states: torch.Tensor,
516
- value_states: torch.Tensor,
517
- attention_mask: Union[torch.LongTensor, None],
518
- query_length: int,
519
- dropout: float = 0.0,
520
- softmax_scale: Optional[float] = None,
521
- ):
522
- """Use FlashAttention, stripping padding tokens if necessary.
523
-
524
- Args:
525
- query_states (torch.Tensor): Input query states to be passed to Flash Attention API
526
- key_states (torch.Tensor): Input key states to be passed to Flash Attention API
527
- value_states (torch.Tensor): Input value states to be passed to Flash Attention API
528
- attention_mask (torch.LongTensor | None): The padding mask - corresponds to a tensor of size
529
- (batch_size, seq_len) where 0 stands for the position of padding tokens and 1
530
- for the position of non-padding tokens.
531
- query_length (int): The length of the query sequence
532
- dropout (float): Attention dropout
533
- softmax_scale (float, optional): The scaling of QK^T before applying softmax.
534
- Defaults to 1 / sqrt(head_dim)
535
- """
536
- causal = True
537
- # Contains at least one padding token in the sequence
538
- if attention_mask is not None:
539
- batch_size = query_states.shape[0]
540
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
541
- query_states, key_states, value_states, attention_mask,
542
- query_length)
543
-
544
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
545
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
546
-
547
- attn_output_unpad = flash_attn_varlen_func(
548
- query_states,
549
- key_states,
550
- value_states,
551
- cu_seqlens_q=cu_seqlens_q,
552
- cu_seqlens_k=cu_seqlens_k,
553
- max_seqlen_q=max_seqlen_in_batch_q,
554
- max_seqlen_k=max_seqlen_in_batch_k,
555
- dropout_p=dropout,
556
- softmax_scale=softmax_scale,
557
- causal=causal,
558
- )
559
-
560
- attn_output = pad_input(
561
- attn_output_unpad,
562
- indices_q,
563
- batch_size,
564
- query_length,
565
- )
566
- else:
567
- attn_output = flash_attn_func(
568
- query_states,
569
- key_states,
570
- value_states,
571
- dropout,
572
- softmax_scale=softmax_scale,
573
- causal=causal,
574
- )
575
-
576
- return attn_output
577
-
578
- def _upad_input(self, query_layer: torch.Tensor, key_layer: torch.Tensor,
579
- value_layer: torch.Tensor, attention_mask: torch.Tensor,
580
- query_length: int):
581
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
582
- attention_mask)
583
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
584
-
585
- key_layer = index_first_axis(
586
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
587
- head_dim), indices_k)
588
- value_layer = index_first_axis(
589
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
590
- head_dim), indices_k)
591
- if query_length == kv_seq_len:
592
- query_layer = index_first_axis(
593
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
594
- head_dim), indices_k)
595
- cu_seqlens_q = cu_seqlens_k
596
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
597
- indices_q = indices_k
598
- elif query_length == 1:
599
- max_seqlen_in_batch_q = 1
600
- cu_seqlens_q = torch.arange(
601
- batch_size + 1, dtype=torch.int32, device=query_layer.device
602
- ) # There is a memcpy here, that is very bad.
603
- indices_q = cu_seqlens_q[:-1]
604
- query_layer = query_layer.squeeze(1)
605
- else:
606
- # The -q_len: slice assumes left padding.
607
- attention_mask = attention_mask[:, -query_length:]
608
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
609
- query_layer, attention_mask)
610
-
611
- return (
612
- query_layer,
613
- key_layer,
614
- value_layer,
615
- indices_q,
616
- (cu_seqlens_q, cu_seqlens_k),
617
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
618
- )
619
-
620
-
621
- DBRX_ATTENTION_CLASSES = {
622
- 'eager': DbrxAttention,
623
- 'flash_attention_2': DbrxFlashAttention2,
624
- }
625
-
626
-
627
- class DbrxNormAttentionNorm(nn.Module):
628
-
629
- def __init__(
630
- self,
631
- hidden_size: int,
632
- num_heads: int,
633
- max_position_embeddings: int,
634
- resid_pdrop: float,
635
- attn_implementation: str,
636
- attn_config: DbrxAttentionConfig,
637
- block_idx: Optional[int] = None,
638
- ):
639
- super().__init__()
640
- self.block_idx = block_idx
641
- self.resid_pdrop = resid_pdrop
642
- self.norm_1 = nn.LayerNorm(hidden_size, bias=False)
643
- self.attn = DBRX_ATTENTION_CLASSES[attn_implementation](
644
- hidden_size=hidden_size,
645
- num_heads=num_heads,
646
- max_position_embeddings=max_position_embeddings,
647
- attn_config=attn_config,
648
- block_idx=block_idx,
649
- )
650
- self.norm_2 = nn.LayerNorm(hidden_size, bias=False)
651
-
652
- def forward(
653
- self,
654
- hidden_states: torch.Tensor,
655
- position_ids: torch.LongTensor,
656
- attention_mask: Optional[torch.Tensor] = None,
657
- past_key_value: Optional[Cache] = None,
658
- output_attentions: bool = False,
659
- use_cache: bool = False,
660
- cache_position: Optional[torch.LongTensor] = None,
661
- **kwargs: Any,
662
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
663
- Optional[Cache]]:
664
-
665
- residual_states = hidden_states
666
- hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
667
-
668
- hidden_states, attn_weights, past_key_value = self.attn(
669
- hidden_states=hidden_states,
670
- attention_mask=attention_mask,
671
- position_ids=position_ids,
672
- past_key_value=past_key_value,
673
- output_attentions=output_attentions,
674
- use_cache=use_cache,
675
- cache_position=cache_position,
676
- **kwargs,
677
- )
678
-
679
- hidden_states = nn.functional.dropout(hidden_states,
680
- p=self.resid_pdrop,
681
- training=self.training)
682
- hidden_states = hidden_states + residual_states
683
-
684
- residual_states = hidden_states
685
- hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
686
-
687
- return residual_states, hidden_states, attn_weights, past_key_value
688
-
689
-
690
- class DbrxRouter(nn.Module):
691
-
692
- def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
693
- moe_jitter_eps: Optional[float],
694
- moe_normalize_expert_weights: Optional[float],
695
- uniform_expert_assignment: bool):
696
- super().__init__()
697
- self.hidden_size = hidden_size
698
- self.moe_num_experts = moe_num_experts
699
- self.moe_top_k = moe_top_k
700
- self.moe_jitter_eps = moe_jitter_eps
701
- self.moe_normalize_expert_weights = moe_normalize_expert_weights
702
- self.uniform_expert_assignment = uniform_expert_assignment
703
-
704
- self.layer = nn.Linear(self.hidden_size,
705
- self.moe_num_experts,
706
- bias=False)
707
-
708
- def jitter(self, x: torch.Tensor) -> torch.Tensor:
709
- if self.moe_jitter_eps is None:
710
- raise RuntimeError('The router does not have moe_jitter_eps set.')
711
- low = 1.0 - self.moe_jitter_eps
712
- high = 1.0 + self.moe_jitter_eps
713
- noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
714
- return low + noise * (high - low)
715
-
716
- def forward(
717
- self, x: torch.Tensor
718
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
719
- if self.training and self.moe_jitter_eps is not None:
720
- x = x * self.jitter(x)
721
-
722
- weights = self.layer(x.view(-1,
723
- x.shape[-1])).softmax(dim=-1,
724
- dtype=torch.float32)
725
- top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
726
-
727
- if self.moe_normalize_expert_weights:
728
- top_weights = top_weights / torch.norm(
729
- top_weights,
730
- p=self.moe_normalize_expert_weights,
731
- dim=-1,
732
- keepdim=True)
733
-
734
- if self.uniform_expert_assignment:
735
- with torch.no_grad():
736
- uniform_tensor = torch.arange(
737
- 0,
738
- top_experts.numel(),
739
- device=top_experts.device,
740
- dtype=top_experts.dtype) % self.moe_num_experts
741
- top_experts = uniform_tensor.reshape(top_experts.shape)
742
- # Note, weights and top_weights are not changed
743
-
744
- weights = weights.to(x.dtype)
745
- top_weights = top_weights.to(x.dtype)
746
- return weights, top_weights, top_experts # type: ignore
747
-
748
-
749
- class DbrxMLP(nn.Module):
750
-
751
- def __init__(self, hidden_size: int, ffn_hidden_size: int, ffn_act_fn: dict):
752
- super().__init__()
753
-
754
- self.w1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
755
- self.v1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
756
- self.w2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
757
- self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
758
-
759
- def forward(self, x: torch.Tensor) -> torch.Tensor:
760
-
761
- return self.w2(self.activation_fn(self.w1(x)) * self.v1(x))
762
-
763
-
764
- class DbrxExperts(nn.Module):
765
-
766
- def __init__(self, hidden_size: int, ffn_hidden_size: int,
767
- moe_num_experts: int, ffn_act_fn: dict):
768
- super().__init__()
769
- self.moe_num_experts = moe_num_experts
770
- self.mlp = nn.ModuleList([DbrxMLP(hidden_size, ffn_hidden_size, ffn_act_fn) for _ in range(moe_num_experts)])
771
-
772
- def forward(self, x: torch.Tensor, weights: torch.Tensor,
773
- top_weights: torch.Tensor,
774
- top_experts: torch.LongTensor) -> torch.Tensor:
775
- bsz, q_len, hidden_size = x.shape
776
- x = x.view(-1, hidden_size)
777
- out = torch.zeros_like(x)
778
-
779
- expert_mask = nn.functional.one_hot(
780
- top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
781
- for expert_idx in range(0, self.moe_num_experts):
782
- topk_idx, token_idx = torch.where(expert_mask[expert_idx])
783
- if token_idx.shape[0] == 0:
784
- continue
785
-
786
- token_list = token_idx.tolist()
787
- topk_list = topk_idx.tolist()
788
-
789
- expert_tokens = x[None, token_list].reshape(-1, hidden_size)
790
- expert_out = self.mlp[expert_idx](expert_tokens) * top_weights[token_list, topk_list, None]
791
-
792
- out.index_add_(0, token_idx, expert_out)
793
-
794
- out = out.reshape(bsz, q_len, hidden_size)
795
- return out
796
-
797
-
798
- class DbrxFFN(nn.Module):
799
-
800
- def __init__(self, hidden_size: int, ffn_config: DbrxFFNConfig):
801
- super().__init__()
802
-
803
- self.router = DbrxRouter(
804
- hidden_size,
805
- moe_num_experts=ffn_config.moe_num_experts,
806
- moe_top_k=ffn_config.moe_top_k,
807
- moe_jitter_eps=ffn_config.moe_jitter_eps,
808
- moe_normalize_expert_weights=ffn_config.
809
- moe_normalize_expert_weights,
810
- uniform_expert_assignment=ffn_config.uniform_expert_assignment,
811
- )
812
-
813
- self.experts = DbrxExperts(
814
- hidden_size=hidden_size,
815
- ffn_hidden_size=ffn_config.ffn_hidden_size,
816
- moe_num_experts=ffn_config.moe_num_experts,
817
- ffn_act_fn=ffn_config.ffn_act_fn,
818
- )
819
-
820
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
821
- weights, top_weights, top_experts = self.router(x)
822
- out = self.experts(x, weights, top_weights, top_experts)
823
- return out, weights
824
-
825
-
826
- class DbrxBlock(nn.Module):
827
-
828
- def __init__(self, config: DbrxConfig, block_idx: int):
829
- super().__init__()
830
- self.hidden_size = config.d_model
831
- self.resid_pdrop = config.resid_pdrop
832
- self.block_idx = block_idx
833
- self.norm_attn_norm = DbrxNormAttentionNorm(
834
- hidden_size=config.d_model,
835
- num_heads=config.n_heads,
836
- max_position_embeddings=config.max_seq_len,
837
- resid_pdrop=config.resid_pdrop,
838
- attn_implementation=config._attn_implementation,
839
- attn_config=config.attn_config,
840
- block_idx=block_idx,
841
- )
842
- self.ffn = DbrxFFN(hidden_size=config.d_model,
843
- ffn_config=config.ffn_config)
844
-
845
- def forward(
846
- self,
847
- hidden_states: torch.Tensor,
848
- position_ids: torch.LongTensor,
849
- attention_mask: Optional[torch.Tensor] = None,
850
- past_key_value: Optional[Cache] = None,
851
- output_attentions: Optional[bool] = False,
852
- output_router_logits: Optional[bool] = False,
853
- use_cache: Optional[bool] = False,
854
- cache_position: Optional[torch.LongTensor] = None,
855
- **kwargs: Any,
856
- ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]],
857
- Tuple[torch.Tensor, Optional[Cache]], Tuple[
858
- torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
859
- Tuple[torch.Tensor, Optional[torch.Tensor],
860
- Optional[torch.Tensor]], Tuple[
861
- torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
862
- Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache],
863
- Optional[torch.Tensor]],]:
864
- """Forward function for DbrxBlock.
865
-
866
- Args:
867
- hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
868
- position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
869
- attention_mask (`torch.Tensor`, optional): attention mask of size (batch_size, sequence_length)
870
- if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
871
- if default attention is used.
872
- past_key_value (`Tuple(torch.Tensor)`, optional): cached past key and value projection states
873
- output_attentions (`bool`, optional): Whether or not to return the attentions tensors of all
874
- attention layers. See `attentions` under returned tensors for more detail.
875
- output_router_logits (`bool`, optional): Whether or not to return the router logits.
876
- use_cache (`bool`, optional): If set to `True`, `past_key_values` key value states are
877
- returned and can be used to speed up decoding (see `past_key_values`).
878
- cache_position (`torch.LongTensor`, optional): position ids of the cache
879
- """
880
- if 'padding_mask' in kwargs:
881
- warnings.warn(
882
- 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
883
- )
884
-
885
- # Norm + Attention + Norm
886
- resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
887
- hidden_states=hidden_states,
888
- attention_mask=attention_mask,
889
- position_ids=position_ids,
890
- past_key_value=past_key_value,
891
- output_attentions=output_attentions,
892
- use_cache=use_cache,
893
- cache_position=cache_position,
894
- **kwargs,
895
- )
896
-
897
- # Fully Connected
898
- hidden_states, router_logits = self.ffn(hidden_states)
899
- hidden_states = nn.functional.dropout(hidden_states,
900
- p=self.resid_pdrop,
901
- training=self.training)
902
- hidden_states = resid_states + hidden_states
903
-
904
- outputs = (hidden_states,)
905
-
906
- if output_attentions:
907
- outputs += (self_attn_weights,)
908
-
909
- if use_cache:
910
- outputs += (present_key_value,)
911
-
912
- if output_router_logits:
913
- outputs += (router_logits,)
914
-
915
- return outputs
916
-
917
-
918
- class DbrxPreTrainedModel(PreTrainedModel):
919
- config_class = DbrxConfig
920
- base_model_prefix = 'transformer'
921
- supports_gradient_checkpointing = True
922
- _no_split_modules = ['DbrxBlock']
923
- _skip_keys_device_placement = ['past_key_values']
924
- _supports_flash_attn_2 = True
925
- _supports_sdpa = False
926
- _supports_cache_class = True
927
-
928
- def _init_weights(self, module: nn.Module):
929
- std = self.config.initializer_range
930
- if isinstance(module, nn.Linear):
931
- module.weight.data.normal_(mean=0.0, std=std)
932
- if module.bias is not None:
933
- module.bias.data.zero_()
934
- elif isinstance(module, nn.Embedding):
935
- module.weight.data.normal_(mean=0.0, std=std)
936
- if module.padding_idx is not None:
937
- module.weight.data[module.padding_idx].zero_()
938
- elif isinstance(module, nn.LayerNorm):
939
- module.weight.data.normal_(mean=0.0, std=std)
940
- if module.bias is not None:
941
- module.bias.data.zero_()
942
-
943
- def _setup_cache(self, cache_cls: Any, max_batch_size: int,
944
- max_cache_len: int): # TODO: how to set var type of class?
945
- if self.config._attn_implementation == 'flash_attention_2' and cache_cls == StaticCache:
946
- raise ValueError(
947
- '`static` cache implementation is not compatible with ' +
948
- '`attn_implementation==flash_attention_2`. Make sure to use ' +
949
- '`spda` in the mean time and open an issue at https://github.com/huggingface/transformers.'
950
- )
951
-
952
- for block in self.transformer.blocks:
953
- device = block.norm_attn_norm.norm_1.weight.device
954
- if hasattr(self.config, '_pre_quantization_dtype'):
955
- dtype = self.config._pre_quantization_dtype
956
- else:
957
- dtype = block.norm_attn_norm.attn.out_proj.weight.dtype
958
- block.norm_attn_norm.attn.past_key_value = cache_cls(self.config,
959
- max_batch_size,
960
- max_cache_len,
961
- device=device,
962
- dtype=dtype)
963
-
964
- def _reset_cache(self):
965
- for block in self.transformer.blocks:
966
- block.norm_attn_norm.attn.past_key_value = None
967
-
968
-
969
- class DbrxModel(DbrxPreTrainedModel):
970
- """Transformer decoder consisting of *config.num_hidden_layers*
971
-
972
- [`DbrxBlock`] layers.
973
-
974
- Args:
975
- config: DbrxConfig
976
- """
977
-
978
- def __init__(self, config: DbrxConfig):
979
- super().__init__(config)
980
- self.padding_idx = config.pad_token_id
981
- self.vocab_size = config.vocab_size
982
- self.emb_pdrop = config.emb_pdrop
983
-
984
- self.wte = nn.Embedding(config.vocab_size, config.d_model,
985
- self.padding_idx)
986
- self.blocks = nn.ModuleList([
987
- DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)
988
- ])
989
- self.norm_f = nn.LayerNorm(config.d_model, bias=False)
990
- self.gradient_checkpointing = False
991
-
992
- # Initialize weights and apply final processing
993
- self.post_init()
994
-
995
- def get_input_embeddings(self) -> nn.Embedding:
996
- return self.wte
997
-
998
- def set_input_embeddings(self, value: nn.Embedding):
999
- self.wte = value
1000
-
1001
- def _autocast_input_embeddings(self,
1002
- inputs_embeds: torch.Tensor) -> torch.Tensor:
1003
- if inputs_embeds.device.type == 'cuda' and torch.is_autocast_enabled():
1004
- return inputs_embeds.to(dtype=torch.get_autocast_gpu_dtype())
1005
- elif inputs_embeds.device.type == 'cpu' and torch.is_autocast_cpu_enabled(
1006
- ):
1007
- return inputs_embeds.to(dtype=torch.get_autocast_cpu_dtype())
1008
- else:
1009
- return inputs_embeds
1010
-
1011
- def forward(
1012
- self,
1013
- input_ids: Optional[torch.LongTensor] = None,
1014
- attention_mask: Optional[torch.Tensor] = None,
1015
- position_ids: Optional[torch.LongTensor] = None,
1016
- past_key_values: Optional[Cache] = None,
1017
- inputs_embeds: Optional[torch.Tensor] = None,
1018
- use_cache: Optional[bool] = None,
1019
- output_attentions: Optional[bool] = None,
1020
- output_hidden_states: Optional[bool] = None,
1021
- output_router_logits: Optional[bool] = None,
1022
- return_dict: Optional[bool] = None,
1023
- cache_position: Optional[torch.LongTensor] = None,
1024
- ) -> Union[Tuple, MoeModelOutputWithPast]:
1025
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1026
- output_hidden_states = (output_hidden_states
1027
- if output_hidden_states is not None else
1028
- self.config.output_hidden_states)
1029
- output_router_logits = (output_router_logits
1030
- if output_router_logits is not None else
1031
- self.config.output_router_logits)
1032
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1033
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1034
-
1035
- if (input_ids is None) ^ (inputs_embeds is not None):
1036
- raise ValueError(
1037
- 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
1038
- )
1039
-
1040
- if self.gradient_checkpointing and self.training and use_cache:
1041
- logger.warning_once(
1042
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
1043
- )
1044
- use_cache = False
1045
-
1046
- if inputs_embeds is None:
1047
- inputs_embeds = self.wte(input_ids)
1048
-
1049
- inputs_embeds = self._autocast_input_embeddings(
1050
- inputs_embeds) # type: ignore
1051
- inputs_embeds = nn.functional.dropout(inputs_embeds,
1052
- p=self.emb_pdrop,
1053
- training=self.training)
1054
-
1055
- past_seen_tokens = 0
1056
- if use_cache: # kept for BC (cache positions)
1057
- if not isinstance(past_key_values, StaticCache):
1058
- past_key_values = DynamicCache.from_legacy_cache(
1059
- past_key_values)
1060
- past_seen_tokens = past_key_values.get_seq_length( # type: ignore
1061
- )
1062
-
1063
- if cache_position is None:
1064
- if isinstance(past_key_values, StaticCache):
1065
- raise ValueError(
1066
- 'cache_position is a required argument when using StaticCache.'
1067
- )
1068
- cache_position = torch.arange( # type: ignore
1069
- past_seen_tokens,
1070
- past_seen_tokens + inputs_embeds.shape[1],
1071
- device=inputs_embeds.device)
1072
-
1073
- if position_ids is None:
1074
- position_ids = cache_position.unsqueeze(0) # type: ignore
1075
-
1076
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
1077
- cache_position) # type: ignore
1078
-
1079
- # embed positions
1080
- hidden_states = inputs_embeds
1081
-
1082
- # decoder layers
1083
- all_hidden_states = () if output_hidden_states else None
1084
- all_self_attns = () if output_attentions else None
1085
- all_router_logits = () if output_router_logits else None
1086
- next_decoder_cache = None
1087
-
1088
- for block in self.blocks:
1089
- if output_hidden_states:
1090
- all_hidden_states += (hidden_states,) # type: ignore
1091
-
1092
- if self.gradient_checkpointing and self.training:
1093
- block_outputs = self._gradient_checkpointing_func(
1094
- block.__call__,
1095
- hidden_states,
1096
- attention_mask=causal_mask,
1097
- position_ids=position_ids,
1098
- past_key_values=past_key_values,
1099
- output_attentions=output_attentions,
1100
- output_router_logits=output_router_logits,
1101
- use_cache=use_cache,
1102
- cache_position=cache_position,
1103
- )
1104
- else:
1105
- block_outputs = block(
1106
- hidden_states,
1107
- attention_mask=causal_mask,
1108
- position_ids=position_ids,
1109
- past_key_value=past_key_values,
1110
- output_attentions=output_attentions,
1111
- output_router_logits=output_router_logits,
1112
- use_cache=use_cache,
1113
- cache_position=cache_position,
1114
- )
1115
-
1116
- hidden_states = block_outputs[0]
1117
-
1118
- if use_cache:
1119
- next_decoder_cache = block_outputs[
1120
- 2 if output_attentions else 1]
1121
-
1122
- if output_attentions:
1123
- all_self_attns += (block_outputs[1],) # type: ignore
1124
-
1125
- if output_router_logits:
1126
- all_router_logits += (block_outputs[-1],) # type: ignore
1127
-
1128
- hidden_states = self.norm_f(hidden_states)
1129
-
1130
- # add hidden states from the last decoder layer
1131
- if output_hidden_states:
1132
- all_hidden_states += (hidden_states,) # type: ignore
1133
-
1134
- next_cache = None
1135
- if use_cache:
1136
- next_cache = (
1137
- next_decoder_cache.to_legacy_cache() # type: ignore
1138
- if isinstance(next_decoder_cache, Cache) else
1139
- next_decoder_cache)
1140
- if not return_dict:
1141
- return tuple(v for v in [
1142
- hidden_states, next_cache, all_hidden_states, all_self_attns,
1143
- all_router_logits
1144
- ] if v is not None)
1145
- return MoeModelOutputWithPast(
1146
- last_hidden_state=hidden_states,
1147
- past_key_values=next_cache,
1148
- hidden_states=all_hidden_states,
1149
- attentions=all_self_attns,
1150
- router_logits=all_router_logits,
1151
- )
1152
-
1153
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1154
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1155
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1156
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1157
- def _update_causal_mask(
1158
- self, attention_mask: Optional[torch.Tensor],
1159
- input_tensor: torch.Tensor,
1160
- cache_position: torch.Tensor) -> Optional[torch.Tensor]:
1161
- if self.config._attn_implementation == 'flash_attention_2':
1162
- if attention_mask is not None and 0.0 in attention_mask:
1163
- return attention_mask
1164
- return None
1165
-
1166
- dtype, device = input_tensor.dtype, input_tensor.device
1167
- min_dtype = torch.finfo(dtype).min
1168
- sequence_length = input_tensor.shape[1]
1169
- if hasattr(self.blocks[0].norm_attn_norm.attn,
1170
- 'past_key_value'): # static cache
1171
- target_length = self.config.max_position_embeddings
1172
- else: # dynamic cache
1173
- target_length = (attention_mask.shape[-1] if isinstance(
1174
- attention_mask, torch.Tensor) else cache_position[-1] + 1)
1175
- target_length = int(target_length)
1176
-
1177
- causal_mask = torch.full((sequence_length, target_length),
1178
- fill_value=min_dtype,
1179
- dtype=dtype,
1180
- device=device)
1181
- if sequence_length != 1:
1182
- causal_mask = torch.triu(causal_mask, diagonal=1)
1183
- causal_mask *= torch.arange(
1184
- target_length, device=device) > cache_position.reshape(-1, 1)
1185
- causal_mask = causal_mask[None,
1186
- None, :, :].expand(input_tensor.shape[0], 1,
1187
- -1, -1)
1188
- if attention_mask is not None:
1189
- causal_mask = causal_mask.clone(
1190
- ) # copy to contiguous memory for in-place edit
1191
- if attention_mask.dim() == 2:
1192
- mask_length = attention_mask.shape[-1]
1193
- padding_mask = causal_mask[..., :mask_length].eq(
1194
- 0.0) * attention_mask[:, None, None, :].eq(0.0)
1195
- causal_mask[..., :mask_length] = causal_mask[
1196
- ..., :mask_length].masked_fill(padding_mask, min_dtype)
1197
- elif attention_mask.dim() == 4:
1198
- # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1199
- # cache. In that case, the 4D attention mask attends to the newest tokens only.
1200
- if attention_mask.shape[
1201
- -2] < cache_position[0] + sequence_length:
1202
- offset = cache_position[0]
1203
- else:
1204
- offset = 0
1205
- mask_shape = attention_mask.shape
1206
- mask_slice = (attention_mask.eq(0.0)).to(
1207
- dtype=dtype) * min_dtype
1208
- causal_mask[:mask_shape[0], :mask_shape[1],
1209
- offset:mask_shape[2] +
1210
- offset, :mask_shape[3]] = mask_slice
1211
-
1212
- if (self.config._attn_implementation == 'sdpa' and
1213
- attention_mask is not None and
1214
- attention_mask.device.type == 'cuda'):
1215
- # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1216
- is_tracing = (
1217
- torch.jit.is_tracing() or
1218
- isinstance(input_tensor, torch.fx.Proxy) or # type: ignore
1219
- (hasattr(torch, '_dynamo') and torch._dynamo.is_compiling()))
1220
- if not is_tracing and torch.any(attention_mask != 1):
1221
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
- # Details: https://github.com/pytorch/pytorch/issues/110213
1224
- causal_mask = AttentionMaskConverter._unmask_unattended(
1225
- causal_mask, min_dtype)
1226
-
1227
- return causal_mask
1228
-
1229
-
1230
- class DbrxForCausalLM(DbrxPreTrainedModel):
1231
-
1232
- def __init__(self, config: DbrxConfig):
1233
- super().__init__(config)
1234
- self.transformer = DbrxModel(config)
1235
- self.vocab_size = config.vocab_size
1236
- self.lm_head = nn.Linear(config.hidden_size,
1237
- config.vocab_size,
1238
- bias=False)
1239
- self.router_aux_loss_coef = config.router_aux_loss_coef
1240
- self.num_experts = config.ffn_config.moe_num_experts
1241
- self.num_experts_per_tok = config.ffn_config.moe_top_k
1242
-
1243
- # Initialize weights and apply final processing
1244
- self.post_init()
1245
-
1246
- def get_input_embeddings(self) -> nn.Embedding:
1247
- return self.transformer.get_input_embeddings()
1248
-
1249
- def set_input_embeddings(self, value: nn.Embedding):
1250
- self.transformer.set_input_embeddings(value)
1251
-
1252
- def get_output_embeddings(self) -> nn.Linear:
1253
- return self.lm_head
1254
-
1255
- def set_output_embeddings(self, new_embeddings: nn.Linear):
1256
- self.lm_head = new_embeddings
1257
-
1258
- def set_decoder(self, decoder: DbrxModel):
1259
- self.transformer = decoder
1260
-
1261
- def get_decoder(self) -> DbrxModel:
1262
- return self.transformer
1263
-
1264
- def forward(
1265
- self,
1266
- input_ids: Optional[torch.LongTensor] = None,
1267
- attention_mask: Optional[torch.Tensor] = None,
1268
- position_ids: Optional[torch.LongTensor] = None,
1269
- past_key_values: Optional[Cache] = None,
1270
- inputs_embeds: Optional[torch.Tensor] = None,
1271
- labels: Optional[torch.LongTensor] = None,
1272
- use_cache: Optional[bool] = None,
1273
- output_attentions: Optional[bool] = None,
1274
- output_hidden_states: Optional[bool] = None,
1275
- output_router_logits: Optional[bool] = None,
1276
- return_dict: Optional[bool] = None,
1277
- cache_position: Optional[torch.LongTensor] = None,
1278
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1279
- r"""Forward function for causal language modeling.
1280
-
1281
- Example:
1282
- ```python
1283
- >>> from transformers import AutoTokenizer, DbrxForCausalLM
1284
-
1285
- >>> model = DbrxForCausalLM.from_pretrained("databricks/dbrx")
1286
- >>> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx")
1287
-
1288
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1289
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1290
-
1291
- >>> # Generate
1292
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1293
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1294
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1295
- ```
1296
- """
1297
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1298
- output_hidden_states = (output_hidden_states
1299
- if output_hidden_states is not None else
1300
- self.config.output_hidden_states)
1301
- output_router_logits = (output_router_logits
1302
- if output_router_logits is not None else
1303
- self.config.output_router_logits)
1304
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1305
-
1306
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1307
- outputs = self.transformer(
1308
- input_ids=input_ids,
1309
- attention_mask=attention_mask,
1310
- position_ids=position_ids,
1311
- past_key_values=past_key_values,
1312
- inputs_embeds=inputs_embeds,
1313
- use_cache=use_cache,
1314
- output_attentions=output_attentions,
1315
- output_hidden_states=output_hidden_states,
1316
- output_router_logits=output_router_logits,
1317
- return_dict=return_dict,
1318
- cache_position=cache_position,
1319
- )
1320
-
1321
- hidden_states = outputs[0]
1322
- logits = self.lm_head(hidden_states)
1323
-
1324
- loss = None
1325
- if labels is not None:
1326
- # Shift so that tokens < n predict n
1327
- shift_logits = logits[..., :-1, :].contiguous()
1328
- shift_labels = labels[..., 1:].contiguous()
1329
- # Flatten the tokens
1330
- loss_fct = nn.CrossEntropyLoss()
1331
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1332
- shift_labels = shift_labels.view(-1)
1333
- # Enable model parallelism
1334
- shift_labels = shift_labels.to(shift_logits.device)
1335
- loss = loss_fct(shift_logits, shift_labels)
1336
-
1337
- aux_loss = None
1338
- if output_router_logits:
1339
- aux_loss = load_balancing_loss_func(
1340
- outputs.router_logits if return_dict else outputs[-1],
1341
- self.num_experts,
1342
- self.num_experts_per_tok,
1343
- attention_mask,
1344
- )
1345
- if labels is not None and loss is not None:
1346
- loss += self.router_aux_loss_coef * aux_loss.to(
1347
- loss.device) # make sure to reside in the same device
1348
-
1349
- if not return_dict:
1350
- output = (logits,) + outputs[1:]
1351
- return (loss,) + output if loss is not None else output
1352
-
1353
- return MoeCausalLMOutputWithPast(
1354
- loss=loss,
1355
- aux_loss=aux_loss,
1356
- logits=logits,
1357
- past_key_values=outputs.past_key_values,
1358
- hidden_states=outputs.hidden_states,
1359
- attentions=outputs.attentions,
1360
- router_logits=outputs.router_logits,
1361
- )
1362
-
1363
- def prepare_inputs_for_generation(
1364
- self,
1365
- input_ids: torch.Tensor,
1366
- past_key_values: Optional[Cache] = None,
1367
- attention_mask: Optional[torch.Tensor] = None,
1368
- inputs_embeds: Optional[torch.Tensor] = None,
1369
- **kwargs: Any) -> Dict[str, Any]:
1370
- past_length = 0
1371
- if past_key_values is not None:
1372
- if isinstance(past_key_values, Cache):
1373
- cache_length = past_key_values.get_seq_length()
1374
- past_length = past_key_values.seen_tokens
1375
- max_cache_length = past_key_values.get_max_length()
1376
- else:
1377
- cache_length = past_length = past_key_values[0][0].shape[2]
1378
- max_cache_length = None
1379
-
1380
- # Keep only the unprocessed tokens:
1381
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1382
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1383
- # input)
1384
- if attention_mask is not None and attention_mask.shape[
1385
- 1] > input_ids.shape[1]:
1386
- input_ids = input_ids[:,
1387
- -(attention_mask.shape[1] - past_length):]
1388
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1389
- # input_ids based on the past_length.
1390
- elif past_length < input_ids.shape[1]:
1391
- input_ids = input_ids[:, past_length:]
1392
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1393
-
1394
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1395
- if (max_cache_length is not None and attention_mask is not None and
1396
- cache_length + input_ids.shape[1] > max_cache_length):
1397
- attention_mask = attention_mask[:, -max_cache_length:]
1398
-
1399
- position_ids = kwargs.get('position_ids', None)
1400
- if attention_mask is not None and position_ids is None:
1401
- # create position_ids on the fly for batch generation
1402
- position_ids = attention_mask.long().cumsum(-1) - 1
1403
- position_ids.masked_fill_(attention_mask == 0, 1)
1404
- if past_key_values:
1405
- position_ids = position_ids[:, -input_ids.shape[1]:]
1406
-
1407
- if self.generation_config.cache_implementation == 'static':
1408
- # generation with static cache
1409
- cache_position = kwargs.get('cache_position', None)
1410
- if cache_position is None:
1411
- past_length = 0
1412
- else:
1413
- past_length = cache_position[-1] + 1
1414
- input_ids = input_ids[:, past_length:]
1415
- position_ids = position_ids[:,
1416
- past_length:] if position_ids is not None else None
1417
-
1418
- # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1419
- # same goes for position ids. Could also help with continued generation.
1420
- input_length = position_ids.shape[
1421
- -1] if position_ids is not None else input_ids.shape[-1]
1422
- cache_position = torch.arange(past_length,
1423
- past_length + input_length,
1424
- device=input_ids.device)
1425
- position_ids = position_ids.contiguous(
1426
- ) if position_ids is not None else None
1427
-
1428
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1429
- if inputs_embeds is not None and past_key_values is None:
1430
- model_inputs = {'inputs_embeds': inputs_embeds}
1431
- else:
1432
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1433
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1434
- # TODO: use `next_tokens` directly instead.
1435
- model_inputs = {'input_ids': input_ids.contiguous()}
1436
-
1437
- model_inputs.update(
1438
- { # type: ignore
1439
- 'position_ids': position_ids,
1440
- 'cache_position': cache_position,
1441
- 'past_key_values': past_key_values,
1442
- 'use_cache': kwargs.get('use_cache'),
1443
- 'attention_mask': attention_mask,
1444
- }
1445
- )
1446
- return model_inputs
1447
-
1448
- @staticmethod
1449
- def _reorder_cache(past_key_values: Cache, beam_idx: torch.LongTensor):
1450
- reordered_past = ()
1451
- for layer_past in past_key_values:
1452
- reordered_past += (tuple(
1453
- past_state.index_select(0, beam_idx.to(past_state.device))
1454
- for past_state in layer_past),)
1455
- return reordered_past