Files changed (2) hide show
  1. mlp.py +8 -2
  2. modeling_bert.py +2 -1
mlp.py CHANGED
@@ -33,6 +33,7 @@ class GLUMLP(nn.Module):
33
  in_features,
34
  hidden_features,
35
  activation,
 
36
  return_residual=False,
37
  hidden_dropout_prob=0.1
38
  ):
@@ -52,14 +53,19 @@ class GLUMLP(nn.Module):
52
  self.wo = nn.Linear(hidden_features, in_features)
53
  self.dropout = nn.Dropout(hidden_dropout_prob)
54
  self.return_residual = return_residual
 
55
  #self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
56
 
57
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
58
  residual_connection = hidden_states
59
  # compute the activation
60
  hidden_states = self.gated_layers(hidden_states)
61
- gated = hidden_states[:, : self.hidden_features]
62
- non_gated = hidden_states[:, self.hidden_features :]
 
 
 
 
63
  hidden_states = self.act(gated) * non_gated
64
  hidden_states = self.dropout(hidden_states)
65
  # multiply by the second matrix
 
33
  in_features,
34
  hidden_features,
35
  activation,
36
+ use_flash_attn,
37
  return_residual=False,
38
  hidden_dropout_prob=0.1
39
  ):
 
53
  self.wo = nn.Linear(hidden_features, in_features)
54
  self.dropout = nn.Dropout(hidden_dropout_prob)
55
  self.return_residual = return_residual
56
+ self.use_flash_attn = use_flash_attn
57
  #self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
58
 
59
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
60
  residual_connection = hidden_states
61
  # compute the activation
62
  hidden_states = self.gated_layers(hidden_states)
63
+ if self.use_flash_attn:
64
+ gated = hidden_states[:, : self.hidden_features]
65
+ non_gated = hidden_states[:, self.hidden_features :]
66
+ else:
67
+ gated = hidden_states[:, :, : self.hidden_features]
68
+ non_gated = hidden_states[:, :, self.hidden_features :]
69
  hidden_states = self.act(gated) * non_gated
70
  hidden_states = self.dropout(hidden_states)
71
  # multiply by the second matrix
modeling_bert.py CHANGED
@@ -114,6 +114,7 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
114
  GLUMLP,
115
  hidden_features=inner_dim,
116
  activation=config.hidden_act,
 
117
  hidden_dropout_prob=config.hidden_dropout_prob,
118
  return_residual=return_residual,
119
  )
@@ -802,4 +803,4 @@ class BertForMaskedLM(BertPreTrainedModel):
802
  loss=masked_lm_loss,
803
  prediction_logits=prediction_scores,
804
  seq_relationship_logits=seq_relationship_score,
805
- )
 
114
  GLUMLP,
115
  hidden_features=inner_dim,
116
  activation=config.hidden_act,
117
+ use_flash_attn=config.use_flash_attn,
118
  hidden_dropout_prob=config.hidden_dropout_prob,
119
  return_residual=return_residual,
120
  )
 
803
  loss=masked_lm_loss,
804
  prediction_logits=prediction_scores,
805
  seq_relationship_logits=seq_relationship_score,
806
+ )