tomeras1 mber commited on
Commit
86c5df0
1 Parent(s): 62df05e

Fix small typo in JambaDecoder (#24)

Browse files

- Fix small typo in JambaDecoder (3e2e3bf1a6540e78b3d6e8f58adbbab384d6d32d)


Co-authored-by: Moshe Berchansky <mber@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_jamba.py +1 -1
modeling_jamba.py CHANGED
@@ -1053,7 +1053,7 @@ class JambaMambaMixer(nn.Module):
1053
  ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
1054
  scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
1055
  scan_outputs.append(scan_output[:, :, 0])
1056
- scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
1057
  scan_output = scan_output + (hidden_states * self.D[None, :, None])
1058
  scan_output = (scan_output * self.act(gate))
1059
 
 
1053
  ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
1054
  scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
1055
  scan_outputs.append(scan_output[:, :, 0])
1056
+ scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len]
1057
  scan_output = scan_output + (hidden_states * self.D[None, :, None])
1058
  scan_output = (scan_output * self.act(gate))
1059