yairschiff commited on
Commit
5f7b219
1 Parent(s): a951351

Update modeling_rcps.py

Browse files

Enable `prenorm=False` for RCPSAddNormWrapper which prevent returning the residual

Files changed (1) hide show
  1. modeling_rcps.py +4 -3
modeling_rcps.py CHANGED
@@ -101,11 +101,12 @@ class RCPSAddNormWrapper(RCPSWrapper):
101
  def __init__(self, submodule: nn.Module):
102
  super().__init__(submodule)
103
 
104
- def forward(self, x, residual=None):
105
  """
106
  Args:
107
  x: Input tensor of shape (batch_size, seq_len, channels)
108
  residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
 
109
  """
110
  n_channels = x.shape[-1]
111
  if residual is None:
@@ -123,7 +124,7 @@ class RCPSAddNormWrapper(RCPSWrapper):
123
  residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
124
  x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
125
 
126
- return x, residual
127
 
128
 
129
  class RCPSMambaBlock(nn.Module):
@@ -159,7 +160,7 @@ class RCPSMambaBlock(nn.Module):
159
  inference_params: inference parameters for mixer.
160
  """
161
  if not self.fused_add_norm:
162
- hidden_states, residual = self.norm(hidden_states, residual=residual)
163
  if self.residual_in_fp32:
164
  residual = residual.to(torch.float32)
165
  else:
 
101
  def __init__(self, submodule: nn.Module):
102
  super().__init__(submodule)
103
 
104
+ def forward(self, x, residual=None, prenorm=True):
105
  """
106
  Args:
107
  x: Input tensor of shape (batch_size, seq_len, channels)
108
  residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
109
+ prenorm: Whether to return residual.
110
  """
111
  n_channels = x.shape[-1]
112
  if residual is None:
 
124
  residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
125
  x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
126
 
127
+ return x if not prenorm else (x, residual)
128
 
129
 
130
  class RCPSMambaBlock(nn.Module):
 
160
  inference_params: inference parameters for mixer.
161
  """
162
  if not self.fused_add_norm:
163
+ hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
164
  if self.residual_in_fp32:
165
  residual = residual.to(torch.float32)
166
  else: