Qubitium commited on
Commit
83e3af5
1 Parent(s): 444fa87

Update modeling_dbrx.py

Browse files
Files changed (1) hide show
  1. modeling_dbrx.py +8 -22
modeling_dbrx.py CHANGED
@@ -1,3 +1,4 @@
 
1
  """PyTorch Dbrx model."""
2
 
3
  import math
@@ -244,28 +245,13 @@ def resolve_ffn_act_fn(
244
  # Copied from LLaMaAttention
245
  #############################################################################
246
 
247
- def get_max_seqlen_in_batch(attention_mask):
248
- max_num = torch.max(attention_mask)
249
- # attention_mask: B x N
250
- counts = []
251
- for i in range(1, max_num + 1):
252
- counts.append(
253
- torch.sum(attention_mask == i, axis=-1)
254
- ) # shape: B, count length of data point maksed with i
255
- result = torch.stack(counts, axis=1)
256
- result = result.flatten()
257
- return result[result.nonzero()].squeeze(-1).to(dtype=torch.int32)
258
-
259
-
260
- def _get_unpad_data(attention_mask):
261
- seqlens_in_batch = get_max_seqlen_in_batch(
262
- attention_mask
263
- ) # attention_mask.sum(dim=-1, dtype=torch.int32)
264
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
265
  max_seqlen_in_batch = seqlens_in_batch.max().item()
266
- cu_seqlens = F.pad(
267
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
268
- )
269
  return (
270
  indices,
271
  cu_seqlens,
@@ -426,7 +412,7 @@ class DbrxFlashAttention2(DbrxAttention):
426
  **kwargs: Any,
427
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
428
  Optional[Tuple[torch.Tensor]]]:
429
- logger.info(
430
  'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
431
  )
432
  output_attentions = False
@@ -1459,4 +1445,4 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
1459
  reordered_past += (tuple(
1460
  past_state.index_select(0, beam_idx.to(past_state.device))
1461
  for past_state in layer_past),)
1462
- return reordered_past
 
1
+ # code adapted from https://huggingface.co/fahadh4ilyas
2
  """PyTorch Dbrx model."""
3
 
4
  import math
 
245
  # Copied from LLaMaAttention
246
  #############################################################################
247
 
248
+
249
+ def _get_unpad_data(attention_mask: torch.Tensor):
250
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
252
  max_seqlen_in_batch = seqlens_in_batch.max().item()
253
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
254
+ (1, 0))
 
255
  return (
256
  indices,
257
  cu_seqlens,
 
412
  **kwargs: Any,
413
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
414
  Optional[Tuple[torch.Tensor]]]:
415
+ logger.debug(
416
  'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
417
  )
418
  output_attentions = False
 
1445
  reordered_past += (tuple(
1446
  past_state.index_select(0, beam_idx.to(past_state.device))
1447
  for past_state in layer_past),)
1448
+ return reordered_past