YenChunChen commited on
Commit
bfc0c05
1 Parent(s): 5dcfdfb

make flash attention usage configurable from user code

Browse files
Files changed (2) hide show
  1. config.json +1 -2
  2. modeling_phi3_v.py +7 -6
config.json CHANGED
@@ -143,6 +143,5 @@
143
  "torch_dtype": "bfloat16",
144
  "transformers_version": "4.38.1",
145
  "use_cache": true,
146
- "vocab_size": 32064,
147
- "_attn_implementation": "flash_attention_2"
148
  }
 
143
  "torch_dtype": "bfloat16",
144
  "transformers_version": "4.38.1",
145
  "use_cache": true,
146
+ "vocab_size": 32064
 
147
  }
modeling_phi3_v.py CHANGED
@@ -40,7 +40,6 @@ from transformers.utils import (
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
43
- is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
  logging,
46
  replace_return_docstrings,
@@ -49,11 +48,13 @@ from .configuration_phi3_v import Phi3VConfig
49
  from .image_embedding_phi3_v import Phi3ImageEmbedding
50
 
51
 
52
- if is_flash_attn_2_available():
53
  from flash_attn import flash_attn_func, flash_attn_varlen_func
54
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
55
 
56
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
 
 
57
 
58
  logger = logging.get_logger(__name__)
59
 
@@ -1000,8 +1001,8 @@ PHI3V_INPUTS_DOCSTRING = r"""
1000
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1001
  model's internal embedding lookup matrix.
1002
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1003
- The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
1004
- See [`Phi3ImageProcessor.__call__`] for details.
1005
  image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
1006
  The sizes of the images in the batch, being (height, width) for each image.
1007
  use_cache (`bool`, *optional*):
@@ -1046,7 +1047,7 @@ class Phi3VModel(Phi3VPreTrainedModel):
1046
  **config.embd_layer
1047
  }
1048
  self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
1049
- # # set wte the same for vision embedding
1050
  # self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
1051
 
1052
  self.layers = nn.ModuleList(
@@ -1629,4 +1630,4 @@ class Phi3VForTokenClassification(Phi3VPreTrainedModel):
1629
  logits=logits,
1630
  hidden_states=model_outputs.hidden_states,
1631
  attentions=model_outputs.attentions,
1632
- )
 
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
 
43
  is_flash_attn_greater_or_equal_2_10,
44
  logging,
45
  replace_return_docstrings,
 
48
  from .image_embedding_phi3_v import Phi3ImageEmbedding
49
 
50
 
51
+ try:
52
  from flash_attn import flash_attn_func, flash_attn_varlen_func
53
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
54
 
55
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
56
+ except ImportError:
57
+ pass
58
 
59
  logger = logging.get_logger(__name__)
60
 
 
1001
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1002
  model's internal embedding lookup matrix.
1003
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1004
+ The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
1005
+ See [`Phi3ImageProcessor.__call__`] for details.
1006
  image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
1007
  The sizes of the images in the batch, being (height, width) for each image.
1008
  use_cache (`bool`, *optional*):
 
1047
  **config.embd_layer
1048
  }
1049
  self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
1050
+ # # set wte the same for vision embedding
1051
  # self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
1052
 
1053
  self.layers = nn.ModuleList(
 
1630
  logits=logits,
1631
  hidden_states=model_outputs.hidden_states,
1632
  attentions=model_outputs.attentions,
1633
+ )