a8nova commited on
Commit
2c39b54
1 Parent(s): 35f1c60

Update modeling_openelm.py

Browse files
Files changed (1) hide show
  1. modeling_openelm.py +5 -3
modeling_openelm.py CHANGED
@@ -778,9 +778,11 @@ class OpenELMModel(OpenELMPreTrainedModel):
778
  padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
779
  :, None, None, :
780
  ].eq(0.0)
781
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
782
- padding_mask, min_dtype
783
- )
 
 
784
 
785
  if self.config._attn_implementation == "sdpa" and attention_mask is not None:
786
  # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
 
778
  padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
779
  :, None, None, :
780
  ].eq(0.0)
781
+ causal_mask = causal_mask.clone()
782
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(...)
783
+ #causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
784
+ # padding_mask, min_dtype
785
+ #)
786
 
787
  if self.config._attn_implementation == "sdpa" and attention_mask is not None:
788
  # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).