Text Generation
Transformers
Safetensors
fuxitranyu
conversational
custom_code
rrjin commited on
Commit
8ad3d9d
1 Parent(s): acc1bef

Upload modeling_fuxitranyu.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fuxitranyu.py +2 -1
modeling_fuxitranyu.py CHANGED
@@ -30,6 +30,7 @@ from transformers.utils import (
30
  )
31
  from .configuration_fuxitranyu import FuxiTranyuConfig
32
 
 
33
  try:
34
  from flash_attn import flash_attn_func, flash_attn_varlen_func
35
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -1250,10 +1251,10 @@ class FuxiTranyuForCausalLM(FuxiTranyuPreTrainedModel):
1250
  logits = torch.cat(logits, dim=-1)
1251
  else:
1252
  logits = self.lm_head(hidden_states)
1253
- logits = logits.float()
1254
 
1255
  loss = None
1256
  if labels is not None:
 
1257
  # Shift so that tokens < n predict n
1258
  shift_logits = logits[..., :-1, :].contiguous()
1259
  shift_labels = labels[..., 1:].contiguous()
 
30
  )
31
  from .configuration_fuxitranyu import FuxiTranyuConfig
32
 
33
+
34
  try:
35
  from flash_attn import flash_attn_func, flash_attn_varlen_func
36
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
1251
  logits = torch.cat(logits, dim=-1)
1252
  else:
1253
  logits = self.lm_head(hidden_states)
 
1254
 
1255
  loss = None
1256
  if labels is not None:
1257
+ logits = logits.float()
1258
  # Shift so that tokens < n predict n
1259
  shift_logits = logits[..., :-1, :].contiguous()
1260
  shift_labels = labels[..., 1:].contiguous()