yairschiff commited on
Commit
bceb79e
1 Parent(s): d8238ba

Update modeling_rcps.py

Browse files

Enable prenorm = False to prevent returning residual in RCPSAddNormWrapper

Files changed (1) hide show
  1. modeling_rcps.py +9 -3
modeling_rcps.py CHANGED
@@ -101,11 +101,12 @@ class RCPSAddNormWrapper(RCPSWrapper):
101
  def __init__(self, submodule: nn.Module):
102
  super().__init__(submodule)
103
 
104
- def forward(self, x, residual=None):
105
  """
106
  Args:
107
  x: Input tensor of shape (batch_size, seq_len, channels)
108
  residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
 
109
  """
110
  n_channels = x.shape[-1]
111
  if residual is None:
@@ -123,7 +124,7 @@ class RCPSAddNormWrapper(RCPSWrapper):
123
  residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
124
  x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
125
 
126
- return x, residual
127
 
128
 
129
  class RCPSMambaBlock(nn.Module):
@@ -147,6 +148,11 @@ class RCPSMambaBlock(nn.Module):
147
  self.mixer = RCPSWrapper(mixer_cls(dim))
148
  norm_f = norm_cls(dim)
149
  self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
 
 
 
 
 
150
 
151
  def forward(
152
  self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
@@ -159,7 +165,7 @@ class RCPSMambaBlock(nn.Module):
159
  inference_params: inference parameters for mixer.
160
  """
161
  if not self.fused_add_norm:
162
- hidden_states, residual = self.norm(hidden_states, residual=residual)
163
  if self.residual_in_fp32:
164
  residual = residual.to(torch.float32)
165
  else:
 
101
  def __init__(self, submodule: nn.Module):
102
  super().__init__(submodule)
103
 
104
+ def forward(self, x, residual=None, prenorm=False):
105
  """
106
  Args:
107
  x: Input tensor of shape (batch_size, seq_len, channels)
108
  residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
109
+ prenorm: Whether to return residual.
110
  """
111
  n_channels = x.shape[-1]
112
  if residual is None:
 
124
  residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
125
  x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
126
 
127
+ return x if not prenorm else (x, residual)
128
 
129
 
130
  class RCPSMambaBlock(nn.Module):
 
148
  self.mixer = RCPSWrapper(mixer_cls(dim))
149
  norm_f = norm_cls(dim)
150
  self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
151
+ if self.fused_add_norm:
152
+ assert RMSNorm is not None, "RMSNorm import fails"
153
+ assert isinstance(
154
+ self.norm, (nn.LayerNorm, RMSNorm)
155
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
156
 
157
  def forward(
158
  self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
 
165
  inference_params: inference parameters for mixer.
166
  """
167
  if not self.fused_add_norm:
168
+ hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
169
  if self.residual_in_fp32:
170
  residual = residual.to(torch.float32)
171
  else: