Crystalcareai commited on
Commit
66c1565
1 Parent(s): d66d4ca

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +622 -787
modeling_gemmoe.py CHANGED
@@ -26,14 +26,11 @@ from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
  from transformers.activations import ACT2FN
29
- from transformers.cache_utils import Cache, DynamicCache
30
  from transformers.modeling_attn_mask_utils import (
31
- AttentionMaskConverter,
32
- _prepare_4d_attention_mask,
33
  _prepare_4d_causal_attention_mask,
34
- _prepare_4d_causal_attention_mask_for_sdpa,
35
  )
36
- from transformers.modeling_outputs import SequenceClassifierOutputWithPast, BaseModelOutputWithPast, CausalLMOutputWithPast
37
  from transformers.modeling_utils import PreTrainedModel
38
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
39
  from transformers.utils import (
@@ -63,6 +60,7 @@ if is_torch_fx_available():
63
 
64
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
65
 
 
66
  logger = logging.get_logger(__name__)
67
 
68
  _CONFIG_FOR_DOC = "GemmoeConfig"
@@ -158,42 +156,8 @@ def _get_unpad_data(attention_mask):
158
  max_seqlen_in_batch,
159
  )
160
 
161
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
162
- warnings.warn(
163
- "Calling `transformers.models.Gemmoe.modeling_Gemmoe._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
164
- )
165
- return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
166
-
167
- def _make_causal_mask(
168
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
169
- ):
170
- warnings.warn(
171
- "Calling `transformers.models.Gemmoe.modeling_Gemmoe._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.Gemmoe.modeling_Gemmoe.AttentionMaskConverter._make_causal_mask"
172
- )
173
- return AttentionMaskConverter._make_causal_mask(
174
- input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
175
- )
176
-
177
 
178
 
179
- class GemmoeRMSNorm(nn.Module):
180
- def __init__(self, hidden_size, eps=1e-6):
181
- """
182
- GemmoeRMSNorm is equivalent to T5LayerNorm
183
- """
184
- super().__init__()
185
- self.weight = nn.Parameter(torch.ones(hidden_size))
186
- self.variance_epsilon = eps
187
-
188
- def forward(self, hidden_states):
189
- input_dtype = hidden_states.dtype
190
- hidden_states = hidden_states.to(torch.float32)
191
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
192
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
193
- return self.weight * hidden_states.to(input_dtype)
194
-
195
- ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
196
-
197
  class GemmoeRMSNorm(nn.Module):
198
  def __init__(self, dim: int, eps: float = 1e-6):
199
  super().__init__()
@@ -242,249 +206,22 @@ class GemmoeRotaryEmbedding(nn.Module):
242
  self.sin_cached[:seq_len],
243
  )
244
 
245
-
246
- class GemmoeLinearScalingRotaryEmbedding(GemmoeRotaryEmbedding):
247
- """GemmoeRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
248
-
249
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
250
- self.scaling_factor = scaling_factor
251
- super().__init__(dim, max_position_embeddings, base, device)
252
-
253
- def _set_cos_sin_cache(self, seq_len, device, dtype):
254
- self.max_seq_len_cached = seq_len
255
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
256
- t = t / self.scaling_factor
257
-
258
- freqs = torch.outer(t, self.inv_freq)
259
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
260
- emb = torch.cat((freqs, freqs), dim=-1)
261
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
262
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
263
-
264
- class GemmoeDynamicNTKScalingRotaryEmbedding(GemmoeRotaryEmbedding):
265
- """GemmoeRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
266
-
267
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
268
- self.scaling_factor = scaling_factor
269
- super().__init__(dim, max_position_embeddings, base, device)
270
-
271
- def _set_cos_sin_cache(self, seq_len, device, dtype):
272
- self.max_seq_len_cached = seq_len
273
-
274
- if seq_len > self.max_position_embeddings:
275
- base = self.base * (
276
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
277
- ) ** (self.dim / (self.dim - 2))
278
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
279
- self.register_buffer("inv_freq", inv_freq, persistent=False)
280
-
281
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
282
-
283
- freqs = torch.outer(t, self.inv_freq)
284
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
285
- emb = torch.cat((freqs, freqs), dim=-1)
286
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
287
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
288
-
289
  def rotate_half(x):
290
  """Rotates half the hidden dims of the input."""
