Text Generation
Transformers
PyTorch
English
llama
Inference Endpoints
text-generation-inference
arshzahed commited on
Commit
061211f
1 Parent(s): 3c84db1

Add _support_flash_attn_2 to Llama 2 32k (#37)

Browse files

- Add _support_flash_attn_2 to Llama 2 32k (c761ba2f083d2de002465b0b74c438b8af1561aa)

Files changed (1) hide show
  1. modeling_flash_llama.py +1 -0
modeling_flash_llama.py CHANGED
@@ -499,6 +499,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
499
  supports_gradient_checkpointing = True
500
  _no_split_modules = ["LlamaDecoderLayer"]
501
  _skip_keys_device_placement = "past_key_values"
 
502
 
503
  def _init_weights(self, module):
504
  std = self.config.initializer_range
 
499
  supports_gradient_checkpointing = True
500
  _no_split_modules = ["LlamaDecoderLayer"]
501
  _skip_keys_device_placement = "past_key_values"
502
+ _supports_flash_attn_2 = True
503
 
504
  def _init_weights(self, module):
505
  std = self.config.initializer_range