Commit
•
ff9ec2f
1
Parent(s):
ad78f30
Update modeling_hyena.py
Browse files- modeling_hyena.py +1 -1
modeling_hyena.py
CHANGED
@@ -463,7 +463,7 @@ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
|
|
463 |
shift_labels = labels[..., 1:].contiguous()
|
464 |
# Flatten the tokens
|
465 |
loss_fct = nn.CrossEntropyLoss()
|
466 |
-
shift_logits = shift_logits.view(-1, self.
|
467 |
shift_labels = shift_labels.view(-1)
|
468 |
# Enable model parallelism
|
469 |
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
463 |
shift_labels = labels[..., 1:].contiguous()
|
464 |
# Flatten the tokens
|
465 |
loss_fct = nn.CrossEntropyLoss()
|
466 |
+
shift_logits = shift_logits.view(-1, self.vocab_size)
|
467 |
shift_labels = shift_labels.view(-1)
|
468 |
# Enable model parallelism
|
469 |
shift_labels = shift_labels.to(shift_logits.device)
|