x54-729 commited on
Commit
e9b4825
1 Parent(s): 2a826c2

update modeling file to newest

Browse files
configuration_internlm2.py CHANGED
@@ -177,4 +177,4 @@ class InternLM2Config(PretrainedConfig):
177
  raise ValueError(
178
  f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
179
  f"of type {type(rope_scaling_factor)}"
180
- )
 
177
  raise ValueError(
178
  f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
179
  f"of type {type(rope_scaling_factor)}"
180
+ )
modeling_internlm2.py CHANGED
@@ -59,6 +59,10 @@ try:
59
  except:
60
  pass
61
 
 
 
 
 
62
 
63
  logger = logging.get_logger(__name__)
64
 
@@ -1093,7 +1097,11 @@ class InternLM2Model(InternLM2PreTrainedModel):
1093
  else:
1094
  causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1095
  if sequence_length != 1:
1096
- causal_mask = torch.triu(causal_mask, diagonal=1)
 
 
 
 
1097
  causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1098
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1099
  if attention_mask is not None:
 
59
  except:
60
  pass
61
 
62
+ try:
63
+ support_bf16_triu = torch.__version__ >= "2.1.0"
64
+ except Exception:
65
+ support_bf16_triu = False
66
 
67
  logger = logging.get_logger(__name__)
68
 
 
1097
  else:
1098
  causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1099
  if sequence_length != 1:
1100
+ if support_bf16_triu or dtype == torch.float32:
1101
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1102
+ else:
1103
+ triu_mask = torch.triu(torch.ones(causal_mask.size(), device=device), diagonal=1).bool()
1104
+ causal_mask.masked_fill_(~triu_mask, 0)
1105
  causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1106
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1107
  if attention_mask is not None: