yairschiff commited on
Commit
5799730
1 Parent(s): de5ba50

Fix rcps channel splitting bug

Browse files

Fixes issue identified here: https://github.com/kuleshov-group/caduceus/issues/5

Files changed (1) hide show
  1. modeling_caduceus.py +2 -2
modeling_caduceus.py CHANGED
@@ -543,8 +543,8 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
543
  )
544
  hidden_states = torch.stack(
545
  [
546
- transformer_outputs[0][..., :self.config.d_model // 2],
547
- torch.flip(transformer_outputs[0][..., self.config.d_model // 2:], dims=[1, 2])
548
  ],
549
  dim=-1
550
  )
 
543
  )
544
  hidden_states = torch.stack(
545
  [
546
+ transformer_outputs[0][..., :self.config.d_model],
547
+ torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])
548
  ],
549
  dim=-1
550
  )