291
  x1 = x[..., : x.shape[-1] // 2]
292
  x2 = x[..., x.shape[-1] // 2 :]
293
  return torch.cat((-x2, x1), dim=-1)
294
 
295
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
296
- """Applies Rotary Position Embedding to the query and key tensors.
297
-
298
- Args:
299
- q (`torch.Tensor`): The query tensor.
300
- k (`torch.Tensor`): The key tensor.
301
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
302
- sin (`torch.Tensor`): The sine part of the rotary embedding.
303
- position_ids (`torch.Tensor`):
304
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
305
- used to pass offsetted position ids when working with a KV-cache.
306
- unsqueeze_dim (`int`, *optional*, defaults to 1):
307
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
308
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
309
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
310
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
311
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
312
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
313
- Returns:
314
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
315
- """
316
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
317
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
318
  q_embed = (q * cos) + (rotate_half(q) * sin)
319
  k_embed = (k * cos) + (rotate_half(k) * sin)
320
  return q_embed, k_embed
321
 
322
- class GemmoeMLP(nn.Module):
323
- def __init__(self, config, hidden_size = None, intermediate_size = None):
324
- super().__init__()
325
- self.config = config
326
- self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
327
- self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
328
-
329
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
330
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
331
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
332
- self.act_fn = ACT2FN[config.hidden_act]
333
-
334
- def forward(self, x):
335
- if self.config.pretraining_tp > 1:
336
- slice = self.intermediate_size // self.config.pretraining_tp
337
- gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
338
- up_proj_slices = self.up_proj.weight.split(slice, dim=0)
339
- down_proj_slices = self.down_proj.weight.split(slice, dim=1)
340
-
341
- gate_proj = torch.cat(
342
- [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
343
- )
344
- up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
345
-
346
- intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
347
- down_proj = [
348
- F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
349
- ]
350
- down_proj = sum(down_proj)
351
- else:
352
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
353
-
354
- return down_proj
355
-
356
- class MoEGate(nn.Module):
357
- def __init__(self, config):
358
- super().__init__()
359
- self.config = config
360
- self.top_k = config.num_experts_per_tok
361
- self.n_routed_experts = config.n_routed_experts
362
-
363
- self.scoring_func = config.scoring_func
364
- self.alpha = config.aux_loss_alpha
365
- self.seq_aux = config.seq_aux
366
-
367
- # topk selection algorithm
368
- self.norm_topk_prob = config.norm_topk_prob
369
- self.gating_dim = config.hidden_size
370
- self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
371
- self.reset_parameters()
372
-
373
- def reset_parameters(self) -> None:
374
- import torch.nn.init as init
375
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
376
-
377
- def forward(self, hidden_states):
378
- bsz, seq_len, h = hidden_states.shape
379
- ### compute gating score
380
- hidden_states = hidden_states.view(-1, h)
381
- logits = F.linear(hidden_states, self.weight, None)
382
- if self.scoring_func == 'softmax':
383
- scores = logits.softmax(dim=-1)
384
- else:
385
- raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
386
-
387
- ### select top-k experts
388
- topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
389
-
390
- ### norm gate to sum 1
391
- if self.top_k > 1 and self.norm_topk_prob:
392
- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
393
- topk_weight = topk_weight / denominator
394
-
395
- ### expert-level computation auxiliary loss
396
- if self.training and self.alpha > 0.0:
397
- scores_for_aux = scores
398
- aux_topk = self.top_k
399
- # always compute aux loss based on the naive greedy topk method
400
- topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
401
- if self.seq_aux:
402
- scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
403
- ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
404
- ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
405
- aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
406
- else:
407
- mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
408
- ce = mask_ce.float().mean(0)
409
- Pi = scores_for_aux.mean(0)
410
- fi = ce * self.n_routed_experts
411
- aux_loss = (Pi * fi).sum() * self.alpha
412
- else:
413
- aux_loss = None
414
- return topk_idx, topk_weight, aux_loss
415
-
416
- class AddAuxiliaryLoss(torch.autograd.Function):
417
- """
418
- The trick function of adding auxiliary (aux) loss,
419
- which includes the gradient of the aux loss during backpropagation.
420
- """
421
- @staticmethod
422
- def forward(ctx, x, loss):
423
- assert loss.numel() == 1
424
- ctx.dtype = loss.dtype
425
- ctx.required_aux_loss = loss.requires_grad
426
- return x
427
-
428
- @staticmethod
429
- def backward(ctx, grad_output):
430
- grad_loss = None
431
- if ctx.required_aux_loss:
432
- grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
433
- return grad_output, grad_loss
434
-
435
- class GemMoE(nn.Module):
436
- """
437
- A mixed expert module containing shared experts.
438
- """
439
- def __init__(self, config):
440
- super().__init__()
441
- self.config = config
442
- self.num_experts_per_tok = config.num_experts_per_tok
443
- self.experts = nn.ModuleList([GemmoeMLP(config, intermediate_size = config.moe_intermediate_size) for i in range(config.n_routed_experts)])
444
- self.gate = MoEGate(config)
445
- if config.n_shared_experts is not None:
446
- intermediate_size = config.moe_intermediate_size * config.n_shared_experts
447
- self.shared_experts = GemmoeMLP(config=config, intermediate_size = intermediate_size)
448
-
449
- def forward(self, hidden_states):
450
- identity = hidden_states
451
- orig_shape = hidden_states.shape
452
- topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
453
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
454
- flat_topk_idx = topk_idx.view(-1)
455
- if self.training:
456
- hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
457
- y = torch.empty_like(hidden_states)
458
- for i, expert in enumerate(self.experts):
459
- y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
460
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
461
- y = y.view(*orig_shape)
462
- y = AddAuxiliaryLoss.apply(y, aux_loss)
463
- else:
464
- y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
465
- if self.config.n_shared_experts is not None:
466
- y = y + self.shared_experts(identity)
467
- return y
468
-
469
- @torch.no_grad()
470
- def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
471
- expert_cache = torch.zeros_like(x)
472
- idxs = flat_expert_indices.argsort()
473
- tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
474
- token_idxs = idxs // self.num_experts_per_tok
475
- for i, end_idx in enumerate(tokens_per_expert):
476
- start_idx = 0 if i == 0 else tokens_per_expert[i-1]
477
- if start_idx == end_idx:
478
- continue
479
- expert = self.experts[i]
480
- exp_token_idx = token_idxs[start_idx:end_idx]
481
- expert_tokens = x[exp_token_idx]
482
- expert_out = expert(expert_tokens)
483
- expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
484
- expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
485
- return expert_cache
486
-
487
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
488
  """
489
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
490
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
@@ -494,10 +231,15 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
494
  return hidden_states
495
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
496
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
497
-
498
 
499
  class GemmoeAttention(nn.Module):
500
- """Multi-headed attention from 'Attention Is All You Need' paper"""
 
 
 
 
 
 
501
 
502
  def __init__(self, config: GemmoeConfig, layer_idx: Optional[int] = None):
503
  super().__init__()
@@ -505,62 +247,34 @@ class GemmoeAttention(nn.Module):
505
  self.layer_idx = layer_idx
506
  if layer_idx is None:
507
  logger.warning_once(
508
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
509
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
510
  "when creating this class."
511
  )
512
-
513
  self.attention_dropout = config.attention_dropout
514
  self.hidden_size = config.hidden_size
515
  self.num_heads = config.num_attention_heads
516
- self.head_dim = self.hidden_size // self.num_heads
517
  self.num_key_value_heads = config.num_key_value_heads
518
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
519
  self.max_position_embeddings = config.max_position_embeddings
520
  self.rope_theta = config.rope_theta
521
  self.is_causal = True
522
 
523
- if (self.head_dim * self.num_heads) != self.hidden_size:
524
  raise ValueError(
525
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
526
  f" and `num_heads`: {self.num_heads})."
527
  )
528
-
529
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
530
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
531
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
532
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
533
- self._init_rope()
534
-
535
- def _init_rope(self):
536
- if self.config.rope_scaling is None:
537
- self.rotary_emb = GemmoeRotaryEmbedding(
538
- self.head_dim,
539
- max_position_embeddings=self.max_position_embeddings,
540
- base=self.rope_theta,
541
- )
542
- else:
543
- scaling_type = self.config.rope_scaling["type"]
544
- scaling_factor = self.config.rope_scaling["factor"]
545
- if scaling_type == "linear":
546
- self.rotary_emb = GemmoeLinearScalingRotaryEmbedding(
547
- self.head_dim,
548
- max_position_embeddings=self.max_position_embeddings,
549
- scaling_factor=scaling_factor,
550
- base=self.rope_theta,
551
- )
552
- elif scaling_type == "dynamic":
553
- self.rotary_emb = GemmoeDynamicNTKScalingRotaryEmbedding(
554
- self.head_dim,
555
- max_position_embeddings=self.max_position_embeddings,
556
- scaling_factor=scaling_factor,
557
- base=self.rope_theta,
558
- )
559
- else:
560
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
561
-
562
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
563
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
564
 
565
  def forward(
566
  self,
@@ -570,78 +284,64 @@ class GemmoeAttention(nn.Module):
570
  past_key_value: Optional[Cache] = None,
571
  output_attentions: bool = False,
572
  use_cache: bool = False,
 
573
  **kwargs,
574
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
575
- if "padding_mask" in kwargs:
576
- warnings.warn(
577
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
578
- )
579
-
580
- bsz, q_len, _ = hidden_states.size()
581
-
582
- if self.config.pretraining_tp > 1:
583
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
584
- query_slices = self.q_proj.weight.split(
585
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
586
- )
587
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
588
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
589
-
590
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
591
- query_states = torch.cat(query_states, dim=-1)
592
 
593
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
594
- key_states = torch.cat(key_states, dim=-1)
 
 
 
 
 
 
 
595
 
596
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
597
- value_states = torch.cat(value_states, dim=-1)
 
 
 
 
 
598
 
599
- else:
600
- query_states = self.q_proj(hidden_states)
601
- key_states = self.k_proj(hidden_states)
602
- value_states = self.v_proj(hidden_states)
603
 
604
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
605
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
606
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
607
 
608
- kv_seq_len = key_states.shape[-2]
609
- if past_key_value is not None:
610
- if self.layer_idx is None:
611
- raise ValueError(
612
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
613
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
614
- "with a layer index."
615
- )
616
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
617
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
618
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
619
 
620
  if past_key_value is not None:
621
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
 
622
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
623
 
624
- key_states = repeat_kv(key_states, self.num_key_value_groups)
625
- value_states = repeat_kv(value_states, self.num_key_value_groups)
626
 
627
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
628
 
629
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
630
- raise ValueError(
631
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
632
- f" {attn_weights.size()}"
633
- )
634
-
635
- if attention_mask is not None:
636
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
637
- raise ValueError(
638
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
639
- )
640
- attn_weights = attn_weights + attention_mask
641
 
642
  # upcast attention to fp32
643
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
644
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
645
  attn_output = torch.matmul(attn_weights, value_states)
646
 
647
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -651,15 +351,9 @@ class GemmoeAttention(nn.Module):
651
  )
652
 
653
  attn_output = attn_output.transpose(1, 2).contiguous()
 
654
 
655
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
656
-
657
- if self.config.pretraining_tp > 1:
658
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
659
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
660
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
661
- else:
662
- attn_output = self.o_proj(attn_output)
663
 
664
  if not output_attentions:
665
  attn_weights = None
@@ -672,13 +366,9 @@ class GemmoeFlashAttention2(GemmoeAttention):
672
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
673
  flash attention and deal with padding tokens in case the input contains any of them.
674
  """
675
-
676
  def __init__(self, *args, **kwargs):
677
  super().__init__(*args, **kwargs)
678
-
679
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
680
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
681
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
682
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
683
 
684
  def forward(
@@ -689,17 +379,9 @@ class GemmoeFlashAttention2(GemmoeAttention):
689
  past_key_value: Optional[Cache] = None,
690
  output_attentions: bool = False,
691
  use_cache: bool = False,
 
692
  **kwargs,
693
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
694
- # GemmoeFlashAttention2 attention does not support output_attentions
695
- if "padding_mask" in kwargs:
696
- warnings.warn(
697
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
698
- )
699
-
700
- # overwrite attention_mask with padding_mask
701
- attention_mask = kwargs.pop("padding_mask")
702
-
703
  output_attentions = False
704
 
705
  bsz, q_len, _ = hidden_states.size()
@@ -715,14 +397,13 @@ class GemmoeFlashAttention2(GemmoeAttention):
715
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
716
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
717
 
718
- kv_seq_len = key_states.shape[-2]
719
- if past_key_value is not None:
720
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
721
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
722
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
723
 
 
724
  if past_key_value is not None:
725
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
 
726
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
727
 
728
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
@@ -738,14 +419,13 @@ class GemmoeFlashAttention2(GemmoeAttention):
738
  # cast them back in the correct dtype just to be sure everything works as expected.
739
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
740
  # in fp32. (GemmoeRMSNorm handles it correctly)
741
-
742
  input_dtype = query_states.dtype
743
  if input_dtype == torch.float32:
 
 
744
  # Handle the case where the model is quantized
745
- if hasattr(self.config, "_pre_quantization_dtype"):
746
  target_dtype = self.config._pre_quantization_dtype
747
- elif torch.is_autocast_enabled():
748
- target_dtype = torch.get_autocast_gpu_dtype()
749
  else:
750
  target_dtype = self.q_proj.weight.dtype
751
 
@@ -754,7 +434,6 @@ class GemmoeFlashAttention2(GemmoeAttention):
754
  f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
755
  f" {target_dtype}."
756
  )
757
-
758
  query_states = query_states.to(target_dtype)
759
  key_states = key_states.to(target_dtype)
760
  value_states = value_states.to(target_dtype)
@@ -763,7 +442,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
763
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
764
  )
765
 
766
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
767
  attn_output = self.o_proj(attn_output)
768
 
769
  if not output_attentions:
@@ -805,7 +484,6 @@ class GemmoeFlashAttention2(GemmoeAttention):
805
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
806
  query_states, key_states, value_states, attention_mask, query_length
807
  )
808
-
809
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
810
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
811
 
@@ -821,7 +499,6 @@ class GemmoeFlashAttention2(GemmoeAttention):
821
  softmax_scale=softmax_scale,
822
  causal=causal,
823
  )
824
-
825
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
826
  else:
