adalbertojunior commited on
Commit
835bd36
1 Parent(s): dd90667

Update modeling_cohere.py

Browse files
Files changed (1) hide show
  1. modeling_cohere.py +4 -1
modeling_cohere.py CHANGED
@@ -52,7 +52,10 @@ from transformers.utils import (
52
  )
53
  from .configuration_cohere import CohereConfig
54
 
55
-
 
 
 
56
  logger = logging.get_logger(__name__)
57
 
58
  _CONFIG_FOR_DOC = "CohereConfig"
 
52
  )
53
  from .configuration_cohere import CohereConfig
54
 
55
+ if is_flash_attn_2_available():
56
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
57
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
+
59
  logger = logging.get_logger(__name__)
60
 
61
  _CONFIG_FOR_DOC = "CohereConfig"