Upload modeling_fuxitranyu.py with huggingface_hub
Browse files- modeling_fuxitranyu.py +2 -1
modeling_fuxitranyu.py
CHANGED
@@ -30,6 +30,7 @@ from transformers.utils import (
|
|
30 |
)
|
31 |
from .configuration_fuxitranyu import FuxiTranyuConfig
|
32 |
|
|
|
33 |
try:
|
34 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
35 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
@@ -1250,10 +1251,10 @@ class FuxiTranyuForCausalLM(FuxiTranyuPreTrainedModel):
|
|
1250 |
logits = torch.cat(logits, dim=-1)
|
1251 |
else:
|
1252 |
logits = self.lm_head(hidden_states)
|
1253 |
-
logits = logits.float()
|
1254 |
|
1255 |
loss = None
|
1256 |
if labels is not None:
|
|
|
1257 |
# Shift so that tokens < n predict n
|
1258 |
shift_logits = logits[..., :-1, :].contiguous()
|
1259 |
shift_labels = labels[..., 1:].contiguous()
|
|
|
30 |
)
|
31 |
from .configuration_fuxitranyu import FuxiTranyuConfig
|
32 |
|
33 |
+
|
34 |
try:
|
35 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
36 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
1251 |
logits = torch.cat(logits, dim=-1)
|
1252 |
else:
|
1253 |
logits = self.lm_head(hidden_states)
|
|
|
1254 |
|
1255 |
loss = None
|
1256 |
if labels is not None:
|
1257 |
+
logits = logits.float()
|
1258 |
# Shift so that tokens < n predict n
|
1259 |
shift_logits = logits[..., :-1, :].contiguous()
|
1260 |
shift_labels = labels[..., 1:].contiguous()
|