827
  attn_output = flash_attn_func(
@@ -832,14 +509,15 @@ class GemmoeFlashAttention2(GemmoeAttention):
832
 
833
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
834
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
835
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
836
 
 
837
  key_layer = index_first_axis(
838
  key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
839
  )
840
  value_layer = index_first_axis(
841
  value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
842
  )
 
843
  if query_length == kv_seq_len:
844
  query_layer = index_first_axis(
845
  query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
@@ -871,11 +549,21 @@ class GemmoeFlashAttention2(GemmoeAttention):
871
  class GemmoeSdpaAttention(GemmoeAttention):
872
  """
873
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
874
- `GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
875
  SDPA API.
876
  """
877
 
878
- # Adapted from GemmoeAttention.forward
 
 
 
 
 
 
 
 
 
 
879
  def forward(
880
  self,
881
  hidden_states: torch.Tensor,
@@ -884,13 +572,15 @@ class GemmoeSdpaAttention(GemmoeAttention):
884
  past_key_value: Optional[Cache] = None,
885
  output_attentions: bool = False,
886
  use_cache: bool = False,
 
887
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
888
  if output_attentions:
889
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
890
- logger.warning_once(
891
- "GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
892
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
893
- )
 
894
  return super().forward(
895
  hidden_states=hidden_states,
896
  attention_mask=attention_mask,
@@ -898,8 +588,9 @@ class GemmoeSdpaAttention(GemmoeAttention):
898
  past_key_value=past_key_value,
899
  output_attentions=output_attentions,
900
  use_cache=use_cache,
 
901
  )
902
-
903
  bsz, q_len, _ = hidden_states.size()
904
 
905
  query_states = self.q_proj(hidden_states)
@@ -910,46 +601,48 @@ class GemmoeSdpaAttention(GemmoeAttention):
910
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
911
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
912
 
913
- kv_seq_len = key_states.shape[-2]
914
- if past_key_value is not None:
915
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
916
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
917
-
918
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
919
 
 
920
  if past_key_value is not None:
921
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
 
922
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
923
 
924
- key_states = repeat_kv(key_states, self.num_key_value_groups)
925
- value_states = repeat_kv(value_states, self.num_key_value_groups)
926
 
927
- if attention_mask is not None:
928
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
929
- raise ValueError(
930
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
931
- )
 
 
 
932
 
933
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
934
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
935
- if query_states.device.type == "cuda" and attention_mask is not None:
936
  query_states = query_states.contiguous()
937
  key_states = key_states.contiguous()
938
  value_states = value_states.contiguous()
939
 
 
 
 
 
940
  attn_output = torch.nn.functional.scaled_dot_product_attention(
941
  query_states,
942
  key_states,
943
  value_states,
944
- attn_mask=attention_mask,
945
  dropout_p=self.attention_dropout if self.training else 0.0,
946
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
947
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
948
  )
949
 
950
  attn_output = attn_output.transpose(1, 2).contiguous()
951
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
952
-
953
  attn_output = self.o_proj(attn_output)
954
 
955
  return attn_output, None, past_key_value
@@ -960,23 +653,101 @@ GEMMOE_ATTENTION_CLASSES = {
960
  "sdpa": GemmoeSdpaAttention,
961
  }
962
 
963
- class GemmoeBlockSparseTop2MLP(nn.Module):
964
- def __init__(self, config: GemmoeConfig):
965
  super().__init__()
966
- self.ffn_dim = config.intermediate_size
967
- self.hidden_dim = config.hidden_size
 
 
 
 
 
 
 
 
 
 
 
 
 
968
 
969
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
970
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
971
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
 
972
 
973
- self.act_fn = approx_gelu
 
 
 
 
 
 
974
 
975
- def forward(self, hidden_states):
976
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
977
- current_hidden_states = self.w2(current_hidden_states)
978
- return current_hidden_states
979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  class GemmoeSparseMoeBlock(nn.Module):
981
  def __init__(self, config):
982
  super().__init__()
@@ -985,20 +756,15 @@ class GemmoeSparseMoeBlock(nn.Module):
985
  self.num_experts = config.num_local_experts
986
  self.top_k = 2
987
 
988
- # gating
989
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
990
 
991
- self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
992
 
993
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
994
  batch_size, sequence_length, hidden_dim = hidden_states.shape
995
  hidden_states = hidden_states.view(-1, hidden_dim)
996
 
997
- # router_logits: (batch * sequence_length, n_experts)
998
- router_logits = self.gate(hidden_states)
999
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
1000
- topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
1001
- topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
1002
 
1003
  # we cast back to the input dtype
1004
  topk_weight = topk_weight.to(hidden_states.dtype)
@@ -1016,22 +782,36 @@ class GemmoeSparseMoeBlock(nn.Module):
1016
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
1017
 
1018
  final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
1019
- return final_hidden_states, router_logits
 
 
 
 
 
 
 
 
 
 
 
 
1020
 
 
 
 
 
 
 
1021
 
 
1022
  class GemmoeDecoderLayer(nn.Module):
1023
  def __init__(self, config: GemmoeConfig, layer_idx: int):
1024
  super().__init__()
1025
  self.hidden_size = config.hidden_size
1026
- self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
1027
 
1028
- if config.n_routed_experts is not None and \
1029
- layer_idx >= config.first_k_dense_replace and \
1030
- layer_idx % config.moe_layer_freq == 0:
1031
- self.block_sparse_moe = GemmoeSparseMoeBlock(config)
1032
- else:
1033
- self.mlp = GemmoeMLP(config)
1034
 
 
1035
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1036
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1037
 
@@ -1044,29 +824,9 @@ class GemmoeDecoderLayer(nn.Module):
1044
  output_attentions: Optional[bool] = False,
1045
  output_router_logits: Optional[bool] = False,
1046
  use_cache: Optional[bool] = False,
 
1047
  **kwargs,
1048
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1049
- """
1050
- Args:
1051
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1052
- attention_mask (`torch.FloatTensor`, *optional*):
1053
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1054
- query_sequence_length, key_sequence_length)` if default attention is used.
1055
- output_attentions (`bool`, *optional*):
1056
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1057
- returned tensors for more detail.
1058
- use_cache (`bool`, *optional*):
1059
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1060
- (see `past_key_values`).
1061
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1062
- output_router_logits (`bool`, *optional*):
1063
- Whether or not to return the logits of all the routers. They are useful for computing the router loss,
1064
- and should not be returned during inference.
1065
- """
1066
- if "padding_mask" in kwargs:
1067
- warnings.warn(
1068
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1069
- )
1070
  residual = hidden_states
1071
  hidden_states = self.input_layernorm(hidden_states)
1072
 
@@ -1078,6 +838,7 @@ class GemmoeDecoderLayer(nn.Module):
1078
  past_key_value=past_key_value,
1079
  output_attentions=output_attentions,
1080
  use_cache=use_cache,
 
1081
  **kwargs,
1082
  )
1083
  hidden_states = residual + hidden_states
@@ -1085,14 +846,12 @@ class GemmoeDecoderLayer(nn.Module):
1085
  # Fully Connected
1086
  residual = hidden_states
1087
  hidden_states = self.post_attention_layernorm(hidden_states)
1088
-
1089
- if hasattr(self, 'block_sparse_moe'):
1090
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
1091
- else:
1092
- hidden_states = self.mlp(hidden_states)
1093
-
1094
  hidden_states = residual + hidden_states
1095
 
 
 
 
1096
  outputs = (hidden_states,)
1097
 
1098
  if output_attentions:
@@ -1100,26 +859,13 @@ class GemmoeDecoderLayer(nn.Module):
1100
 
1101
  if use_cache:
1102
  outputs += (present_key_value,)
1103
-
1104
- if output_router_logits and hasattr(self, 'block_sparse_moe'):
1105
- outputs += (router_logits,)
1106
 
1107
  return outputs
1108
 
1109
  GEMMOE_START_DOCSTRING = r"""
1110
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1111
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1112
- etc.)
1113
-
1114
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1115
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1116
- and behavior.
1117
-
1118
- Parameters:
1119
- config ([`GemmoeConfig`]):
1120
- Model configuration class with all the parameters of the model. Initializing with a config file does not
1121
- load the weights associated with the model, only the configuration. Check out the
1122
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1123
  """
1124
 
1125
  @add_start_docstrings(
@@ -1128,94 +874,52 @@ GEMMOE_START_DOCSTRING,
1128
  )
1129
 
1130
  class GemmoePreTrainedModel(PreTrainedModel):
1131
- config_class = GemmoeConfig
1132
- base_model_prefix = "model"
1133
- supports_gradient_checkpointing = True
1134
- _no_split_modules = ["GemmoeDecoderLayer"]
1135
- _skip_keys_device_placement = "past_key_values"
1136
- _supports_flash_attn_2 = True
1137
- _supports_sdpa = True
1138
- _supports_cache_class = True
1139
-
1140
- def _init_weights(self, module):
1141
- std = self.config.initializer_range
1142
- if isinstance(module, nn.Linear):
1143
- module.weight.data.normal_(mean=0.0, std=std)
1144
- if module.bias is not None:
1145
- module.bias.data.zero_()
1146
- elif isinstance(module, nn.Embedding):
1147
- module.weight.data.normal_(mean=0.0, std=std)
1148
- if module.padding_idx is not None:
1149
- module.weight.data[module.padding_idx].zero_()
1150
-
1151
-
1152
- Gemmoe_INPUTS_DOCSTRING = r"""
1153
- Args:
1154
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1155
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1156
- it.
1157
-
1158
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1159
- [`PreTrainedTokenizer.__call__`] for details.
1160
-
1161
- [What are input IDs?](../glossary#input-ids)
1162
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1163
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1164
-
1165
- - 1 for tokens that are **not masked**,
1166
- - 0 for tokens that are **masked**.
1167
-
1168
- [What are attention masks?](../glossary#attention-mask)
1169
-
1170
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1171
- [`PreTrainedTokenizer.__call__`] for details.
1172
-
1173
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1174
- `past_key_values`).
1175
-
1176
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1177
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1178
- information on the default strategy.
1179
-
1180
- - 1 indicates the head is **not masked**,
1181
- - 0 indicates the head is **masked**.
1182
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1183
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1184
- config.n_positions - 1]`.
1185
-
1186
- [What are position IDs?](../glossary#position-ids)
1187
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1188
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1189
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1190
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1191
-
1192
- Two formats are allowed:
1193
- - a [`~cache_utils.Cache`] instance;
1194
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1195
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1196
- cache format.
1197
-
1198
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1199
- legacy cache format will be returned.
1200
-
1201
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1202
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1203
- of shape `(batch_size, sequence_length)`.
1204
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1205
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1206
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1207
- model's internal embedding lookup matrix.
1208
- use_cache (`bool`, *optional*):
1209
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1210
- `past_key_values`).
1211
- output_attentions (`bool`, *optional*):
1212
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1213
- tensors for more detail.
1214
- output_hidden_states (`bool`, *optional*):
1215
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1216
- more detail.
1217
- return_dict (`bool`, *optional*):
1218
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1219
  """
1220
 
1221
  @add_start_docstrings(
@@ -1224,168 +928,263 @@ GEMMOE_START_DOCSTRING,
1224
  )
1225
 
1226
  class GemmoeModel(GemmoePreTrainedModel):
1227
- """
1228
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
1229
-
1230
- Args:
1231
- config: GemmoeConfig
1232
- """
1233
-
1234
- def __init__(self, config: GemmoeConfig):
1235
- super().__init__(config)
1236
- self.padding_idx = config.pad_token_id
1237
- self.vocab_size = config.vocab_size
1238
-
1239
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1240
- self.layers = nn.ModuleList(
1241
- [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1242
- )
1243
- self._use_sdpa = config._attn_implementation == "sdpa"
1244
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1245
- self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1246
-
1247
- self.gradient_checkpointing = False
1248
- # Initialize weights and apply final processing
1249
- self.post_init()
1250
-
1251
- def get_input_embeddings(self):
1252
- return self.embed_tokens
1253
-
1254
- def set_input_embeddings(self, value):
1255
- self.embed_tokens = value
1256
-
1257
- @add_start_docstrings_to_model_forward(Gemmoe_INPUTS_DOCSTRING)
1258
- def forward(
1259
- self,
1260
- input_ids: torch.LongTensor = None,
1261
- attention_mask: Optional[torch.Tensor] = None,
1262
- position_ids: Optional[torch.LongTensor] = None,
1263
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1264
- inputs_embeds: Optional[torch.FloatTensor] = None,
1265
- use_cache: Optional[bool] = None,
1266
- output_attentions: Optional[bool] = None,
1267
- output_hidden_states: Optional[bool] = None,
1268
- return_dict: Optional[bool] = None,
1269
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1270
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1271
- output_hidden_states = (
1272
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1273
- )
1274
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1275
-
1276
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1277
-
1278
- # retrieve input_ids and inputs_embeds
1279
- if input_ids is not None and inputs_embeds is not None:
1280
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1281
- elif input_ids is not None:
1282
- batch_size, seq_length = input_ids.shape[:2]
1283
- elif inputs_embeds is not None:
1284
- batch_size, seq_length = inputs_embeds.shape[:2]
1285
- else:
1286
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1287
-
1288
- if self.gradient_checkpointing and self.training:
1289
- if use_cache:
1290
- logger.warning_once(
1291
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
1292
- )
1293
- use_cache = False
1294
-
1295
- past_key_values_length = 0
1296
- if use_cache:
1297
- use_legacy_cache = not isinstance(past_key_values, Cache)
1298
- if use_legacy_cache:
1299
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1300
- past_key_values_length = past_key_values.get_usable_length(seq_length)
1301
-
1302
- if position_ids is None:
1303
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1304
- position_ids = torch.arange(
1305
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1306
- )
1307
- position_ids = position_ids.unsqueeze(0)
1308
-
1309
- if inputs_embeds is None:
1310
- inputs_embeds = self.embed_tokens(input_ids)
1311
-
1312
- if self._use_flash_attention_2:
1313
- # 2d mask is passed through the layers
1314
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1315
- elif self._use_sdpa and not output_attentions:
1316
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1317
- # the manual implementation that requires a 4D causal mask in all cases.
1318
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1319
- attention_mask,
1320
- (batch_size, seq_length),
1321
- inputs_embeds,
1322
- past_key_values_length,
1323
- )
1324
- else:
1325
- # 4d mask is passed through the layers
1326
- attention_mask = _prepare_4d_causal_attention_mask(
1327
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1328
- )
1329
-
1330
- # embed positions
1331
- hidden_states = inputs_embeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1332
 
1333
- # decoder layers
1334
- all_hidden_states = () if output_hidden_states else None
1335
- all_self_attns = () if output_attentions else None
1336
- next_decoder_cache = None
1337
-
1338
- for decoder_layer in self.layers:
1339
- if output_hidden_states:
1340
- all_hidden_states += (hidden_states,)
1341
-
1342
- if self.gradient_checkpointing and self.training:
1343
- layer_outputs = self._gradient_checkpointing_func(
1344
- decoder_layer.__call__,
1345
- hidden_states,
1346
- attention_mask,
1347
- position_ids,
1348
- past_key_values,
1349
- output_attentions,
1350
- use_cache,
1351
- )
1352
- else:
1353
- layer_outputs = decoder_layer(
1354
- hidden_states,
1355
- attention_mask=attention_mask,
1356
- position_ids=position_ids,
1357
- past_key_value=past_key_values,
1358
- output_attentions=output_attentions,
1359
- use_cache=use_cache,
1360
- )
1361
-
1362
- hidden_states = layer_outputs[0]
1363
-
1364
- if use_cache:
1365
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1366
 
1367
- if output_attentions:
1368
- all_self_attns += (layer_outputs[1],)
1369
 
1370
- hidden_states = self.norm(hidden_states)
 
 
1371
 
1372
- # add hidden states from the last decoder layer
1373
- if output_hidden_states:
1374
- all_hidden_states += (hidden_states,)
1375
 
1376
- next_cache = None
1377
- if use_cache:
1378
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1379
- if not return_dict:
1380
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1381
- return BaseModelOutputWithPast(
1382
- last_hidden_state=hidden_states,
1383
- past_key_values=next_cache,
1384
- hidden_states=all_hidden_states,
1385
- attentions=all_self_attns,
1386
- )
1387
 
1388
- class GemmoeForCausalLM(GemmoePreTrainedModel):
 
 
 
 
 
1389
  _tied_weights_keys = ["lm_head.weight"]
1390
 
1391
  def __init__(self, config):
@@ -1393,6 +1192,9 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1393
  self.model = GemmoeModel(config)
1394
  self.vocab_size = config.vocab_size
1395
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
1396
 
1397
  # Initialize weights and apply final processing
1398
  self.post_init()
@@ -1415,8 +1217,8 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1415
  def get_decoder(self):
1416
  return self.model
1417
 
1418
- @add_start_docstrings_to_model_forward(Gemmoe_INPUTS_DOCSTRING)
1419
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1420
  def forward(
1421
  self,
1422
  input_ids: torch.LongTensor = None,
@@ -1428,14 +1230,16 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1428
  use_cache: Optional[bool] = None,
1429
  output_attentions: Optional[bool] = None,
1430
  output_hidden_states: Optional[bool] = None,
 
1431
  return_dict: Optional[bool] = None,
1432
- ) -> Union[Tuple, CausalLMOutputWithPast]:
 
1433
  r"""
1434
  Args:
1435
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1436
- Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1437
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1438
- (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1439
 
1440
  Returns:
1441
 
@@ -1444,24 +1248,26 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1444
  ```python
1445
  >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1446
 
1447
- >>> model = GemmoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1448
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1449
 
1450
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1451
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1452
 
1453
  >>> # Generate
1454
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1455
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1456
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1457
  ```"""
1458
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
1459
  output_hidden_states = (
1460
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1461
  )
1462
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1463
 
1464
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1465
  outputs = self.model(
1466
  input_ids=input_ids,
1467
  attention_mask=attention_mask,
@@ -1471,46 +1277,61 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1471
  use_cache=use_cache,
1472
  output_attentions=output_attentions,
1473
  output_hidden_states=output_hidden_states,
 
1474
  return_dict=return_dict,
 
1475
  )
1476
 
1477
  hidden_states = outputs[0]
1478
- if self.config.pretraining_tp > 1:
1479
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1480
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1481
- logits = torch.cat(logits, dim=-1)
1482
- else:
1483
- logits = self.lm_head(hidden_states)
1484
- logits = logits.float()
1485
 
1486
  loss = None
1487
  if labels is not None:
1488
- # Shift so that tokens < n predict n
1489
  shift_logits = logits[..., :-1, :].contiguous()
1490
  shift_labels = labels[..., 1:].contiguous()
1491
- # Flatten the tokens
1492
  loss_fct = CrossEntropyLoss()
1493
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1494
  shift_labels = shift_labels.view(-1)
1495
- # Enable model parallelism
1496
  shift_labels = shift_labels.to(shift_logits.device)
1497
  loss = loss_fct(shift_logits, shift_labels)
1498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1499
  if not return_dict:
1500
  output = (logits,) + outputs[1:]
 
 
1501
  return (loss,) + output if loss is not None else output
1502
 
1503
- return CausalLMOutputWithPast(
1504
  loss=loss,
 
1505
  logits=logits,
1506
  past_key_values=outputs.past_key_values,
1507
  hidden_states=outputs.hidden_states,
1508
  attentions=outputs.attentions,
 
1509
  )
1510
 
1511
  def prepare_inputs_for_generation(
1512
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1513
  ):
 
1514
  if past_key_values is not None:
1515
  if isinstance(past_key_values, Cache):
1516
  cache_length = past_key_values.get_seq_length()
@@ -1520,19 +1341,11 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1520
  cache_length = past_length = past_key_values[0][0].shape[2]
1521
  max_cache_length = None
1522
 
1523
- # Keep only the unprocessed tokens:
1524
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1525
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1526
- # input)
1527
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1528
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1529
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1530
- # input_ids based on the past_length.
1531
  elif past_length < input_ids.shape[1]:
1532
  input_ids = input_ids[:, past_length:]
1533
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1534
-
1535
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1536
  if (
1537
  max_cache_length is not None
1538
  and attention_mask is not None
@@ -1542,26 +1355,37 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1542
 
1543
  position_ids = kwargs.get("position_ids", None)
1544
  if attention_mask is not None and position_ids is None:
1545
- # create position_ids on the fly for batch generation
1546
  position_ids = attention_mask.long().cumsum(-1) - 1
1547
  position_ids.masked_fill_(attention_mask == 0, 1)
1548
  if past_key_values:
1549
  position_ids = position_ids[:, -input_ids.shape[1] :]
1550
 
1551
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 
 
 
 
 
 
 
 
 
 
1552
  if inputs_embeds is not None and past_key_values is None:
1553
  model_inputs = {"inputs_embeds": inputs_embeds}
1554
  else:
1555
- model_inputs = {"input_ids": input_ids}
1556
 
1557
  model_inputs.update(
1558
  {
1559
- "position_ids": position_ids,
 
1560
  "past_key_values": past_key_values,
1561
  "use_cache": kwargs.get("use_cache"),
1562
  "attention_mask": attention_mask,
1563
  }
1564
  )
 
1565
  return model_inputs
1566
 
1567
  @staticmethod
@@ -1594,7 +1418,6 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1594
  self.num_labels = config.num_labels
1595
  self.model = GemmoeModel(config)
1596
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1597
-
1598
  # Initialize weights and apply final processing
1599
  self.post_init()
1600
 
@@ -1604,7 +1427,8 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1604
  def set_input_embeddings(self, value):
1605
  self.model.embed_tokens = value
1606
 
1607
- @add_start_docstrings_to_model_forward(Gemmoe_INPUTS_DOCSTRING)
 
1608
  def forward(
1609
  self,
1610
  input_ids: torch.LongTensor = None,
@@ -1618,14 +1442,25 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1618
  output_hidden_states: Optional[bool] = None,
1619
  return_dict: Optional[bool] = None,
1620
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1621
- r"""
1622
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1623
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1624
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1625
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1626
  """
1627
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1629
  transformer_outputs = self.model(
1630
  input_ids,
1631
  attention_mask=attention_mask,
@@ -1651,9 +1486,8 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1651
  sequence_lengths = -1
1652
  else:
1653
  if input_ids is not None:
1654
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1655
- logits.device
1656
- )
1657
  else:
1658
  sequence_lengths = -1
1659
 
@@ -1682,6 +1516,7 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1682
  elif self.config.problem_type == "multi_label_classification":
1683
  loss_fct = BCEWithLogitsLoss()
1684
  loss = loss_fct(pooled_logits, labels)
 
1685
  if not return_dict:
1686
  output = (pooled_logits,) + transformer_outputs[1:]
1687
  return ((loss,) + output) if loss is not None else output
 
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
  from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.modeling_attn_mask_utils import (
 
 
31
  _prepare_4d_causal_attention_mask,
 
32
  )
33
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast, MoeModelOutputWithPast, MoeCausalLMOutputWithPast
34
  from transformers.modeling_utils import PreTrainedModel
35
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
36
  from transformers.utils import (
 
60
 
61
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
62
 
63
+
64
  logger = logging.get_logger(__name__)
65
 
66
  _CONFIG_FOR_DOC = "GemmoeConfig"
 
156
  max_seqlen_in_batch,
157
  )
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  class GemmoeRMSNorm(nn.Module):
162
  def __init__(self, dim: int, eps: float = 1e-6):
163
  super().__init__()
 
206
  self.sin_cached[:seq_len],
207
  )
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def rotate_half(x):
210
  """Rotates half the hidden dims of the input."""
211
  x1 = x[..., : x.shape[-1] // 2]
212
  x2 = x[..., x.shape[-1] // 2 :]
213
  return torch.cat((-x2, x1), dim=-1)
214
 
215
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
216
+ """Applies Rotary Position Embedding to the query and key tensors."""
217
+ seq_len, dim = q.shape[-2], q.shape[-1]
218
+ cos = cos[:seq_len].view(1, 1, seq_len, dim)
219
+ sin = sin[:seq_len].view(1, 1, seq_len, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  q_embed = (q * cos) + (rotate_half(q) * sin)
221
  k_embed = (k * cos) + (rotate_half(k) * sin)
222
  return q_embed, k_embed
223
 
224
+ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  """
226
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
227
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
 
231
  return hidden_states
232
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
233
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 
234
 
235
  class GemmoeAttention(nn.Module):
236
+ """
237
+ Multi-headed attention module for Gemmoe model.
238
+
239
+ Args:
240
+ config (GemmoeConfig): The configuration object for the Gemmoe model.
241
+ layer_idx (Optional[int]): The index of the layer. Default is None.
242
+ """
243
 
244
  def __init__(self, config: GemmoeConfig, layer_idx: Optional[int] = None):
245
  super().__init__()
 
247
  self.layer_idx = layer_idx
248
  if layer_idx is None:
249
  logger.warning_once(
250
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
251
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
252
  "when creating this class."
253
  )
 
254
  self.attention_dropout = config.attention_dropout
255
  self.hidden_size = config.hidden_size
256
  self.num_heads = config.num_attention_heads
257
+ self.head_dim = config.head_dim
258
  self.num_key_value_heads = config.num_key_value_heads
259
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
260
  self.max_position_embeddings = config.max_position_embeddings
261
  self.rope_theta = config.rope_theta
262
  self.is_causal = True
263
 
264
+ if self.hidden_size % self.num_heads != 0:
265
  raise ValueError(
266
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
267
  f" and `num_heads`: {self.num_heads})."
268
  )
 
269
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
270
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
271
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
272
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
273
+ self.rotary_emb = GemmoeRotaryEmbedding(
274
+ self.head_dim,
275
+ max_position_embeddings=self.max_position_embeddings,
276
+ base=self.rope_theta,
277
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  def forward(
280
  self,
 
284
  past_key_value: Optional[Cache] = None,
285
  output_attentions: bool = False,
286
  use_cache: bool = False,
287
+ cache_position: Optional[torch.LongTensor] = None,
288
  **kwargs,
289
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
+ """
291
+ Forward pass of the attention module.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ Args:
294
+ hidden_states (torch.Tensor): The input hidden states.
295
+ attention_mask (Optional[torch.Tensor]): The attention mask. Default is None.
296
+ position_ids (Optional[torch.LongTensor]): The position IDs. Default is None.
297
+ past_key_value (Optional[Cache]): The past key-value cache. Default is None.
298
+ output_attentions (bool): Whether to output the attention weights. Default is False.
299
+ use_cache (bool): Whether to use caching. Default is False.
300
+ cache_position (Optional[torch.LongTensor]): The cache position. Default is None.
301
+ **kwargs: Additional keyword arguments.
302
 
303
+ Returns:
304
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
305
+ - The output hidden states.
306
+ - The attention weights (if `output_attentions=True`).
307
+ - The past key-value cache (if `use_cache=True`).
308
+ """
309
+ bsz, q_len, _ = hidden_states.size()
310
 
311
+ query_states = self.q_proj(hidden_states)
312
+ key_states = self.k_proj(hidden_states)
313
+ value_states = self.v_proj(hidden_states)
 
314
 
315
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
316
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
317
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
318
 
319
+ past_key_value = getattr(self, "past_key_value", past_key_value)
320
+
321
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
322
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
 
 
 
 
 
323
 
324
  if past_key_value is not None:
325
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
326
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
327
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
328
 
329
+ key_states = self.repeat_kv(key_states, self.num_key_value_groups)
330
+ value_states = self.repeat_kv(value_states, self.num_key_value_groups)
331
 
332
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
333
 
334
+ if attention_mask is not None: # no matter the length, we just slice it
335
+ if cache_position is not None:
336
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
337
+ else:
338
+ causal_mask = attention_mask
339
+ attn_weights = attn_weights + causal_mask
 
 
 
 
 
 
340
 
341
  # upcast attention to fp32
342
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
343
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
344
+
345
  attn_output = torch.matmul(attn_weights, value_states)
346
 
347
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
351
  )
352
 
353
  attn_output = attn_output.transpose(1, 2).contiguous()
354
+ attn_output = attn_output.view(bsz, q_len, -1)
355
 
356
+ attn_output = self.o_proj(attn_output)
 
 
 
 
 
 
 
357
 
358
  if not output_attentions:
359
  attn_weights = None
 
366
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
367
  flash attention and deal with padding tokens in case the input contains any of them.
368
  """
 
369
  def __init__(self, *args, **kwargs):
370
  super().__init__(*args, **kwargs)
371
+ # TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
 
 
 
372
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
373
 
374
  def forward(
 
379
  past_key_value: Optional[Cache] = None,
380
  output_attentions: bool = False,
381
  use_cache: bool = False,
382
+ cache_position: Optional[torch.LongTensor] = None,
383
  **kwargs,
384
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
 
 
385
  output_attentions = False
386
 
387
  bsz, q_len, _ = hidden_states.size()
 
397
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
398
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
399
 
400
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
401
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
 
402
 
403
+ past_key_value = getattr(self, "past_key_value", past_key_value)
404
  if past_key_value is not None:
405
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
406
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
407
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
408
 
409
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
 
419
  # cast them back in the correct dtype just to be sure everything works as expected.
420
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
421
  # in fp32. (GemmoeRMSNorm handles it correctly)
 
422
  input_dtype = query_states.dtype
423
  if input_dtype == torch.float32:
424
+ if torch.is_autocast_enabled():
425
+ target_dtype = torch.get_autocast_gpu_dtype()
426
  # Handle the case where the model is quantized
427
+ elif hasattr(self.config, "_pre_quantization_dtype"):
428
  target_dtype = self.config._pre_quantization_dtype
 
 
429
  else:
430
  target_dtype = self.q_proj.weight.dtype
431
 
 
434
  f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
  f" {target_dtype}."
436
  )
 
437
  query_states = query_states.to(target_dtype)
438
  key_states = key_states.to(target_dtype)
439
  value_states = value_states.to(target_dtype)
 
442
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
443
  )
444
 
445
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
446
  attn_output = self.o_proj(attn_output)
447
 
448
  if not output_attentions:
 
484
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
485
  query_states, key_states, value_states, attention_mask, query_length
486
  )
 
487
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
488
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
489
 
 
499
  softmax_scale=softmax_scale,
500
  causal=causal,
501
  )
 
502
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
503
  else:
504
  attn_output = flash_attn_func(
 
509
 
510
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
511
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
 
512
 
513
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
514
  key_layer = index_first_axis(
515
  key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
516
  )
517
  value_layer = index_first_axis(
518
  value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
519
  )
520
+
521
  if query_length == kv_seq_len:
522
  query_layer = index_first_axis(
523
  query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
549
  class GemmoeSdpaAttention(GemmoeAttention):
550
  """
551
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
552
+ GemmoeAttention as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
553
  SDPA API.
554
  """
555
 
556
+ def repeat_kv(self, x, n_rep):
557
+ """
558
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
559
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
560
+ """
561
+ batch, num_key_value_heads, slen, head_dim = x.shape
562
+ if n_rep == 1:
563
+ return x
564
+ x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
565
+ return x.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
566
+
567
  def forward(
568
  self,
569
  hidden_states: torch.Tensor,
 
572
  past_key_value: Optional[Cache] = None,
573
  output_attentions: bool = False,
574
  use_cache: bool = False,
575
+ cache_position: Optional[torch.LongTensor] = None,
576
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
577
  if output_attentions:
578
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
579
+ # logger.warning_once(
580
+ "GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
581
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
582
+ # )
583
+
584
  return super().forward(
585
  hidden_states=hidden_states,
586
  attention_mask=attention_mask,
 
588
  past_key_value=past_key_value,
589
  output_attentions=output_attentions,
590
  use_cache=use_cache,
591
+ cache_position=cache_position,
592
  )
593
+
594
  bsz, q_len, _ = hidden_states.size()
595
 
596
  query_states = self.q_proj(hidden_states)
 
601
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
602
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
603
 
604
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
605
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
 
 
606
 
607
+ past_key_value = getattr(self, "past_key_value", past_key_value)
608
  if past_key_value is not None:
609
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
610
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
611
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
612
 
613
+ key_states = self.repeat_kv(key_states, self.num_key_value_groups)
614
+ value_states = self.repeat_kv(value_states, self.num_key_value_groups)
615
 
616
+ causal_mask = attention_mask
617
+ if attention_mask is not None and cache_position is not None:
618
+ causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
+
620
+ # Ensure query, key, and value states have the same dtype
621
+ common_dtype = query_states.dtype
622
+ key_states = key_states.to(dtype=common_dtype)
623
+ value_states = value_states.to(dtype=common_dtype)
624
 
625
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
626
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
627
+ if query_states.device.type == "cuda" and causal_mask is not None:
628
  query_states = query_states.contiguous()
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
632
+ # Cast causal_mask to the same dtype as query_states
633
+ if causal_mask is not None:
634
+ causal_mask = causal_mask.to(dtype=query_states.dtype)
635
+
636
  attn_output = torch.nn.functional.scaled_dot_product_attention(
637
  query_states,
638
  key_states,
639
  value_states,
640
+ attn_mask=causal_mask,
641
  dropout_p=self.attention_dropout if self.training else 0.0,
 
 
642
  )
643
 
644
  attn_output = attn_output.transpose(1, 2).contiguous()
645
+ attn_output = attn_output.view(bsz, q_len, -1)
 
646
  attn_output = self.o_proj(attn_output)
647
 
648
  return attn_output, None, past_key_value
 
653
  "sdpa": GemmoeSdpaAttention,
654
  }
655
 
656
+ class GemmoeMLP(nn.Module):
657
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
658
  super().__init__()
659
+ self.config = config
660
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
661
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
662
+
663
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
664
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
665
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
666
+ self.act_fn = ACT2FN[config.hidden_act]
667
+
668
+ def forward(self, x):
669
+ if self.config.pretraining_tp > 1:
670
+ slice = self.intermediate_size // self.config.pretraining_tp
671
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
672
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
673
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
674
 
675
+ gate_proj = torch.cat(
676
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
677
+ )
678
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
679
 
680
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
681
+ down_proj = [
682
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
683
+ ]
684
+ down_proj = sum(down_proj)
685
+ else:
686
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
687
 
688
+ return down_proj
 
 
 
689
 
690
+ class MoEGate(nn.Module):
691
+ def __init__(self, config):
692
+ super().__init__()
693
+ self.config = config
694
+ self.top_k = config.num_experts_per_tok
695
+ self.n_routed_experts = config.n_routed_experts
696
+
697
+ self.scoring_func = config.scoring_func
698
+ self.alpha = config.aux_loss_alpha
699
+ self.seq_aux = config.seq_aux
700
+
701
+ # topk selection algorithm
702
+ self.norm_topk_prob = config.norm_topk_prob
703
+ self.gating_dim = config.hidden_size
704
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
705
+ self.reset_parameters()
706
+
707
+ def reset_parameters(self) -> None:
708
+ import torch.nn.init as init
709
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
710
+
711
+ def forward(self, hidden_states):
712
+ bsz, seq_len, h = hidden_states.shape
713
+ ### compute gating score
714
+ hidden_states = hidden_states.view(-1, h)
715
+ logits = F.linear(hidden_states, self.weight, None)
716
+ if self.scoring_func == 'softmax':
717
+ scores = logits.softmax(dim=-1)
718
+ else:
719
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
720
+
721
+ ### select top-k experts
722
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
723
+
724
+ ### norm gate to sum 1
725
+ if self.top_k > 1 and self.norm_topk_prob:
726
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
727
+ topk_weight = topk_weight / denominator
728
+
729
+ ### expert-level computation auxiliary loss
730
+ if self.training and self.alpha > 0.0:
731
+ scores_for_aux = scores
732
+ aux_topk = self.top_k
733
+ # always compute aux loss based on the naive greedy topk method
734
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
735
+ if self.seq_aux:
736
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
737
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
738
+ ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
739
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
740
+ else:
741
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
742
+ ce = mask_ce.float().mean(0)
743
+ Pi = scores_for_aux.mean(0)
744
+ fi = ce * self.n_routed_experts
745
+ aux_loss = (Pi * fi).sum() * self.alpha
746
+ else:
747
+ aux_loss = None
748
+ return topk_idx, topk_weight, aux_loss
749
+
750
+
751
  class GemmoeSparseMoeBlock(nn.Module):
752
  def __init__(self, config):
753
  super().__init__()
 
756
  self.num_experts = config.num_local_experts
757
  self.top_k = 2
758
 
759
+ self.gate = MoEGate(config)
 
760
 
761
+ self.experts = nn.ModuleList([GemmoeMLP(config) for _ in range(self.num_experts)])
762
 
763
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
764
  batch_size, sequence_length, hidden_dim = hidden_states.shape
765
  hidden_states = hidden_states.view(-1, hidden_dim)
766
 
767
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
 
 
 
 
768
 
769
  # we cast back to the input dtype
770
  topk_weight = topk_weight.to(hidden_states.dtype)
 
782
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
783
 
784
  final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
785
+ return final_hidden_states, aux_loss
786
+
787
+ class AddAuxiliaryLoss(torch.autograd.Function):
788
+ """
789
+ The trick function of adding auxiliary (aux) loss,
790
+ which includes the gradient of the aux loss during backpropagation.
791
+ """
792
+ @staticmethod
793
+ def forward(ctx, x, loss):
794
+ assert loss.numel() == 1
795
+ ctx.dtype = loss.dtype
796
+ ctx.required_aux_loss = loss.requires_grad
797
+ return x
798
 
799
+ @staticmethod
800
+ def backward(ctx, grad_output):
801
+ grad_loss = None
802
+ if ctx.required_aux_loss:
803
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
804
+ return grad_output, grad_loss
805
 
806
+
807
  class GemmoeDecoderLayer(nn.Module):
808
  def __init__(self, config: GemmoeConfig, layer_idx: int):
809
  super().__init__()
810
  self.hidden_size = config.hidden_size
 
811
 
812
+ self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
 
 
 
 
813
 
814
+ self.block_sparse_moe = GemmoeSparseMoeBlock(config)
815
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
816
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
817
 
 
824
  output_attentions: Optional[bool] = False,
825
  output_router_logits: Optional[bool] = False,
826
  use_cache: Optional[bool] = False,
827
+ cache_position: Optional[torch.LongTensor] = None,
828
  **kwargs,
829
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830
  residual = hidden_states
831
  hidden_states = self.input_layernorm(hidden_states)
832
 
 
838
  past_key_value=past_key_value,
839
  output_attentions=output_attentions,
840
  use_cache=use_cache,
841
+ cache_position=cache_position,
842
  **kwargs,
843
  )
844
  hidden_states = residual + hidden_states
 
846
  # Fully Connected
847
  residual = hidden_states
848
  hidden_states = self.post_attention_layernorm(hidden_states)
849
+ hidden_states, aux_loss = self.block_sparse_moe(hidden_states)
 
 
 
 
 
850
  hidden_states = residual + hidden_states
851
 
852
+ if aux_loss is not None:
853
+ hidden_states = AddAuxiliaryLoss.apply(hidden_states, aux_loss)
854
+
855
  outputs = (hidden_states,)
856
 
857
  if output_attentions:
 
859
 
860
  if use_cache:
861
  outputs += (present_key_value,)
 
 
 
862
 
863
  return outputs
864
 
865
  GEMMOE_START_DOCSTRING = r"""
866
+ This model inherits from [PreTrainedModel]. Check the superclass documentation for the generic methods the
867
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
868
+ etc.)
 
 
 
 
 
 
 
 
 
 
869
  """
870
 
871
  @add_start_docstrings(
 
874
  )
875
 
876
  class GemmoePreTrainedModel(PreTrainedModel):
877
+ config_class = GemmoeConfig
878
+ base_model_prefix = "model"
879
+ supports_gradient_checkpointing = True
880
+ _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
881
+ _no_split_modules = ["GemmoeDecoderLayer"]
882
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
883
+ _supports_flash_attn_2 = True
884
+ _supports_sdpa = True
885
+ _supports_cache_class = True
886
+
887
+ def _init_weights(self, module):
888
+ std = self.config.initializer_range
889
+ if isinstance(module, nn.Linear):
890
+ module.weight.data.normal_(mean=0.0, std=std)
891
+ if module.bias is not None:
892
+ module.bias.data.zero_()
893
+ elif isinstance(module, nn.Embedding):
894
+ module.weight.data.normal_(mean=0.0, std=std)
895
+ if module.padding_idx is not None:
896
+ module.weight.data[module.padding_idx].zero_()
897
+
898
+ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
899
+ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
900
+ raise ValueError(
901
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
902
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
903
+ )
904
+ if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
905
+ causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
906
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
907
+
908
+ for layer in self.model.layers:
909
+ weights = layer.self_attn.o_proj.weight
910
+ layer.self_attn.past_key_value = cache_cls(
911
+ self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
912
+ )
913
+
914
+ def _reset_cache(self):
915
+ for layer in self.model.layers:
916
+ layer.self_attn.past_key_value = None
917
+
918
+ GEMMOE_INPUTS_DOCSTRING = r"""
919
+ Args:
920
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length)):
921
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
922
+ it.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
923
  """
924
 
925
  @add_start_docstrings(
 
928
  )
929
 
930
  class GemmoeModel(GemmoePreTrainedModel):
931
+ """
932
+ Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [GemmoeDecoderLayer]Args:
933
+ config: GemmoeConfig
934
+ """
935
+
936
+
937
+ def __init__(self, config: GemmoeConfig):
938
+ super().__init__(config)
939
+ self.padding_idx = config.pad_token_id
940
+ self.vocab_size = config.vocab_size
941
+
942
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
943
+ self.layers = nn.ModuleList(
944
+ [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
945
+ )
946
+
947
+ self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
948
+
949
+ self.gradient_checkpointing = False
950
+
951
+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
952
+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
953
+ causal_mask = torch.full(
954
+ (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
955
+ )
956
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
957
+
958
+ # Initialize weights and apply final processing
959
+ self.post_init()
960
+
961
+ def get_input_embeddings(self):
962
+ return self.embed_tokens
963
+
964
+ def set_input_embeddings(self, value):
965
+ self.embed_tokens = value
966
+
967
+ @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
968
+ @replace_return_docstrings(output_type=MoeModelOutputWithPast, config_class=_CONFIG_FOR_DOC)
969
+ def forward(
970
+ self,
971
+ input_ids: torch.LongTensor = None,
972
+ attention_mask: Optional[torch.Tensor] = None,
973
+ position_ids: Optional[torch.LongTensor] = None,
974
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
975
+ inputs_embeds: Optional[torch.FloatTensor] = None,
976
+ use_cache: Optional[bool] = None,
977
+ output_attentions: Optional[bool] = None,
978
+ output_hidden_states: Optional[bool] = None,
979
+ output_router_logits: Optional[bool] = None,
980
+ return_dict: Optional[bool] = None,
981
+ cache_position: Optional[torch.LongTensor] = None,
982
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
983
+ """
984
+ Forward pass of the sequence classification model.
985
+
986
+ Args:
987
+ input_ids: Input token IDs.
988
+ attention_mask: Attention mask.
989
+ position_ids: Position IDs.
990
+ past_key_values: Past key-value pairs.
991
+ inputs_embeds: Input embeddings.
992
+ labels: Labels for sequence classification.
993
+ use_cache: Whether to use cache.
994
+ output_attentions: Whether to output attentions.
995
+ output_hidden_states: Whether to output hidden states.
996
+ return_dict: Whether to return a dictionary or tuple.
997
+
998
+ Returns:
999
+ Output of the sequence classification model.
1000
+ """
1001
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1002
+ output_hidden_states = (
1003
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1004
+ )
1005
+ output_router_logits = (
1006
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1007
+ )
1008
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1009
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1010
+
1011
+ if (input_ids is None) ^ (inputs_embeds is not None):
1012
+ raise ValueError(
1013
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1014
+ )
1015
+
1016
+ if self.gradient_checkpointing and self.training and use_cache:
1017
+ logger.warning_once(
1018
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1019
+ )
1020
+ use_cache = False
1021
+
1022
+ if inputs_embeds is None:
1023
+ inputs_embeds = self.embed_tokens(input_ids)
1024
+
1025
+ past_seen_tokens = 0
1026
+ if use_cache: # kept for BC (cache positions)
1027
+ if not isinstance(past_key_values, StaticCache):
1028
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1029
+ past_seen_tokens = past_key_values.get_seq_length()
1030
+
1031
+ if cache_position is None:
1032
+ cache_position = torch.arange(
1033
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1034
+ )
1035
+
1036
+ if position_ids is None:
1037
+ position_ids = cache_position.unsqueeze(0)
1038
+
1039
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
1040
+
1041
+ hidden_states = inputs_embeds
1042
+
1043
+ # Normalize
1044
+ scale_factor = torch.tensor(math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
1045
+ hidden_states = hidden_states * scale_factor
1046
+ # Decoder layers
1047
+ all_hidden_states = () if output_hidden_states else None
1048
+ all_self_attns = () if output_attentions else None
1049
+ all_router_logits = () if output_router_logits else None
1050
+ next_decoder_cache = None
1051
+
1052
+ for decoder_layer in self.layers:
1053
+ if output_hidden_states:
1054
+ all_hidden_states += (hidden_states,)
1055
+
1056
+ if self.gradient_checkpointing and self.training:
1057
+ layer_outputs = self._gradient_checkpointing_func(
1058
+ decoder_layer.__call__,
1059
+ hidden_states,
1060
+ causal_mask,
1061
+ position_ids,
1062
+ past_key_values,
1063
+ output_attentions,
1064
+ output_router_logits,
1065
+ use_cache,
1066
+ cache_position,
1067
+ )
1068
+ else:
1069
+ layer_outputs = decoder_layer(
1070
+ hidden_states,
1071
+ attention_mask=causal_mask,
1072
+ position_ids=position_ids,
1073
+ past_key_value=past_key_values,
1074
+ output_attentions=output_attentions,
1075
+ output_router_logits=output_router_logits,
1076
+ use_cache=use_cache,
1077
+ cache_position=cache_position,
1078
+ )
1079
+
1080
+ hidden_states = layer_outputs[0]
1081
+ if use_cache:
1082
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1083
+ if output_attentions:
1084
+ all_self_attns += (layer_outputs[1],)
1085
+ if output_router_logits:
1086
+ all_router_logits += (layer_outputs[-1],)
1087
+
1088
+ hidden_states = self.norm(hidden_states)
1089
+
1090
+ # Add hidden states from the last decoder layer
1091
+ if output_hidden_states:
1092
+ all_hidden_states += (hidden_states,)
1093
+
1094
+ next_cache = None
1095
+ if use_cache:
1096
+ next_cache = (
1097
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1098
+ )
1099
+
1100
+ if not return_dict:
1101
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None)
1102
+
1103
+ return MoeModelOutputWithPast(
1104
+ last_hidden_state=hidden_states,
1105
+ past_key_values=next_cache,
1106
+ hidden_states=all_hidden_states,
1107
+ attentions=all_self_attns,
1108
+ router_logits=all_router_logits
1109
+ )
1110
+
1111
+ def _update_causal_mask(self, attention_mask, input_tensor):
1112
+ """
1113
+ Update the causal mask based on the attention mask and input tensor.
1114
+
1115
+ Args:
1116
+ attention_mask (torch.Tensor): The attention mask.
1117
+ input_tensor (torch.Tensor): The input tensor.
1118
+
1119
+ Returns:
1120
+ torch.Tensor: The updated causal mask.
1121
+ """
1122
+
1123
+ if self.config._attn_implementation == "flash_attention_2":
1124
+ if attention_mask is not None and 0.0 in attention_mask:
1125
+ return attention_mask
1126
+ return None
1127
+
1128
+ batch_size, seq_length = input_tensor.shape[:2]
1129
+ dtype = input_tensor.dtype
1130
+ device = input_tensor.device
1131
+
1132
+ # support going beyond cached `max_position_embedding`
1133
+ if seq_length > self.causal_mask.shape[-1]:
1134
+ logger.info(f"Resizing causal mask buffer from {self.causal_mask.shape[-1]} to {2 * self.causal_mask.shape[-1]}")
1135
+ causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
1136
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1137
+
1138
+ # We use the current dtype to avoid any overflows
1139
+ min_dtype = torch.finfo(dtype).min
1140
+ causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
1141
+ causal_mask = causal_mask.to(dtype=dtype, device=device)
1142
+
1143
+ if attention_mask is not None and attention_mask.dim() == 2:
1144
+ mask_length = attention_mask.shape[-1]
1145
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1146
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1147
+
1148
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None:
1149
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1150
+ is_tracing = (
1151
+ torch.jit.is_tracing()
1152
+ or isinstance(input_tensor, torch.fx.Proxy)
1153
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1154
+ )
1155
+
1156
+ if not is_tracing and torch.any(attention_mask != 1):
1157
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
1158
+ # using left padding. This is required by
1159
+ # F.scaled_dot_product_attention memory-efficient attention path.
1160
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1161
+ causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
1162
+
1163
+ return causal_mask
1164
 
1165
+ class GemmoeForCausalLM(GemmoePreTrainedModel):
1166
+ r"""
1167
+ The Gemmoe Model transformer with a language modeling head on top for causal language modeling (CLM).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1168
 
1169
+ Args:
1170
+ config (GemmoeConfig): The configuration object for the Gemmoe model.
1171
 
1172
+ Example usage:
1173
+ ```python
1174
+ >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1175
 
1176
+ >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1177
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
 
1178
 
1179
+ >>> prompt = "What is your favorite condiment?"
1180
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
1181
 
1182
+ >>> # Generate
1183
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1184
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1185
+ "What is your favorite condiment?"
1186
+ ```
1187
+ """
1188
  _tied_weights_keys = ["lm_head.weight"]
1189
 
1190
  def __init__(self, config):
 
1192
  self.model = GemmoeModel(config)
1193
  self.vocab_size = config.vocab_size
1194
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1195
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1196
+ self.num_experts = 8
1197
+ self.num_experts_per_tok = config.num_experts_per_tok
1198
 
1199
  # Initialize weights and apply final processing
1200
  self.post_init()
 
1217
  def get_decoder(self):
1218
  return self.model
1219
 
1220
+ @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1221
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1222
  def forward(
1223
  self,
1224
  input_ids: torch.LongTensor = None,
 
1230
  use_cache: Optional[bool] = None,
1231
  output_attentions: Optional[bool] = None,
1232
  output_hidden_states: Optional[bool] = None,
1233
+ output_router_logits: Optional[bool] = None,
1234
  return_dict: Optional[bool] = None,
1235
+ cache_position: Optional[torch.LongTensor] = None,
1236
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1237
  r"""
1238
  Args:
1239
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1240
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1241
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1242
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1243
 
1244
  Returns:
1245
 
 
1248
  ```python
1249
  >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1250
 
1251
+ >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1252
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
1253
 
1254
+ >>> prompt = "What is your favorite condiment?"
1255
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1256
 
1257
  >>> # Generate
1258
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1259
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1260
+ "What is your favorite condiment?"
1261
  ```"""
1262
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1263
+ output_router_logits = (
1264
+ output_router_logits if output_router_logits is not None else getattr(self.config, "output_router_logits", False)
1265
+ )
1266
  output_hidden_states = (
1267
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1268
  )
1269
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1270
 
 
1271
  outputs = self.model(
1272
  input_ids=input_ids,
1273
  attention_mask=attention_mask,
 
1277
  use_cache=use_cache,
1278
  output_attentions=output_attentions,
1279
  output_hidden_states=output_hidden_states,
1280
+ output_router_logits=output_router_logits,
1281
  return_dict=return_dict,
1282
+ cache_position=cache_position,
1283
  )
1284
 
1285
  hidden_states = outputs[0]
1286
+
1287
+ # Ensure hidden_states and lm_head have compatible dtypes
1288
+ hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
1289
+
1290
+ logits = self.lm_head(hidden_states)
 
 
1291
 
1292
  loss = None
1293
  if labels is not None:
 
1294
  shift_logits = logits[..., :-1, :].contiguous()
1295
  shift_labels = labels[..., 1:].contiguous()
 
1296
  loss_fct = CrossEntropyLoss()
1297
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1298
  shift_labels = shift_labels.view(-1)
 
1299
  shift_labels = shift_labels.to(shift_logits.device)
1300
  loss = loss_fct(shift_logits, shift_labels)
1301
 
1302
+ aux_loss = None
1303
+ if output_router_logits:
1304
+ router_logits = outputs.router_logits if return_dict else outputs[-1]
1305
+ if router_logits is not None:
1306
+ aux_loss = load_balancing_loss_func(
1307
+ router_logits,
1308
+ self.num_experts,
1309
+ self.num_experts_per_tok,
1310
+ attention_mask,
1311
+ )
1312
+ if labels is not None:
1313
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
1314
+
1315
  if not return_dict:
1316
  output = (logits,) + outputs[1:]
1317
+ if aux_loss is not None:
1318
+ output = (aux_loss,) + output
1319
  return (loss,) + output if loss is not None else output
1320
 
1321
+ return MoeCausalLMOutputWithPast(
1322
  loss=loss,
1323
+ aux_loss=aux_loss,
1324
  logits=logits,
1325
  past_key_values=outputs.past_key_values,
1326
  hidden_states=outputs.hidden_states,
1327
  attentions=outputs.attentions,
1328
+ router_logits=outputs.router_logits,
1329
  )
1330
 
1331
  def prepare_inputs_for_generation(
1332
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1333
  ):
1334
+ past_length = 0
1335
  if past_key_values is not None:
1336
  if isinstance(past_key_values, Cache):
1337
  cache_length = past_key_values.get_seq_length()
 
1341
  cache_length = past_length = past_key_values[0][0].shape[2]
1342
  max_cache_length = None
1343
 
 
 
 
 
1344
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1345
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
 
 
1346
  elif past_length < input_ids.shape[1]:
1347
  input_ids = input_ids[:, past_length:]
1348
+
 
 
1349
  if (
1350
  max_cache_length is not None
1351
  and attention_mask is not None
 
1355
 
1356
  position_ids = kwargs.get("position_ids", None)
1357
  if attention_mask is not None and position_ids is None:
 
1358
  position_ids = attention_mask.long().cumsum(-1) - 1
1359
  position_ids.masked_fill_(attention_mask == 0, 1)
1360
  if past_key_values:
1361
  position_ids = position_ids[:, -input_ids.shape[1] :]
1362
 
1363
+ if self.generation_config.cache_implementation == "static":
1364
+ cache_position = kwargs.get("cache_position", None)
1365
+ if cache_position is None:
1366
+ past_length = 0
1367
+ else:
1368
+ past_length = cache_position[-1] + 1
1369
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1370
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1371
+
1372
+ cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1373
+
1374
  if inputs_embeds is not None and past_key_values is None:
1375
  model_inputs = {"inputs_embeds": inputs_embeds}
1376
  else:
1377
+ model_inputs = {"input_ids": input_ids.contiguous()}
1378
 
1379
  model_inputs.update(
1380
  {
1381
+ "position_ids": position_ids.contiguous(),
1382
+ "cache_position": cache_position,
1383
  "past_key_values": past_key_values,
1384
  "use_cache": kwargs.get("use_cache"),
1385
  "attention_mask": attention_mask,
1386
  }
1387
  )
1388
+
1389
  return model_inputs
1390
 
1391
  @staticmethod
 
1418
  self.num_labels = config.num_labels
1419
  self.model = GemmoeModel(config)
1420
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
 
1421
  # Initialize weights and apply final processing
1422
  self.post_init()
1423
 
 
1427
  def set_input_embeddings(self, value):
1428
  self.model.embed_tokens = value
1429
 
1430
+ @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1431
+ @replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
1432
  def forward(
1433
  self,
1434
  input_ids: torch.LongTensor = None,
 
1442
  output_hidden_states: Optional[bool] = None,
1443
  return_dict: Optional[bool] = None,
1444
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
 
 
 
 
 
1445
  """
1446
+ Forward pass of the sequence classification model.
1447
 
1448
+ Args:
1449
+ input_ids (torch.LongTensor, optional): Input token IDs.
1450
+ attention_mask (torch.Tensor, optional): Attention mask.
1451
+ position_ids (torch.LongTensor, optional): Position IDs.
1452
+ past_key_values (List[torch.FloatTensor], optional): Past key-value pairs.
1453
+ inputs_embeds (torch.FloatTensor, optional): Input embeddings.
1454
+ labels (torch.LongTensor, optional): Labels for sequence classification.
1455
+ use_cache (bool, optional): Whether to use cache.
1456
+ output_attentions (bool, optional): Whether to output attentions.
1457
+ output_hidden_states (bool, optional): Whether to output hidden states.
1458
+ return_dict (bool, optional): Whether to return a dictionary or tuple.
1459
+
1460
+ Returns:
1461
+ Union[Tuple, SequenceClassifierOutputWithPast]: Output of the sequence classification model.
1462
+ """
1463
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1464
  transformer_outputs = self.model(
1465
  input_ids,
1466
  attention_mask=attention_mask,
 
1486
  sequence_lengths = -1
1487
  else:
1488
  if input_ids is not None:
1489
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1490
+ sequence_lengths = sequence_lengths.clamp(min=0).to(logits.device)
 
1491
  else:
1492
  sequence_lengths = -1
1493
 
 
1516
  elif self.config.problem_type == "multi_label_classification":
1517
  loss_fct = BCEWithLogitsLoss()
1518
  loss = loss_fct(pooled_logits, labels)
1519
+
1520
  if not return_dict:
1521
  output = (pooled_logits,) + transformer_outputs[1:]
1522
  return ((loss,) + output) if loss is not None else output