Update modeling_cohere.py

#8
by ybelkada HF staff - opened
Files changed (2) hide show
  1. modeling_cohere.py +3 -0
  2. requirements.txt +1 -0
modeling_cohere.py CHANGED
@@ -52,6 +52,9 @@ from transformers.utils import (
52
  )
53
  from .configuration_cohere import CohereConfig
54
 
 
 
 
55
 
56
  logger = logging.get_logger(__name__)
57
 
 
52
  )
53
  from .configuration_cohere import CohereConfig
54
 
55
+ if is_flash_attn_2_available():
56
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
57
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
 
59
  logger = logging.get_logger(__name__)
60
 
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ flash_attn