Update bert_layers.py
Browse files- bert_layers.py +33 -27
bert_layers.py
CHANGED
@@ -410,13 +410,13 @@ class BertEncoder(nn.Module):
|
|
410 |
attention_mask: torch.Tensor,
|
411 |
output_all_encoded_layers: Optional[bool] = True,
|
412 |
subset_mask: Optional[torch.Tensor] = None,
|
413 |
-
) -> List[torch.Tensor]:
|
414 |
|
415 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
416 |
extended_attention_mask = extended_attention_mask.to(
|
417 |
dtype=torch.float32) # fp16 compatibility
|
418 |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
419 |
-
|
420 |
attention_mask_bool = attention_mask.bool()
|
421 |
batch, seqlen = hidden_states.shape[:2]
|
422 |
# Unpad inputs and mask. It will remove tokens that are padded.
|
@@ -426,7 +426,7 @@ class BertEncoder(nn.Module):
|
|
426 |
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
|
427 |
hidden_states, indices, cu_seqlens, _ = unpad_input(
|
428 |
hidden_states, attention_mask_bool)
|
429 |
-
|
430 |
# Add alibi matrix to extended_attention_mask
|
431 |
if self._current_alibi_size < seqlen:
|
432 |
# Rebuild the alibi tensor when needed
|
@@ -440,17 +440,20 @@ class BertEncoder(nn.Module):
|
|
440 |
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
|
441 |
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
|
442 |
alibi_attn_mask = attn_bias + alibi_bias
|
443 |
-
|
444 |
all_encoder_layers = []
|
|
|
|
|
445 |
if subset_mask is None:
|
446 |
for layer_module in self.layer:
|
447 |
-
hidden_states = layer_module(hidden_states,
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
|
|
454 |
if output_all_encoded_layers:
|
455 |
all_encoder_layers.append(hidden_states)
|
456 |
# Pad inputs and mask. It will insert back zero-padded tokens.
|
@@ -462,28 +465,31 @@ class BertEncoder(nn.Module):
|
|
462 |
else:
|
463 |
for i in range(len(self.layer) - 1):
|
464 |
layer_module = self.layer[i]
|
465 |
-
hidden_states = layer_module(hidden_states,
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
|
|
472 |
if output_all_encoded_layers:
|
473 |
all_encoder_layers.append(hidden_states)
|
474 |
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
|
475 |
as_tuple=False).flatten()
|
476 |
-
hidden_states = self.layer[-1](hidden_states,
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
|
|
484 |
if not output_all_encoded_layers:
|
485 |
all_encoder_layers.append(hidden_states)
|
486 |
-
return all_encoder_layers
|
|
|
487 |
|
488 |
|
489 |
class BertPooler(nn.Module):
|
|
|
410 |
attention_mask: torch.Tensor,
|
411 |
output_all_encoded_layers: Optional[bool] = True,
|
412 |
subset_mask: Optional[torch.Tensor] = None,
|
413 |
+
) -> Tuple[List[torch.Tensor], torch.Tensor]: # Modify return type to include attention weights
|
414 |
|
415 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
416 |
extended_attention_mask = extended_attention_mask.to(
|
417 |
dtype=torch.float32) # fp16 compatibility
|
418 |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
419 |
+
|
420 |
attention_mask_bool = attention_mask.bool()
|
421 |
batch, seqlen = hidden_states.shape[:2]
|
422 |
# Unpad inputs and mask. It will remove tokens that are padded.
|
|
|
426 |
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
|
427 |
hidden_states, indices, cu_seqlens, _ = unpad_input(
|
428 |
hidden_states, attention_mask_bool)
|
429 |
+
|
430 |
# Add alibi matrix to extended_attention_mask
|
431 |
if self._current_alibi_size < seqlen:
|
432 |
# Rebuild the alibi tensor when needed
|
|
|
440 |
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
|
441 |
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
|
442 |
alibi_attn_mask = attn_bias + alibi_bias
|
443 |
+
|
444 |
all_encoder_layers = []
|
445 |
+
all_attention_weights = [] # List to store attention weights
|
446 |
+
|
447 |
if subset_mask is None:
|
448 |
for layer_module in self.layer:
|
449 |
+
hidden_states, attention_weights = layer_module(hidden_states,
|
450 |
+
cu_seqlens,
|
451 |
+
seqlen,
|
452 |
+
None,
|
453 |
+
indices,
|
454 |
+
attn_mask=attention_mask,
|
455 |
+
bias=alibi_attn_mask)
|
456 |
+
all_attention_weights.append(attention_weights) # Store attention weights
|
457 |
if output_all_encoded_layers:
|
458 |
all_encoder_layers.append(hidden_states)
|
459 |
# Pad inputs and mask. It will insert back zero-padded tokens.
|
|
|
465 |
else:
|
466 |
for i in range(len(self.layer) - 1):
|
467 |
layer_module = self.layer[i]
|
468 |
+
hidden_states, attention_weights = layer_module(hidden_states,
|
469 |
+
cu_seqlens,
|
470 |
+
seqlen,
|
471 |
+
None,
|
472 |
+
indices,
|
473 |
+
attn_mask=attention_mask,
|
474 |
+
bias=alibi_attn_mask)
|
475 |
+
all_attention_weights.append(attention_weights) # Store attention weights
|
476 |
if output_all_encoded_layers:
|
477 |
all_encoder_layers.append(hidden_states)
|
478 |
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
|
479 |
as_tuple=False).flatten()
|
480 |
+
hidden_states, attention_weights = self.layer[-1](hidden_states,
|
481 |
+
cu_seqlens,
|
482 |
+
seqlen,
|
483 |
+
subset_idx=subset_idx,
|
484 |
+
indices=indices,
|
485 |
+
attn_mask=attention_mask,
|
486 |
+
bias=alibi_attn_mask)
|
487 |
+
all_attention_weights.append(attention_weights) # Store attention weights
|
488 |
+
|
489 |
if not output_all_encoded_layers:
|
490 |
all_encoder_layers.append(hidden_states)
|
491 |
+
return all_encoder_layers, all_attention_weights # Return both hidden states and attention weights
|
492 |
+
|
493 |
|
494 |
|
495 |
class BertPooler(nn.Module):
|