rulerman commited on
Commit
8008b40
·
1 Parent(s): dbd498f
Files changed (1) hide show
  1. generation_utils.py +1 -1
generation_utils.py CHANGED
@@ -82,7 +82,7 @@ def load_model(
82
  spt_config_path,
83
  spt_checkpoint_path,
84
  torch_dtype=torch.bfloat16,
85
- attn_implementation="flash_attention_2",
86
  ):
87
  from transformers import AutoTokenizer
88
 
 
82
  spt_config_path,
83
  spt_checkpoint_path,
84
  torch_dtype=torch.bfloat16,
85
+ attn_implementation="sdpa",
86
  ):
87
  from transformers import AutoTokenizer
88