Update modeling_pharia.py
#2
by
AzizBelaweid
- opened
- modeling_pharia.py +15 -2
modeling_pharia.py
CHANGED
@@ -764,9 +764,22 @@ class PhariaForCausalLM(PhariaPreTrainedModel):
|
|
764 |
|
765 |
hidden_states = outputs[0]
|
766 |
logits = self.lm_head(hidden_states)
|
767 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
return CausalLMOutputWithPast(
|
769 |
-
loss=
|
770 |
logits=logits,
|
771 |
past_key_values=outputs.past_key_values,
|
772 |
hidden_states=outputs.hidden_states,
|
|
|
764 |
|
765 |
hidden_states = outputs[0]
|
766 |
logits = self.lm_head(hidden_states)
|
767 |
+
loss = 0.0
|
768 |
+
if labels is not None:
|
769 |
+
# Shift logits and labels for causal language modeling
|
770 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
771 |
+
shift_labels = outputs['labels'][..., 1:].contiguous()
|
772 |
+
|
773 |
+
# Flatten the tokens
|
774 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
775 |
+
shift_labels = shift_labels.view(-1)
|
776 |
+
|
777 |
+
# Compute loss
|
778 |
+
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=1) # Pad token ID for Pharia is 1
|
779 |
+
loss = loss_fct(shift_logits, shift_labels)
|
780 |
+
|
781 |
return CausalLMOutputWithPast(
|
782 |
+
loss=loss,
|
783 |
logits=logits,
|
784 |
past_key_values=outputs.past_key_values,
|
785 |
hidden_states=outputs.hidden_states,
|