Rocketknight1 HF staff commited on
Commit
e0522bc
1 Parent(s): 6e84b23

Update modeling_hyena.py

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +1 -1
modeling_hyena.py CHANGED
@@ -463,7 +463,7 @@ class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
463
  shift_labels = labels[..., 1:].contiguous()
464
  # Flatten the tokens
465
  loss_fct = nn.CrossEntropyLoss()
466
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
467
  shift_labels = shift_labels.view(-1)
468
  # Enable model parallelism
469
  shift_labels = shift_labels.to(shift_logits.device)
 
463
  shift_labels = labels[..., 1:].contiguous()
464
  # Flatten the tokens
465
  loss_fct = nn.CrossEntropyLoss()
466
+ shift_logits = shift_logits.view(-1, self.vocab_size)
467
  shift_labels = shift_labels.view(-1)
468
  # Enable model parallelism
469
  shift_labels = shift_labels.to(shift_logits.device)