Make flash attention configurable in user code
#26
by
YenChunChen
- opened
- README.md +1 -19
- modeling_phi3_v.py +7 -6
README.md
CHANGED
@@ -105,7 +105,7 @@ from transformers import AutoProcessor
|
|
105 |
|
106 |
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
107 |
|
108 |
-
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto")
|
109 |
|
110 |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
111 |
|
@@ -217,24 +217,6 @@ Note that by default, the Phi-3-Vision-128K model uses flash attention, which re
|
|
217 |
* NVIDIA A6000
|
218 |
* NVIDIA H100
|
219 |
|
220 |
-
### Running on Windows or without flash attention
|
221 |
-
To enable the model on these enviroment here are steps that you may consider to follow:
|
222 |
-
|
223 |
-
Step 1: comment flash attention import code in modeling_phi3_v.py from line 52 to line 56.
|
224 |
-
```python
|
225 |
-
# if is_flash_attn_2_available():
|
226 |
-
# from flash_attn import flash_attn_func, flash_attn_varlen_func
|
227 |
-
# from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
228 |
-
|
229 |
-
# _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
230 |
-
```
|
231 |
-
|
232 |
-
Step 2: change _"_attn_implementation"_ from _"flash_attention_2"_ to _"eager"_ in config.json or disable flash attention when you create the model as below.
|
233 |
-
|
234 |
-
```python
|
235 |
-
model = AutoModelForCausalLM.from_pretrained('microsoft/Phi-3-vision-128k-instruct', device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation="eager")
|
236 |
-
```
|
237 |
-
|
238 |
## License
|
239 |
|
240 |
The model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/resolve/main/LICENSE).
|
|
|
105 |
|
106 |
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
107 |
|
108 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2') # use _attn_implementation='eager' to disable flash attention
|
109 |
|
110 |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
111 |
|
|
|
217 |
* NVIDIA A6000
|
218 |
* NVIDIA H100
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
## License
|
221 |
|
222 |
The model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/resolve/main/LICENSE).
|
modeling_phi3_v.py
CHANGED
@@ -40,7 +40,6 @@ from transformers.utils import (
|
|
40 |
add_code_sample_docstrings,
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
43 |
-
is_flash_attn_2_available,
|
44 |
is_flash_attn_greater_or_equal_2_10,
|
45 |
logging,
|
46 |
replace_return_docstrings,
|
@@ -49,11 +48,13 @@ from .configuration_phi3_v import Phi3VConfig
|
|
49 |
from .image_embedding_phi3_v import Phi3ImageEmbedding
|
50 |
|
51 |
|
52 |
-
|
53 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
54 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
55 |
|
56 |
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
|
|
|
|
57 |
|
58 |
logger = logging.get_logger(__name__)
|
59 |
|
@@ -1000,8 +1001,8 @@ PHI3V_INPUTS_DOCSTRING = r"""
|
|
1000 |
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1001 |
model's internal embedding lookup matrix.
|
1002 |
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
1003 |
-
The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
|
1004 |
-
See [`Phi3ImageProcessor.__call__`] for details.
|
1005 |
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
|
1006 |
The sizes of the images in the batch, being (height, width) for each image.
|
1007 |
use_cache (`bool`, *optional*):
|
@@ -1046,7 +1047,7 @@ class Phi3VModel(Phi3VPreTrainedModel):
|
|
1046 |
**config.embd_layer
|
1047 |
}
|
1048 |
self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
|
1049 |
-
# # set wte the same for vision embedding
|
1050 |
# self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
|
1051 |
|
1052 |
self.layers = nn.ModuleList(
|
@@ -1629,4 +1630,4 @@ class Phi3VForTokenClassification(Phi3VPreTrainedModel):
|
|
1629 |
logits=logits,
|
1630 |
hidden_states=model_outputs.hidden_states,
|
1631 |
attentions=model_outputs.attentions,
|
1632 |
-
)
|
|
|
40 |
add_code_sample_docstrings,
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
|
|
43 |
is_flash_attn_greater_or_equal_2_10,
|
44 |
logging,
|
45 |
replace_return_docstrings,
|
|
|
48 |
from .image_embedding_phi3_v import Phi3ImageEmbedding
|
49 |
|
50 |
|
51 |
+
try:
|
52 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
53 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
54 |
|
55 |
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
56 |
+
except ImportError:
|
57 |
+
pass
|
58 |
|
59 |
logger = logging.get_logger(__name__)
|
60 |
|
|
|
1001 |
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1002 |
model's internal embedding lookup matrix.
|
1003 |
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
1004 |
+
The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
|
1005 |
+
See [`Phi3ImageProcessor.__call__`] for details.
|
1006 |
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
|
1007 |
The sizes of the images in the batch, being (height, width) for each image.
|
1008 |
use_cache (`bool`, *optional*):
|
|
|
1047 |
**config.embd_layer
|
1048 |
}
|
1049 |
self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
|
1050 |
+
# # set wte the same for vision embedding
|
1051 |
# self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
|
1052 |
|
1053 |
self.layers = nn.ModuleList(
|
|
|
1630 |
logits=logits,
|
1631 |
hidden_states=model_outputs.hidden_states,
|
1632 |
attentions=model_outputs.attentions,
|
1633 |
+
)
|