yairschiff commited on
Commit
c699cf9
1 Parent(s): 5d97a93

Update modeling_caduceus.py

Browse files

Prevent returning residual on final add norm for RCPS

Files changed (1) hide show
  1. modeling_caduceus.py +2 -1
modeling_caduceus.py CHANGED
@@ -213,7 +213,8 @@ class CaduceusMixerModel(nn.Module):
213
 
214
  if not self.fused_add_norm:
215
  if self.rcps:
216
- hidden_states = self.norm_f(hidden_states, residual=residual)
 
217
  else:
218
  residual = (hidden_states + residual) if residual is not None else hidden_states
219
  hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
 
213
 
214
  if not self.fused_add_norm:
215
  if self.rcps:
216
+ # Set prenorm=False here since we don't need the residual
217
+ hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False)
218
  else:
219
  residual = (hidden_states + residual) if residual is not None else hidden_states
220
  hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))