jaandoui commited on
Commit
c16263c
1 Parent(s): 45d7a2a

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +33 -27
bert_layers.py CHANGED
@@ -410,13 +410,13 @@ class BertEncoder(nn.Module):
410
  attention_mask: torch.Tensor,
411
  output_all_encoded_layers: Optional[bool] = True,
412
  subset_mask: Optional[torch.Tensor] = None,
413
- ) -> List[torch.Tensor]:
414
 
415
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
416
  extended_attention_mask = extended_attention_mask.to(
417
  dtype=torch.float32) # fp16 compatibility
418
  extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
419
-
420
  attention_mask_bool = attention_mask.bool()
421
  batch, seqlen = hidden_states.shape[:2]
422
  # Unpad inputs and mask. It will remove tokens that are padded.
@@ -426,7 +426,7 @@ class BertEncoder(nn.Module):
426
  # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
427
  hidden_states, indices, cu_seqlens, _ = unpad_input(
428
  hidden_states, attention_mask_bool)
429
-
430
  # Add alibi matrix to extended_attention_mask
431
  if self._current_alibi_size < seqlen:
432
  # Rebuild the alibi tensor when needed
@@ -440,17 +440,20 @@ class BertEncoder(nn.Module):
440
  alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
441
  attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
442
  alibi_attn_mask = attn_bias + alibi_bias
443
-
444
  all_encoder_layers = []
 
 
445
  if subset_mask is None:
446
  for layer_module in self.layer:
447
- hidden_states = layer_module(hidden_states,
448
- cu_seqlens,
449
- seqlen,
450
- None,
451
- indices,
452
- attn_mask=attention_mask,
453
- bias=alibi_attn_mask)
 
454
  if output_all_encoded_layers:
455
  all_encoder_layers.append(hidden_states)
456
  # Pad inputs and mask. It will insert back zero-padded tokens.
@@ -462,28 +465,31 @@ class BertEncoder(nn.Module):
462
  else:
463
  for i in range(len(self.layer) - 1):
464
  layer_module = self.layer[i]
465
- hidden_states = layer_module(hidden_states,
466
- cu_seqlens,
467
- seqlen,
468
- None,
469
- indices,
470
- attn_mask=attention_mask,
471
- bias=alibi_attn_mask)
 
472
  if output_all_encoded_layers:
473
  all_encoder_layers.append(hidden_states)
474
  subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
475
  as_tuple=False).flatten()
476
- hidden_states = self.layer[-1](hidden_states,
477
- cu_seqlens,
478
- seqlen,
479
- subset_idx=subset_idx,
480
- indices=indices,
481
- attn_mask=attention_mask,
482
- bias=alibi_attn_mask)
483
-
 
484
  if not output_all_encoded_layers:
485
  all_encoder_layers.append(hidden_states)
486
- return all_encoder_layers
 
487
 
488
 
489
  class BertPooler(nn.Module):
 
410
  attention_mask: torch.Tensor,
411
  output_all_encoded_layers: Optional[bool] = True,
412
  subset_mask: Optional[torch.Tensor] = None,
413
+ ) -> Tuple[List[torch.Tensor], torch.Tensor]: # Modify return type to include attention weights
414
 
415
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
416
  extended_attention_mask = extended_attention_mask.to(
417
  dtype=torch.float32) # fp16 compatibility
418
  extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
419
+
420
  attention_mask_bool = attention_mask.bool()
421
  batch, seqlen = hidden_states.shape[:2]
422
  # Unpad inputs and mask. It will remove tokens that are padded.
 
426
  # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
427
  hidden_states, indices, cu_seqlens, _ = unpad_input(
428
  hidden_states, attention_mask_bool)
429
+
430
  # Add alibi matrix to extended_attention_mask
431
  if self._current_alibi_size < seqlen:
432
  # Rebuild the alibi tensor when needed
 
440
  alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
441
  attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
442
  alibi_attn_mask = attn_bias + alibi_bias
443
+
444
  all_encoder_layers = []
445
+ all_attention_weights = [] # List to store attention weights
446
+
447
  if subset_mask is None:
448
  for layer_module in self.layer:
449
+ hidden_states, attention_weights = layer_module(hidden_states,
450
+ cu_seqlens,
451
+ seqlen,
452
+ None,
453
+ indices,
454
+ attn_mask=attention_mask,
455
+ bias=alibi_attn_mask)
456
+ all_attention_weights.append(attention_weights) # Store attention weights
457
  if output_all_encoded_layers:
458
  all_encoder_layers.append(hidden_states)
459
  # Pad inputs and mask. It will insert back zero-padded tokens.
 
465
  else:
466
  for i in range(len(self.layer) - 1):
467
  layer_module = self.layer[i]
468
+ hidden_states, attention_weights = layer_module(hidden_states,
469
+ cu_seqlens,
470
+ seqlen,
471
+ None,
472
+ indices,
473
+ attn_mask=attention_mask,
474
+ bias=alibi_attn_mask)
475
+ all_attention_weights.append(attention_weights) # Store attention weights
476
  if output_all_encoded_layers:
477
  all_encoder_layers.append(hidden_states)
478
  subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
479
  as_tuple=False).flatten()
480
+ hidden_states, attention_weights = self.layer[-1](hidden_states,
481
+ cu_seqlens,
482
+ seqlen,
483
+ subset_idx=subset_idx,
484
+ indices=indices,
485
+ attn_mask=attention_mask,
486
+ bias=alibi_attn_mask)
487
+ all_attention_weights.append(attention_weights) # Store attention weights
488
+
489
  if not output_all_encoded_layers:
490
  all_encoder_layers.append(hidden_states)
491
+ return all_encoder_layers, all_attention_weights # Return both hidden states and attention weights
492
+
493
 
494
 
495
  class BertPooler(nn.Module):