Fix rcps channel splitting bug
Browse filesFixes issue identified here: https://github.com/kuleshov-group/caduceus/issues/5
- modeling_caduceus.py +2 -2
modeling_caduceus.py
CHANGED
@@ -543,8 +543,8 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
543 |
)
|
544 |
hidden_states = torch.stack(
|
545 |
[
|
546 |
-
transformer_outputs[0][..., :self.config.d_model
|
547 |
-
torch.flip(transformer_outputs[0][..., self.config.d_model
|
548 |
],
|
549 |
dim=-1
|
550 |
)
|
|
|
543 |
)
|
544 |
hidden_states = torch.stack(
|
545 |
[
|
546 |
+
transformer_outputs[0][..., :self.config.d_model],
|
547 |
+
torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])
|
548 |
],
|
549 |
dim=-1
|
550 |
)
|