Qubitium commited on
Commit
e642316
1 Parent(s): cfa3210

Upload modeling_dbrx.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dbrx.py +1462 -0
modeling_dbrx.py ADDED
@@ -0,0 +1,1462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Dbrx model."""
2
+
3
+ import math
4
+ import warnings
5
+ from copy import deepcopy
6
+ from functools import partial
7
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from transformers.modeling_outputs import (MoeCausalLMOutputWithPast,
16
+ MoeModelOutputWithPast)
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import is_flash_attn_2_available, logging
19
+
20
+ from .configuration_dbrx import DbrxAttentionConfig, DbrxConfig, DbrxFFNConfig
21
+
22
+ if is_flash_attn_2_available():
23
+ try:
24
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
25
+ from flash_attn.bert_padding import pad_input # noqa
26
+ from flash_attn.bert_padding import index_first_axis, unpad_input
27
+ except:
28
+ pass
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ _CONFIG_FOR_DOC = 'DbrxConfig'
33
+
34
+ #############################################################################
35
+ # Copied from LLaMaRotaryEmbedding
36
+ #############################################################################
37
+
38
+
39
+ class DbrxRotaryEmbedding(nn.Module):
40
+
41
+ def __init__(self,
42
+ dim: int,
43
+ max_position_embeddings: int = 2048,
44
+ base: float = 10000.0,
45
+ scaling_factor: float = 1.0):
46
+ super().__init__()
47
+ self.scaling_factor = scaling_factor
48
+ self.dim = dim
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.base = base
51
+ inv_freq = 1.0 / (self.base**(
52
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
53
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
54
+ # For BC we register cos and sin cached
55
+ self.max_seq_len_cached = max_position_embeddings
56
+
57
+ @torch.no_grad()
58
+ def forward(
59
+ self, x: torch.Tensor, position_ids: torch.LongTensor
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ # x: [bs, num_attention_heads, seq_len, head_size]
62
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
63
+ position_ids.shape[0], -1, 1)
64
+ position_ids_expanded = position_ids[:, None, :].float()
65
+ # Force float32 since bfloat16 loses precision on long contexts
66
+ # See https://github.com/huggingface/transformers/pull/29285
67
+ device_type = x.device.type
68
+ device_type = device_type if isinstance(
69
+ device_type, str) and device_type != 'mps' else 'cpu'
70
+ with torch.autocast(device_type=device_type, enabled=False):
71
+ freqs = (inv_freq_expanded.float()
72
+ @ position_ids_expanded.float()).transpose(1, 2)
73
+ emb = torch.cat((freqs, freqs), dim=-1)
74
+ cos = emb.cos()
75
+ sin = emb.sin()
76
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
77
+
78
+
79
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
80
+ """Rotates half the hidden dims of the input."""
81
+ x1 = x[..., :x.shape[-1] // 2]
82
+ x2 = x[..., x.shape[-1] // 2:]
83
+ return torch.cat((-x2, x1), dim=-1)
84
+
85
+
86
+ def apply_rotary_pos_emb(
87
+ q: torch.Tensor,
88
+ k: torch.Tensor,
89
+ cos: torch.Tensor,
90
+ sin: torch.Tensor,
91
+ unsqueeze_dim: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """Applies Rotary Position Embedding to the query and key tensors.
93
+
94
+ Args:
95
+ q (`torch.Tensor`): The query tensor.
96
+ k (`torch.Tensor`): The key tensor.
97
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
98
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
99
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
100
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos and
101
+ sin so that they can be properly broadcasted to the dimensions of q and k. For example, note
102
+ that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and
103
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
104
+ cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have
105
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
106
+
107
+ Returns:
108
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
109
+ """
110
+ cos = cos.unsqueeze(unsqueeze_dim)
111
+ sin = sin.unsqueeze(unsqueeze_dim)
112
+ q_embed = (q * cos) + (rotate_half(q) * sin)
113
+ k_embed = (k * cos) + (rotate_half(k) * sin)
114
+ return q_embed, k_embed
115
+
116
+
117
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
118
+ """Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
119
+
120
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
121
+ (batch, num_attention_heads, seqlen, head_dim)
122
+ """
123
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
124
+ if n_rep == 1:
125
+ return hidden_states
126
+ hidden_states = hidden_states[:, :,
127
+ None, :, :].expand(batch, num_key_value_heads,
128
+ n_rep, slen, head_dim)
129
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
130
+ head_dim)
131
+
132
+
133
+ #############################################################################
134
+
135
+ #############################################################################
136
+ # Modified from modeling_mixtral
137
+ #############################################################################
138
+
139
+
140
+ def load_balancing_loss_func(
141
+ gate_logits: torch.Tensor,
142
+ num_experts: int,
143
+ top_k: int,
144
+ attention_mask: Optional[torch.Tensor],
145
+ ) -> torch.Tensor:
146
+ r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
147
+
148
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
149
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
150
+ experts is too unbalanced.
151
+
152
+ Args:
153
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
154
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
155
+ shape [batch_size X sequence_length, num_experts].
156
+ num_experts (`int`):
157
+ Number of experts.
158
+ top_k (`int`):
159
+ The number of experts each token is routed to.
160
+ attention_mask (`torch.Tensor`, None):
161
+ The attention_mask used in forward function
162
+ shape [batch_size X sequence_length] if not None.
163
+
164
+ Returns:
165
+ The auxiliary loss.
166
+ """
167
+ if gate_logits is None or not isinstance(gate_logits, tuple):
168
+ return torch.tensor(0.0)
169
+
170
+ if isinstance(gate_logits, tuple):
171
+ compute_device = gate_logits[0].device
172
+ concatenated_gate_logits = torch.cat(
173
+ [layer_gate.to(compute_device) for layer_gate in gate_logits],
174
+ dim=0)
175
+
176
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits,
177
+ dim=-1)
178
+
179
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
180
+
181
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
182
+
183
+ if attention_mask is None:
184
+ # Compute the percentage of tokens routed to each experts
185
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
186
+
187
+ # Compute the average probability of routing to these experts
188
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
189
+ else:
190
+ batch_size, sequence_length = attention_mask.shape
191
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (
192
+ batch_size * sequence_length)
193
+
194
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
195
+ expert_attention_mask = (attention_mask[None, :, :, None, None].expand(
196
+ (num_hidden_layers, batch_size, sequence_length, top_k,
197
+ num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
198
+
199
+ # Compute the percentage of tokens routed to each experts
200
+ tokens_per_expert = torch.sum(
201
+ expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
202
+ expert_attention_mask, dim=0)
203
+
204
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
205
+ router_per_expert_attention_mask = (
206
+ attention_mask[None, :, :, None].expand(
207
+ (num_hidden_layers, batch_size, sequence_length,
208
+ num_experts)).reshape(-1, num_experts).to(compute_device))
209
+
210
+ # Compute the average probability of routing to these experts
211
+ router_prob_per_expert = torch.sum(
212
+ routing_weights * router_per_expert_attention_mask,
213
+ dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
214
+
215
+ overall_loss = torch.sum(tokens_per_expert *
216
+ router_prob_per_expert.unsqueeze(0))
217
+ return overall_loss * num_experts
218
+
219
+
220
+ #############################################################################
221
+
222
+
223
+ def resolve_ffn_act_fn(
224
+ ffn_act_fn: dict) -> Callable[[torch.Tensor], torch.Tensor]:
225
+ """Resolve the activation function for the feed-forward network.
226
+
227
+ Args:
228
+ ffn_act_fn (dict): The configuration dictionary for the activation function.
229
+ The dict config must specify the 'name' of a torch.nn.functional activation
230
+ function. All of other key values pairs are bound to the function as a partial.
231
+
232
+ Returns:
233
+ Callable[[torch.Tensor], torch.Tensor]: The activation function.
234
+ """
235
+ config = deepcopy(ffn_act_fn)
236
+ name = config.pop('name')
237
+ if not hasattr(nn.functional, name):
238
+ raise ValueError(f'Unrecognised activation function name ({name}).')
239
+ act = getattr(nn.functional, name)
240
+ return partial(act, **config)
241
+
242
+
243
+ #############################################################################
244
+ # Copied from LLaMaAttention
245
+ #############################################################################
246
+
247
+ def get_max_seqlen_in_batch(attention_mask):
248
+ max_num = torch.max(attention_mask)
249
+ # attention_mask: B x N
250
+ counts = []
251
+ for i in range(1, max_num + 1):
252
+ counts.append(
253
+ torch.sum(attention_mask == i, axis=-1)
254
+ ) # shape: B, count length of data point maksed with i
255
+ result = torch.stack(counts, axis=1)
256
+ result = result.flatten()
257
+ return result[result.nonzero()].squeeze(-1).to(dtype=torch.int32)
258
+
259
+
260
+ def _get_unpad_data(attention_mask):
261
+ seqlens_in_batch = get_max_seqlen_in_batch(
262
+ attention_mask
263
+ ) # attention_mask.sum(dim=-1, dtype=torch.int32)
264
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
265
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
266
+ cu_seqlens = F.pad(
267
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
268
+ )
269
+ return (
270
+ indices,
271
+ cu_seqlens,
272
+ max_seqlen_in_batch,
273
+ )
274
+
275
+
276
+ class DbrxAttention(nn.Module):
277
+ """Multi-head self attention."""
278
+
279
+ def __init__(self,
280
+ hidden_size: int,
281
+ num_heads: int,
282
+ max_position_embeddings: int,
283
+ attn_config: DbrxAttentionConfig,
284
+ block_idx: Optional[int] = None):
285
+ super().__init__()
286
+ self.hidden_size = hidden_size
287
+ self.num_heads = num_heads
288
+ self.head_dim = self.hidden_size // self.num_heads
289
+ self.max_position_embeddings = max_position_embeddings
290
+ self.block_idx = block_idx
291
+ self.config = attn_config
292
+ if block_idx is None:
293
+ logger.warning_once(
294
+ f'Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will '
295
+ +
296
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` '
297
+ + 'when creating this class.')
298
+
299
+ self.attn_pdrop = attn_config.attn_pdrop
300
+ self.clip_qkv = attn_config.clip_qkv
301
+ self.num_key_value_heads = attn_config.kv_n_heads
302
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
303
+ self.rope_theta = attn_config.rope_theta
304
+
305
+ self.q_proj = nn.Linear(self.hidden_size,
306
+ self.hidden_size,
307
+ bias=False)
308
+ self.k_proj = nn.Linear(self.hidden_size,
309
+ self.num_key_value_heads * self.head_dim,
310
+ bias=False)
311
+ self.v_proj = nn.Linear(self.hidden_size,
312
+ self.num_key_value_heads * self.head_dim,
313
+ bias=False)
314
+ self.out_proj = nn.Linear(self.hidden_size,
315
+ self.hidden_size,
316
+ bias=False)
317
+ self.rotary_emb = DbrxRotaryEmbedding(
318
+ self.head_dim,
319
+ max_position_embeddings=self.max_position_embeddings,
320
+ base=self.rope_theta,
321
+ )
322
+
323
+ def forward(
324
+ self,
325
+ hidden_states: torch.Tensor,
326
+ position_ids: torch.LongTensor,
327
+ attention_mask: Optional[torch.Tensor] = None,
328
+ past_key_value: Optional[Cache] = None,
329
+ output_attentions: bool = False,
330
+ use_cache: bool = False,
331
+ cache_position: Optional[torch.LongTensor] = None,
332
+ **kwargs: Any,
333
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
334
+ bsz, q_len, _ = hidden_states.size()
335
+
336
+ query_states = self.q_proj(hidden_states)
337
+ key_states = self.k_proj(hidden_states)
338
+ value_states = self.v_proj(hidden_states)
339
+ if self.clip_qkv is not None:
340
+ query_states = query_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
341
+ key_states = key_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
342
+ value_states = value_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
343
+
344
+ query_states = query_states.view(bsz, q_len, self.num_heads,
345
+ self.head_dim).transpose(1, 2)
346
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
347
+ self.head_dim).transpose(1, 2)
348
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
349
+ self.head_dim).transpose(1, 2)
350
+
351
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
352
+ cos, sin = self.rotary_emb(value_states, position_ids)
353
+ query_states, key_states = apply_rotary_pos_emb(query_states,
354
+ key_states, cos, sin)
355
+
356
+ if past_key_value is not None:
357
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
358
+ cache_kwargs = {
359
+ 'sin': sin,
360
+ 'cos': cos,
361
+ 'cache_position': cache_position
362
+ }
363
+ key_states, value_states = past_key_value.update(
364
+ key_states, value_states, self.block_idx, cache_kwargs)
365
+
366
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
367
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
368
+
369
+ attn_weights = torch.matmul(query_states, key_states.transpose(
370
+ 2, 3)) / math.sqrt(self.head_dim)
371
+
372
+ if attention_mask is not None: # no matter the length, we just slice it
373
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
374
+ attn_weights = attn_weights + causal_mask
375
+
376
+ # upcast attention to fp32
377
+ attn_weights = nn.functional.softmax(attn_weights,
378
+ dim=-1,
379
+ dtype=torch.float32).to(
380
+ query_states.dtype)
381
+ attn_weights = nn.functional.dropout(attn_weights,
382
+ p=self.attn_pdrop,
383
+ training=self.training)
384
+ attn_output = torch.matmul(attn_weights, value_states)
385
+
386
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
387
+ raise ValueError(
388
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
389
+ + f' {attn_output.size()}')
390
+
391
+ attn_output = attn_output.transpose(1, 2).contiguous()
392
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
393
+ attn_output = self.out_proj(attn_output)
394
+
395
+ if not output_attentions:
396
+ attn_weights = None
397
+
398
+ return attn_output, attn_weights, past_key_value
399
+
400
+
401
+ class DbrxFlashAttention2(DbrxAttention):
402
+ """Dbrx flash attention module.
403
+
404
+ This module inherits from `DbrxAttention` as the weights of the module stays
405
+ untouched. The only required change would be on the forward pass where it
406
+ calls the public API of flash attention.
407
+ """
408
+
409
+ def __init__(self, *args: Any, **kwargs: Any):
410
+ if not is_flash_attn_2_available():
411
+ raise ImportError(
412
+ 'Flash Attention 2 is not available. Please install it with `pip install flash-attn`.'
413
+ )
414
+
415
+ super().__init__(*args, **kwargs)
416
+
417
+ def forward(
418
+ self,
419
+ hidden_states: torch.Tensor,
420
+ attention_mask: Optional[torch.LongTensor] = None,
421
+ position_ids: Optional[torch.LongTensor] = None,
422
+ past_key_value: Optional[Cache] = None,
423
+ output_attentions: bool = False,
424
+ use_cache: bool = False,
425
+ cache_position: Optional[torch.LongTensor] = None,
426
+ **kwargs: Any,
427
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
428
+ Optional[Tuple[torch.Tensor]]]:
429
+ logger.info(
430
+ 'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
431
+ )
432
+ output_attentions = False
433
+
434
+ bsz, q_len, _ = hidden_states.size()
435
+
436
+ query_states = self.q_proj(hidden_states)
437
+ key_states = self.k_proj(hidden_states)
438
+ value_states = self.v_proj(hidden_states)
439
+ if self.clip_qkv is not None:
440
+ query_states = query_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
441
+ key_states = key_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
442
+ value_states = value_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
443
+
444
+ # Flash attention requires the input to have the shape
445
+ # batch_size x seq_length x head_dim x hidden_dim
446
+ # therefore we just need to keep the original shape
447
+ query_states = query_states.view(bsz, q_len, self.num_heads,
448
+ self.head_dim).transpose(1, 2)
449
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
450
+ self.head_dim).transpose(1, 2)
451
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
452
+ self.head_dim).transpose(1, 2)
453
+
454
+ cos, sin = self.rotary_emb(value_states, position_ids)
455
+ query_states, key_states = apply_rotary_pos_emb(query_states,
456
+ key_states, cos, sin)
457
+
458
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
459
+
460
+ if past_key_value is not None:
461
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
462
+ cache_kwargs = {
463
+ 'sin': sin,
464
+ 'cos': cos,
465
+ 'cache_position': cache_position
466
+ }
467
+ key_states, value_states = past_key_value.update(
468
+ key_states, value_states, self.block_idx, cache_kwargs)
469
+
470
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
471
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
472
+ # to be able to avoid many of these transpose/reshape/view.
473
+ query_states = query_states.transpose(1, 2)
474
+ key_states = key_states.transpose(1, 2)
475
+ value_states = value_states.transpose(1, 2)
476
+
477
+ dropout_rate = self.attn_pdrop if self.training else 0.0
478
+
479
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
480
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
481
+ # cast them back in the correct dtype just to be sure everything works as expected.
482
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
483
+ # in fp32. (LlamaRMSNorm handles it correctly)
484
+ input_dtype = query_states.dtype
485
+ if input_dtype == torch.float32:
486
+ if torch.is_autocast_enabled():
487
+ target_dtype = torch.get_autocast_gpu_dtype()
488
+ # Handle the case where the model is quantized
489
+ elif hasattr(self.config, '_pre_quantization_dtype'):
490
+ target_dtype = self.config._pre_quantization_dtype
491
+ else:
492
+ target_dtype = query_states.dtype
493
+
494
+ logger.warning_once(
495
+ f'The input hidden states seems to be silently casted in float32, this might be '
496
+ +
497
+ f'related to the fact you have upcasted embedding or layer norm layers in '
498
+ + f'float32. We will cast back the input in {target_dtype}.')
499
+
500
+ query_states = query_states.to(target_dtype)
501
+ key_states = key_states.to(target_dtype)
502
+ value_states = value_states.to(target_dtype)
503
+
504
+ attn_output = self._flash_attention_forward(
505
+ query_states,
506
+ key_states,
507
+ value_states,
508
+ attention_mask,
509
+ q_len,
510
+ dropout=dropout_rate,
511
+ )
512
+
513
+ attn_output = attn_output.reshape(bsz, q_len,
514
+ self.hidden_size).contiguous()
515
+ attn_output = self.out_proj(attn_output)
516
+
517
+ if not output_attentions:
518
+ attn_weights = None
519
+
520
+ return attn_output, attn_weights, past_key_value # type: ignore
521
+
522
+ def _flash_attention_forward(
523
+ self,
524
+ query_states: torch.Tensor,
525
+ key_states: torch.Tensor,
526
+ value_states: torch.Tensor,
527
+ attention_mask: Union[torch.LongTensor, None],
528
+ query_length: int,
529
+ dropout: float = 0.0,
530
+ softmax_scale: Optional[float] = None,
531
+ ):
532
+ """Use FlashAttention, stripping padding tokens if necessary.
533
+
534
+ Args:
535
+ query_states (torch.Tensor): Input query states to be passed to Flash Attention API
536
+ key_states (torch.Tensor): Input key states to be passed to Flash Attention API
537
+ value_states (torch.Tensor): Input value states to be passed to Flash Attention API
538
+ attention_mask (torch.LongTensor | None): The padding mask - corresponds to a tensor of size
539
+ (batch_size, seq_len) where 0 stands for the position of padding tokens and 1
540
+ for the position of non-padding tokens.
541
+ query_length (int): The length of the query sequence
542
+ dropout (float): Attention dropout
543
+ softmax_scale (float, optional): The scaling of QK^T before applying softmax.
544
+ Defaults to 1 / sqrt(head_dim)
545
+ """
546
+ causal = True
547
+ # Contains at least one padding token in the sequence
548
+ if attention_mask is not None:
549
+ batch_size = query_states.shape[0]
550
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
551
+ query_states, key_states, value_states, attention_mask,
552
+ query_length)
553
+
554
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
555
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
556
+
557
+ attn_output_unpad = flash_attn_varlen_func(
558
+ query_states,
559
+ key_states,
560
+ value_states,
561
+ cu_seqlens_q=cu_seqlens_q,
562
+ cu_seqlens_k=cu_seqlens_k,
563
+ max_seqlen_q=max_seqlen_in_batch_q,
564
+ max_seqlen_k=max_seqlen_in_batch_k,
565
+ dropout_p=dropout,
566
+ softmax_scale=softmax_scale,
567
+ causal=causal,
568
+ )
569
+
570
+ attn_output = pad_input(
571
+ attn_output_unpad,
572
+ indices_q,
573
+ batch_size,
574
+ query_length,
575
+ )
576
+ else:
577
+ attn_output = flash_attn_func(
578
+ query_states,
579
+ key_states,
580
+ value_states,
581
+ dropout,
582
+ softmax_scale=softmax_scale,
583
+ causal=causal,
584
+ )
585
+
586
+ return attn_output
587
+
588
+ def _upad_input(self, query_layer: torch.Tensor, key_layer: torch.Tensor,
589
+ value_layer: torch.Tensor, attention_mask: torch.Tensor,
590
+ query_length: int):
591
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
592
+ attention_mask)
593
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
594
+
595
+ key_layer = index_first_axis(
596
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
597
+ head_dim), indices_k)
598
+ value_layer = index_first_axis(
599
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
600
+ head_dim), indices_k)
601
+ if query_length == kv_seq_len:
602
+ query_layer = index_first_axis(
603
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
604
+ head_dim), indices_k)
605
+ cu_seqlens_q = cu_seqlens_k
606
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
607
+ indices_q = indices_k
608
+ elif query_length == 1:
609
+ max_seqlen_in_batch_q = 1
610
+ cu_seqlens_q = torch.arange(
611
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
612
+ ) # There is a memcpy here, that is very bad.
613
+ indices_q = cu_seqlens_q[:-1]
614
+ query_layer = query_layer.squeeze(1)
615
+ else:
616
+ # The -q_len: slice assumes left padding.
617
+ attention_mask = attention_mask[:, -query_length:]
618
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
619
+ query_layer, attention_mask)
620
+
621
+ return (
622
+ query_layer,
623
+ key_layer,
624
+ value_layer,
625
+ indices_q,
626
+ (cu_seqlens_q, cu_seqlens_k),
627
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
628
+ )
629
+
630
+
631
+ DBRX_ATTENTION_CLASSES = {
632
+ 'eager': DbrxAttention,
633
+ 'flash_attention_2': DbrxFlashAttention2,
634
+ }
635
+
636
+
637
+ class DbrxNormAttentionNorm(nn.Module):
638
+
639
+ def __init__(
640
+ self,
641
+ hidden_size: int,
642
+ num_heads: int,
643
+ max_position_embeddings: int,
644
+ resid_pdrop: float,
645
+ attn_implementation: str,
646
+ attn_config: DbrxAttentionConfig,
647
+ block_idx: Optional[int] = None,
648
+ ):
649
+ super().__init__()
650
+ self.block_idx = block_idx
651
+ self.resid_pdrop = resid_pdrop
652
+ self.norm_1 = nn.LayerNorm(hidden_size, bias=False)
653
+ self.attn = DBRX_ATTENTION_CLASSES[attn_implementation](
654
+ hidden_size=hidden_size,
655
+ num_heads=num_heads,
656
+ max_position_embeddings=max_position_embeddings,
657
+ attn_config=attn_config,
658
+ block_idx=block_idx,
659
+ )
660
+ self.norm_2 = nn.LayerNorm(hidden_size, bias=False)
661
+
662
+ def forward(
663
+ self,
664
+ hidden_states: torch.Tensor,
665
+ position_ids: torch.LongTensor,
666
+ attention_mask: Optional[torch.Tensor] = None,
667
+ past_key_value: Optional[Cache] = None,
668
+ output_attentions: bool = False,
669
+ use_cache: bool = False,
670
+ cache_position: Optional[torch.LongTensor] = None,
671
+ **kwargs: Any,
672
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
673
+ Optional[Cache]]:
674
+
675
+ residual_states = hidden_states
676
+ hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
677
+
678
+ hidden_states, attn_weights, past_key_value = self.attn(
679
+ hidden_states=hidden_states,
680
+ attention_mask=attention_mask,
681
+ position_ids=position_ids,
682
+ past_key_value=past_key_value,
683
+ output_attentions=output_attentions,
684
+ use_cache=use_cache,
685
+ cache_position=cache_position,
686
+ **kwargs,
687
+ )
688
+
689
+ hidden_states = nn.functional.dropout(hidden_states,
690
+ p=self.resid_pdrop,
691
+ training=self.training)
692
+ hidden_states = hidden_states + residual_states
693
+
694
+ residual_states = hidden_states
695
+ hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
696
+
697
+ return residual_states, hidden_states, attn_weights, past_key_value
698
+
699
+
700
+ class DbrxRouter(nn.Module):
701
+
702
+ def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
703
+ moe_jitter_eps: Optional[float],
704
+ moe_normalize_expert_weights: Optional[float],
705
+ uniform_expert_assignment: bool):
706
+ super().__init__()
707
+ self.hidden_size = hidden_size
708
+ self.moe_num_experts = moe_num_experts
709
+ self.moe_top_k = moe_top_k
710
+ self.moe_jitter_eps = moe_jitter_eps
711
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
712
+ self.uniform_expert_assignment = uniform_expert_assignment
713
+
714
+ self.layer = nn.Linear(self.hidden_size,
715
+ self.moe_num_experts,
716
+ bias=False)
717
+
718
+ def jitter(self, x: torch.Tensor) -> torch.Tensor:
719
+ if self.moe_jitter_eps is None:
720
+ raise RuntimeError('The router does not have moe_jitter_eps set.')
721
+ low = 1.0 - self.moe_jitter_eps
722
+ high = 1.0 + self.moe_jitter_eps
723
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
724
+ return low + noise * (high - low)
725
+
726
+ def forward(
727
+ self, x: torch.Tensor
728
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
729
+ if self.training and self.moe_jitter_eps is not None:
730
+ x = x * self.jitter(x)
731
+
732
+ weights = self.layer(x.view(-1,
733
+ x.shape[-1])).softmax(dim=-1,
734
+ dtype=torch.float32)
735
+ top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
736
+
737
+ if self.moe_normalize_expert_weights:
738
+ top_weights = top_weights / torch.norm(
739
+ top_weights,
740
+ p=self.moe_normalize_expert_weights,
741
+ dim=-1,
742
+ keepdim=True)
743
+
744
+ if self.uniform_expert_assignment:
745
+ with torch.no_grad():
746
+ uniform_tensor = torch.arange(
747
+ 0,
748
+ top_experts.numel(),
749
+ device=top_experts.device,
750
+ dtype=top_experts.dtype) % self.moe_num_experts
751
+ top_experts = uniform_tensor.reshape(top_experts.shape)
752
+ # Note, weights and top_weights are not changed
753
+
754
+ weights = weights.to(x.dtype)
755
+ top_weights = top_weights.to(x.dtype)
756
+ return weights, top_weights, top_experts # type: ignore
757
+
758
+
759
+ class DbrxMLP(nn.Module):
760
+
761
+ def __init__(self, hidden_size: int, ffn_hidden_size: int, ffn_act_fn: dict):
762
+ super().__init__()
763
+
764
+ self.w1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
765
+ self.v1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
766
+ self.w2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
767
+ self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
768
+
769
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
770
+
771
+ return self.w2(self.activation_fn(self.w1(x)) * self.v1(x))
772
+
773
+
774
+ class DbrxExperts(nn.Module):
775
+
776
+ def __init__(self, hidden_size: int, ffn_hidden_size: int,
777
+ moe_num_experts: int, ffn_act_fn: dict):
778
+ super().__init__()
779
+ self.moe_num_experts = moe_num_experts
780
+ self.mlp = nn.ModuleList([DbrxMLP(hidden_size, ffn_hidden_size, ffn_act_fn) for _ in range(moe_num_experts)])
781
+
782
+ def forward(self, x: torch.Tensor, weights: torch.Tensor,
783
+ top_weights: torch.Tensor,
784
+ top_experts: torch.LongTensor) -> torch.Tensor:
785
+ bsz, q_len, hidden_size = x.shape
786
+ x = x.view(-1, hidden_size)
787
+ out = torch.zeros_like(x)
788
+
789
+ expert_mask = nn.functional.one_hot(
790
+ top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
791
+ for expert_idx in range(0, self.moe_num_experts):
792
+ topk_idx, token_idx = torch.where(expert_mask[expert_idx])
793
+ if token_idx.shape[0] == 0:
794
+ continue
795
+
796
+ expert_tokens = x[None, token_idx].reshape(-1, hidden_size)
797
+ expert_out = self.mlp[expert_idx](expert_tokens) * top_weights[token_idx, topk_idx, None]
798
+
799
+ out.index_add_(0, token_idx, expert_out)
800
+
801
+ out = out.reshape(bsz, q_len, hidden_size)
802
+ return out
803
+
804
+
805
+ class DbrxFFN(nn.Module):
806
+
807
+ def __init__(self, hidden_size: int, ffn_config: DbrxFFNConfig):
808
+ super().__init__()
809
+
810
+ self.router = DbrxRouter(
811
+ hidden_size,
812
+ moe_num_experts=ffn_config.moe_num_experts,
813
+ moe_top_k=ffn_config.moe_top_k,
814
+ moe_jitter_eps=ffn_config.moe_jitter_eps,
815
+ moe_normalize_expert_weights=ffn_config.
816
+ moe_normalize_expert_weights,
817
+ uniform_expert_assignment=ffn_config.uniform_expert_assignment,
818
+ )
819
+
820
+ self.experts = DbrxExperts(
821
+ hidden_size=hidden_size,
822
+ ffn_hidden_size=ffn_config.ffn_hidden_size,
823
+ moe_num_experts=ffn_config.moe_num_experts,
824
+ ffn_act_fn=ffn_config.ffn_act_fn,
825
+ )
826
+
827
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
828
+ weights, top_weights, top_experts = self.router(x)
829
+ out = self.experts(x, weights, top_weights, top_experts)
830
+ return out, weights
831
+
832
+
833
+ class DbrxBlock(nn.Module):
834
+
835
+ def __init__(self, config: DbrxConfig, block_idx: int):
836
+ super().__init__()
837
+ self.hidden_size = config.d_model
838
+ self.resid_pdrop = config.resid_pdrop
839
+ self.block_idx = block_idx
840
+ self.norm_attn_norm = DbrxNormAttentionNorm(
841
+ hidden_size=config.d_model,
842
+ num_heads=config.n_heads,
843
+ max_position_embeddings=config.max_seq_len,
844
+ resid_pdrop=config.resid_pdrop,
845
+ attn_implementation=config._attn_implementation,
846
+ attn_config=config.attn_config,
847
+ block_idx=block_idx,
848
+ )
849
+ self.ffn = DbrxFFN(hidden_size=config.d_model,
850
+ ffn_config=config.ffn_config)
851
+
852
+ def forward(
853
+ self,
854
+ hidden_states: torch.Tensor,
855
+ position_ids: torch.LongTensor,
856
+ attention_mask: Optional[torch.Tensor] = None,
857
+ past_key_value: Optional[Cache] = None,
858
+ output_attentions: Optional[bool] = False,
859
+ output_router_logits: Optional[bool] = False,
860
+ use_cache: Optional[bool] = False,
861
+ cache_position: Optional[torch.LongTensor] = None,
862
+ **kwargs: Any,
863
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]],
864
+ Tuple[torch.Tensor, Optional[Cache]], Tuple[
865
+ torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
866
+ Tuple[torch.Tensor, Optional[torch.Tensor],
867
+ Optional[torch.Tensor]], Tuple[
868
+ torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
869
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache],
870
+ Optional[torch.Tensor]],]:
871
+ """Forward function for DbrxBlock.
872
+
873
+ Args:
874
+ hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
875
+ position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
876
+ attention_mask (`torch.Tensor`, optional): attention mask of size (batch_size, sequence_length)
877
+ if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
878
+ if default attention is used.
879
+ past_key_value (`Tuple(torch.Tensor)`, optional): cached past key and value projection states
880
+ output_attentions (`bool`, optional): Whether or not to return the attentions tensors of all
881
+ attention layers. See `attentions` under returned tensors for more detail.
882
+ output_router_logits (`bool`, optional): Whether or not to return the router logits.
883
+ use_cache (`bool`, optional): If set to `True`, `past_key_values` key value states are
884
+ returned and can be used to speed up decoding (see `past_key_values`).
885
+ cache_position (`torch.LongTensor`, optional): position ids of the cache
886
+ """
887
+ if 'padding_mask' in kwargs:
888
+ warnings.warn(
889
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
890
+ )
891
+
892
+ # Norm + Attention + Norm
893
+ resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
894
+ hidden_states=hidden_states,
895
+ attention_mask=attention_mask,
896
+ position_ids=position_ids,
897
+ past_key_value=past_key_value,
898
+ output_attentions=output_attentions,
899
+ use_cache=use_cache,
900
+ cache_position=cache_position,
901
+ **kwargs,
902
+ )
903
+
904
+ # Fully Connected
905
+ hidden_states, router_logits = self.ffn(hidden_states)
906
+ hidden_states = nn.functional.dropout(hidden_states,
907
+ p=self.resid_pdrop,
908
+ training=self.training)
909
+ hidden_states = resid_states + hidden_states
910
+
911
+ outputs = (hidden_states,)
912
+
913
+ if output_attentions:
914
+ outputs += (self_attn_weights,)
915
+
916
+ if use_cache:
917
+ outputs += (present_key_value,)
918
+
919
+ if output_router_logits:
920
+ outputs += (router_logits,)
921
+
922
+ return outputs
923
+
924
+
925
+ class DbrxPreTrainedModel(PreTrainedModel):
926
+ config_class = DbrxConfig
927
+ base_model_prefix = 'transformer'
928
+ supports_gradient_checkpointing = True
929
+ _no_split_modules = ['DbrxBlock']
930
+ _skip_keys_device_placement = ['past_key_values']
931
+ _supports_flash_attn_2 = True
932
+ _supports_sdpa = False
933
+ _supports_cache_class = True
934
+
935
+ def _init_weights(self, module: nn.Module):
936
+ std = self.config.initializer_range
937
+ if isinstance(module, nn.Linear):
938
+ module.weight.data.normal_(mean=0.0, std=std)
939
+ if module.bias is not None:
940
+ module.bias.data.zero_()
941
+ elif isinstance(module, nn.Embedding):
942
+ module.weight.data.normal_(mean=0.0, std=std)
943
+ if module.padding_idx is not None:
944
+ module.weight.data[module.padding_idx].zero_()
945
+ elif isinstance(module, nn.LayerNorm):
946
+ module.weight.data.normal_(mean=0.0, std=std)
947
+ if module.bias is not None:
948
+ module.bias.data.zero_()
949
+
950
+ def _setup_cache(self, cache_cls: Any, max_batch_size: int,
951
+ max_cache_len: int): # TODO: how to set var type of class?
952
+ if self.config._attn_implementation == 'flash_attention_2' and cache_cls == StaticCache:
953
+ raise ValueError(
954
+ '`static` cache implementation is not compatible with ' +
955
+ '`attn_implementation==flash_attention_2`. Make sure to use ' +
956
+ '`spda` in the mean time and open an issue at https://github.com/huggingface/transformers.'
957
+ )
958
+
959
+ for block in self.transformer.blocks:
960
+ device = block.norm_attn_norm.norm_1.weight.device
961
+ if hasattr(self.config, '_pre_quantization_dtype'):
962
+ dtype = self.config._pre_quantization_dtype
963
+ else:
964
+ dtype = block.norm_attn_norm.attn.out_proj.weight.dtype
965
+ block.norm_attn_norm.attn.past_key_value = cache_cls(self.config,
966
+ max_batch_size,
967
+ max_cache_len,
968
+ device=device,
969
+ dtype=dtype)
970
+
971
+ def _reset_cache(self):
972
+ for block in self.transformer.blocks:
973
+ block.norm_attn_norm.attn.past_key_value = None
974
+
975
+
976
+ class DbrxModel(DbrxPreTrainedModel):
977
+ """Transformer decoder consisting of *config.num_hidden_layers*
978
+
979
+ [`DbrxBlock`] layers.
980
+
981
+ Args:
982
+ config: DbrxConfig
983
+ """
984
+
985
+ def __init__(self, config: DbrxConfig):
986
+ super().__init__(config)
987
+ self.padding_idx = config.pad_token_id
988
+ self.vocab_size = config.vocab_size
989
+ self.emb_pdrop = config.emb_pdrop
990
+
991
+ self.wte = nn.Embedding(config.vocab_size, config.d_model,
992
+ self.padding_idx)
993
+ self.blocks = nn.ModuleList([
994
+ DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)
995
+ ])
996
+ self.norm_f = nn.LayerNorm(config.d_model, bias=False)
997
+ self.gradient_checkpointing = False
998
+
999
+ # Initialize weights and apply final processing
1000
+ self.post_init()
1001
+
1002
+ def get_input_embeddings(self) -> nn.Embedding:
1003
+ return self.wte
1004
+
1005
+ def set_input_embeddings(self, value: nn.Embedding):
1006
+ self.wte = value
1007
+
1008
+ def _autocast_input_embeddings(self,
1009
+ inputs_embeds: torch.Tensor) -> torch.Tensor:
1010
+ if inputs_embeds.device.type == 'cuda' and torch.is_autocast_enabled():
1011
+ return inputs_embeds.to(dtype=torch.get_autocast_gpu_dtype())
1012
+ elif inputs_embeds.device.type == 'cpu' and torch.is_autocast_cpu_enabled(
1013
+ ):
1014
+ return inputs_embeds.to(dtype=torch.get_autocast_cpu_dtype())
1015
+ else:
1016
+ return inputs_embeds
1017
+
1018
+ def forward(
1019
+ self,
1020
+ input_ids: Optional[torch.LongTensor] = None,
1021
+ attention_mask: Optional[torch.Tensor] = None,
1022
+ position_ids: Optional[torch.LongTensor] = None,
1023
+ past_key_values: Optional[Cache] = None,
1024
+ inputs_embeds: Optional[torch.Tensor] = None,
1025
+ use_cache: Optional[bool] = None,
1026
+ output_attentions: Optional[bool] = None,
1027
+ output_hidden_states: Optional[bool] = None,
1028
+ output_router_logits: Optional[bool] = None,
1029
+ return_dict: Optional[bool] = None,
1030
+ cache_position: Optional[torch.LongTensor] = None,
1031
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1032
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1033
+ output_hidden_states = (output_hidden_states
1034
+ if output_hidden_states is not None else
1035
+ self.config.output_hidden_states)
1036
+ output_router_logits = (output_router_logits
1037
+ if output_router_logits is not None else
1038
+ self.config.output_router_logits)
1039
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1040
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1041
+
1042
+ if (input_ids is None) ^ (inputs_embeds is not None):
1043
+ raise ValueError(
1044
+ 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
1045
+ )
1046
+
1047
+ if self.gradient_checkpointing and self.training and use_cache:
1048
+ logger.warning_once(
1049
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
1050
+ )
1051
+ use_cache = False
1052
+
1053
+ if inputs_embeds is None:
1054
+ inputs_embeds = self.wte(input_ids)
1055
+
1056
+ inputs_embeds = self._autocast_input_embeddings(
1057
+ inputs_embeds) # type: ignore
1058
+ inputs_embeds = nn.functional.dropout(inputs_embeds,
1059
+ p=self.emb_pdrop,
1060
+ training=self.training)
1061
+
1062
+ past_seen_tokens = 0
1063
+ if use_cache: # kept for BC (cache positions)
1064
+ if not isinstance(past_key_values, StaticCache):
1065
+ past_key_values = DynamicCache.from_legacy_cache(
1066
+ past_key_values)
1067
+ past_seen_tokens = past_key_values.get_seq_length( # type: ignore
1068
+ )
1069
+
1070
+ if cache_position is None:
1071
+ if isinstance(past_key_values, StaticCache):
1072
+ raise ValueError(
1073
+ 'cache_position is a required argument when using StaticCache.'
1074
+ )
1075
+ cache_position = torch.arange( # type: ignore
1076
+ past_seen_tokens,
1077
+ past_seen_tokens + inputs_embeds.shape[1],
1078
+ device=inputs_embeds.device)
1079
+
1080
+ if position_ids is None:
1081
+ position_ids = cache_position.unsqueeze(0) # type: ignore
1082
+
1083
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
1084
+ cache_position) # type: ignore
1085
+
1086
+ # embed positions
1087
+ hidden_states = inputs_embeds
1088
+
1089
+ # decoder layers
1090
+ all_hidden_states = () if output_hidden_states else None
1091
+ all_self_attns = () if output_attentions else None
1092
+ all_router_logits = () if output_router_logits else None
1093
+ next_decoder_cache = None
1094
+
1095
+ for block in self.blocks:
1096
+ if output_hidden_states:
1097
+ all_hidden_states += (hidden_states,) # type: ignore
1098
+
1099
+ if self.gradient_checkpointing and self.training:
1100
+ block_outputs = self._gradient_checkpointing_func(
1101
+ block.__call__,
1102
+ hidden_states,
1103
+ attention_mask=causal_mask,
1104
+ position_ids=position_ids,
1105
+ past_key_values=past_key_values,
1106
+ output_attentions=output_attentions,
1107
+ output_router_logits=output_router_logits,
1108
+ use_cache=use_cache,
1109
+ cache_position=cache_position,
1110
+ )
1111
+ else:
1112
+ block_outputs = block(
1113
+ hidden_states,
1114
+ attention_mask=causal_mask,
1115
+ position_ids=position_ids,
1116
+ past_key_value=past_key_values,
1117
+ output_attentions=output_attentions,
1118
+ output_router_logits=output_router_logits,
1119
+ use_cache=use_cache,
1120
+ cache_position=cache_position,
1121
+ )
1122
+
1123
+ hidden_states = block_outputs[0]
1124
+
1125
+ if use_cache:
1126
+ next_decoder_cache = block_outputs[
1127
+ 2 if output_attentions else 1]
1128
+
1129
+ if output_attentions:
1130
+ all_self_attns += (block_outputs[1],) # type: ignore
1131
+
1132
+ if output_router_logits:
1133
+ all_router_logits += (block_outputs[-1],) # type: ignore
1134
+
1135
+ hidden_states = self.norm_f(hidden_states)
1136
+
1137
+ # add hidden states from the last decoder layer
1138
+ if output_hidden_states:
1139
+ all_hidden_states += (hidden_states,) # type: ignore
1140
+
1141
+ next_cache = None
1142
+ if use_cache:
1143
+ next_cache = (
1144
+ next_decoder_cache.to_legacy_cache() # type: ignore
1145
+ if isinstance(next_decoder_cache, Cache) else
1146
+ next_decoder_cache)
1147
+ if not return_dict:
1148
+ return tuple(v for v in [
1149
+ hidden_states, next_cache, all_hidden_states, all_self_attns,
1150
+ all_router_logits
1151
+ ] if v is not None)
1152
+ return MoeModelOutputWithPast(
1153
+ last_hidden_state=hidden_states,
1154
+ past_key_values=next_cache,
1155
+ hidden_states=all_hidden_states,
1156
+ attentions=all_self_attns,
1157
+ router_logits=all_router_logits,
1158
+ )
1159
+
1160
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1161
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1162
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1163
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1164
+ def _update_causal_mask(
1165
+ self, attention_mask: Optional[torch.Tensor],
1166
+ input_tensor: torch.Tensor,
1167
+ cache_position: torch.Tensor) -> Optional[torch.Tensor]:
1168
+ if self.config._attn_implementation == 'flash_attention_2':
1169
+ if attention_mask is not None and 0.0 in attention_mask:
1170
+ return attention_mask
1171
+ return None
1172
+
1173
+ dtype, device = input_tensor.dtype, input_tensor.device
1174
+ min_dtype = torch.finfo(dtype).min
1175
+ sequence_length = input_tensor.shape[1]
1176
+ if hasattr(self.blocks[0].norm_attn_norm.attn,
1177
+ 'past_key_value'): # static cache
1178
+ target_length = self.config.max_position_embeddings
1179
+ else: # dynamic cache
1180
+ target_length = (attention_mask.shape[-1] if isinstance(
1181
+ attention_mask, torch.Tensor) else cache_position[-1] + 1)
1182
+ target_length = int(target_length)
1183
+
1184
+ causal_mask = torch.full((sequence_length, target_length),
1185
+ fill_value=min_dtype,
1186
+ dtype=dtype,
1187
+ device=device)
1188
+ if sequence_length != 1:
1189
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1190
+ causal_mask *= torch.arange(
1191
+ target_length, device=device) > cache_position.reshape(-1, 1)
1192
+ causal_mask = causal_mask[None,
1193
+ None, :, :].expand(input_tensor.shape[0], 1,
1194
+ -1, -1)
1195
+ if attention_mask is not None:
1196
+ causal_mask = causal_mask.clone(
1197
+ ) # copy to contiguous memory for in-place edit
1198
+ if attention_mask.dim() == 2:
1199
+ mask_length = attention_mask.shape[-1]
1200
+ padding_mask = causal_mask[..., :mask_length].eq(
1201
+ 0.0) * attention_mask[:, None, None, :].eq(0.0)
1202
+ causal_mask[..., :mask_length] = causal_mask[
1203
+ ..., :mask_length].masked_fill(padding_mask, min_dtype)
1204
+ elif attention_mask.dim() == 4:
1205
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1206
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1207
+ if attention_mask.shape[
1208
+ -2] < cache_position[0] + sequence_length:
1209
+ offset = cache_position[0]
1210
+ else:
1211
+ offset = 0
1212
+ mask_shape = attention_mask.shape
1213
+ mask_slice = (attention_mask.eq(0.0)).to(
1214
+ dtype=dtype) * min_dtype
1215
+ causal_mask[:mask_shape[0], :mask_shape[1],
1216
+ offset:mask_shape[2] +
1217
+ offset, :mask_shape[3]] = mask_slice
1218
+
1219
+ if (self.config._attn_implementation == 'sdpa' and
1220
+ attention_mask is not None and
1221
+ attention_mask.device.type == 'cuda'):
1222
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1223
+ is_tracing = (
1224
+ torch.jit.is_tracing() or
1225
+ isinstance(input_tensor, torch.fx.Proxy) or # type: ignore
1226
+ (hasattr(torch, '_dynamo') and torch._dynamo.is_compiling()))
1227
+ if not is_tracing and torch.any(attention_mask != 1):
1228
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1229
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1230
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1231
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1232
+ causal_mask, min_dtype)
1233
+
1234
+ return causal_mask
1235
+
1236
+
1237
+ class DbrxForCausalLM(DbrxPreTrainedModel):
1238
+
1239
+ def __init__(self, config: DbrxConfig):
1240
+ super().__init__(config)
1241
+ self.transformer = DbrxModel(config)
1242
+ self.vocab_size = config.vocab_size
1243
+ self.lm_head = nn.Linear(config.hidden_size,
1244
+ config.vocab_size,
1245
+ bias=False)
1246
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1247
+ self.num_experts = config.ffn_config.moe_num_experts
1248
+ self.num_experts_per_tok = config.ffn_config.moe_top_k
1249
+
1250
+ # Initialize weights and apply final processing
1251
+ self.post_init()
1252
+
1253
+ def get_input_embeddings(self) -> nn.Embedding:
1254
+ return self.transformer.get_input_embeddings()
1255
+
1256
+ def set_input_embeddings(self, value: nn.Embedding):
1257
+ self.transformer.set_input_embeddings(value)
1258
+
1259
+ def get_output_embeddings(self) -> nn.Linear:
1260
+ return self.lm_head
1261
+
1262
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1263
+ self.lm_head = new_embeddings
1264
+
1265
+ def set_decoder(self, decoder: DbrxModel):
1266
+ self.transformer = decoder
1267
+
1268
+ def get_decoder(self) -> DbrxModel:
1269
+ return self.transformer
1270
+
1271
+ def forward(
1272
+ self,
1273
+ input_ids: Optional[torch.LongTensor] = None,
1274
+ attention_mask: Optional[torch.Tensor] = None,
1275
+ position_ids: Optional[torch.LongTensor] = None,
1276
+ past_key_values: Optional[Cache] = None,
1277
+ inputs_embeds: Optional[torch.Tensor] = None,
1278
+ labels: Optional[torch.LongTensor] = None,
1279
+ use_cache: Optional[bool] = None,
1280
+ output_attentions: Optional[bool] = None,
1281
+ output_hidden_states: Optional[bool] = None,
1282
+ output_router_logits: Optional[bool] = None,
1283
+ return_dict: Optional[bool] = None,
1284
+ cache_position: Optional[torch.LongTensor] = None,
1285
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1286
+ r"""Forward function for causal language modeling.
1287
+
1288
+ Example:
1289
+ ```python
1290
+ >>> from transformers import AutoTokenizer, DbrxForCausalLM
1291
+
1292
+ >>> model = DbrxForCausalLM.from_pretrained("databricks/dbrx")
1293
+ >>> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx")
1294
+
1295
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1296
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1297
+
1298
+ >>> # Generate
1299
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1300
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1301
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1302
+ ```
1303
+ """
1304
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1305
+ output_hidden_states = (output_hidden_states
1306
+ if output_hidden_states is not None else
1307
+ self.config.output_hidden_states)
1308
+ output_router_logits = (output_router_logits
1309
+ if output_router_logits is not None else
1310
+ self.config.output_router_logits)
1311
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1312
+
1313
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1314
+ outputs = self.transformer(
1315
+ input_ids=input_ids,
1316
+ attention_mask=attention_mask,
1317
+ position_ids=position_ids,
1318
+ past_key_values=past_key_values,
1319
+ inputs_embeds=inputs_embeds,
1320
+ use_cache=use_cache,
1321
+ output_attentions=output_attentions,
1322
+ output_hidden_states=output_hidden_states,
1323
+ output_router_logits=output_router_logits,
1324
+ return_dict=return_dict,
1325
+ cache_position=cache_position,
1326
+ )
1327
+
1328
+ hidden_states = outputs[0]
1329
+ logits = self.lm_head(hidden_states)
1330
+
1331
+ loss = None
1332
+ if labels is not None:
1333
+ # Shift so that tokens < n predict n
1334
+ shift_logits = logits[..., :-1, :].contiguous()
1335
+ shift_labels = labels[..., 1:].contiguous()
1336
+ # Flatten the tokens
1337
+ loss_fct = nn.CrossEntropyLoss()
1338
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1339
+ shift_labels = shift_labels.view(-1)
1340
+ # Enable model parallelism
1341
+ shift_labels = shift_labels.to(shift_logits.device)
1342
+ loss = loss_fct(shift_logits, shift_labels)
1343
+
1344
+ aux_loss = None
1345
+ if output_router_logits:
1346
+ aux_loss = load_balancing_loss_func(
1347
+ outputs.router_logits if return_dict else outputs[-1],
1348
+ self.num_experts,
1349
+ self.num_experts_per_tok,
1350
+ attention_mask,
1351
+ )
1352
+ if labels is not None and loss is not None:
1353
+ loss += self.router_aux_loss_coef * aux_loss.to(
1354
+ loss.device) # make sure to reside in the same device
1355
+
1356
+ if not return_dict:
1357
+ output = (logits,) + outputs[1:]
1358
+ return (loss,) + output if loss is not None else output
1359
+
1360
+ return MoeCausalLMOutputWithPast(
1361
+ loss=loss,
1362
+ aux_loss=aux_loss,
1363
+ logits=logits,
1364
+ past_key_values=outputs.past_key_values,
1365
+ hidden_states=outputs.hidden_states,
1366
+ attentions=outputs.attentions,
1367
+ router_logits=outputs.router_logits,
1368
+ )
1369
+
1370
+ def prepare_inputs_for_generation(
1371
+ self,
1372
+ input_ids: torch.Tensor,
1373
+ past_key_values: Optional[Cache] = None,
1374
+ attention_mask: Optional[torch.Tensor] = None,
1375
+ inputs_embeds: Optional[torch.Tensor] = None,
1376
+ **kwargs: Any) -> Dict[str, Any]:
1377
+ past_length = 0
1378
+ if past_key_values is not None:
1379
+ if isinstance(past_key_values, Cache):
1380
+ cache_length = past_key_values.get_seq_length()
1381
+ past_length = past_key_values.seen_tokens
1382
+ max_cache_length = past_key_values.get_max_length()
1383
+ else:
1384
+ cache_length = past_length = past_key_values[0][0].shape[2]
1385
+ max_cache_length = None
1386
+
1387
+ # Keep only the unprocessed tokens:
1388
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1389
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1390
+ # input)
1391
+ if attention_mask is not None and attention_mask.shape[
1392
+ 1] > input_ids.shape[1]:
1393
+ input_ids = input_ids[:,
1394
+ -(attention_mask.shape[1] - past_length):]
1395
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1396
+ # input_ids based on the past_length.
1397
+ elif past_length < input_ids.shape[1]:
1398
+ input_ids = input_ids[:, past_length:]
1399
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1400
+
1401
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1402
+ if (max_cache_length is not None and attention_mask is not None and
1403
+ cache_length + input_ids.shape[1] > max_cache_length):
1404
+ attention_mask = attention_mask[:, -max_cache_length:]
1405
+
1406
+ position_ids = kwargs.get('position_ids', None)
1407
+ if attention_mask is not None and position_ids is None:
1408
+ # create position_ids on the fly for batch generation
1409
+ position_ids = attention_mask.long().cumsum(-1) - 1
1410
+ position_ids.masked_fill_(attention_mask == 0, 1)
1411
+ if past_key_values:
1412
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1413
+
1414
+ if self.generation_config.cache_implementation == 'static':
1415
+ # generation with static cache
1416
+ cache_position = kwargs.get('cache_position', None)
1417
+ if cache_position is None:
1418
+ past_length = 0
1419
+ else:
1420
+ past_length = cache_position[-1] + 1
1421
+ input_ids = input_ids[:, past_length:]
1422
+ position_ids = position_ids[:,
1423
+ past_length:] if position_ids is not None else None
1424
+
1425
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1426
+ # same goes for position ids. Could also help with continued generation.
1427
+ input_length = position_ids.shape[
1428
+ -1] if position_ids is not None else input_ids.shape[-1]
1429
+ cache_position = torch.arange(past_length,
1430
+ past_length + input_length,
1431
+ device=input_ids.device)
1432
+ position_ids = position_ids.contiguous(
1433
+ ) if position_ids is not None else None
1434
+
1435
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1436
+ if inputs_embeds is not None and past_key_values is None:
1437
+ model_inputs = {'inputs_embeds': inputs_embeds}
1438
+ else:
1439
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1440
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1441
+ # TODO: use `next_tokens` directly instead.
1442
+ model_inputs = {'input_ids': input_ids.contiguous()}
1443
+
1444
+ model_inputs.update(
1445
+ { # type: ignore
1446
+ 'position_ids': position_ids,
1447
+ 'cache_position': cache_position,
1448
+ 'past_key_values': past_key_values,
1449
+ 'use_cache': kwargs.get('use_cache'),
1450
+ 'attention_mask': attention_mask,
1451
+ }
1452
+ )
1453
+ return model_inputs
1454
+
1455
+ @staticmethod
1456
+ def _reorder_cache(past_key_values: Cache, beam_idx: torch.LongTensor):
1457
+ reordered_past = ()
1458
+ for layer_past in past_key_values:
1459
+ reordered_past += (tuple(
1460
+ past_state.index_select(0, beam_idx.to(past_state.device))
1461
+ for past_state in layer_past),)
1462
+ return reordered_past