Use Flash Attention for CLIP image encoder

#52
Files changed (1) hide show
  1. image_embedding_phi3_v.py +62 -17
image_embedding_phi3_v.py CHANGED
@@ -13,13 +13,18 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- import math
 
17
  import torch
18
- import torch.nn as nn
19
- from transformers import CLIPVisionModel, PretrainedConfig
20
- from transformers import CLIPVisionConfig
21
  from transformers.utils import logging
22
- from datetime import datetime
 
 
 
 
23
 
24
  logger = logging.get_logger(__name__)
25
 
@@ -37,9 +42,42 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
37
  num_channels=3,
38
  num_hidden_layers=24,
39
  patch_size=14,
40
- projection_dim=768
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class Phi3ImageEmbedding(nn.Module):
44
  """Phi3 Image embedding."""
45
 
@@ -65,6 +103,13 @@ class Phi3ImageEmbedding(nn.Module):
65
  self.img_processor = CLIPVisionModel(clip_config)
66
  image_dim_out = config.img_processor['image_dim_out']
67
  self.num_img_tokens = config.img_processor['num_img_tokens']
 
 
 
 
 
 
 
68
  else:
69
  raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented')
70
 
@@ -157,15 +202,15 @@ class Phi3ImageEmbedding(nn.Module):
157
 
158
  with torch.no_grad():
159
  positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
160
-
161
  select = False
162
 
163
- if isinstance(self.img_projection, nn.Sequential):
164
- target_device = self.img_projection[0].bias.device
165
- target_dtype = self.img_projection[0].bias.dtype
166
- else: # It's a single nn.Linear layer
167
- target_device = self.img_projection.bias.device
168
- target_dtype = self.img_projection.bias.dtype
169
 
170
  if len(positions.tolist()) > 0:
171
  with torch.no_grad():
@@ -197,7 +242,7 @@ class Phi3ImageEmbedding(nn.Module):
197
  img_sizes = img_sizes.view(-1, 2)
198
  for _bs in range(bs):
199
  h, w = img_sizes[_bs]
200
- h = h // 336
201
  w = w // 336
202
  B_ = h * w
203
 
@@ -235,7 +280,7 @@ class Phi3ImageEmbedding(nn.Module):
235
  temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
236
  assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
237
  output_len.append(temp_len)
238
-
239
  num_img_tokens = output_len
240
  img_set_tensor = []
241
  for _output_img in output_imgs:
@@ -267,10 +312,10 @@ class Phi3ImageEmbedding(nn.Module):
267
  else:
268
  raise NotImplementedError
269
  select = True
270
-
271
  with torch.no_grad():
272
  input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
273
-
274
  hidden_states = self.wte(input_ids)
275
 
276
  if select:
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
+ from datetime import datetime
17
+
18
  import torch
19
+ from torch import nn
20
+ from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
21
+ from transformers.models.clip.modeling_clip import CLIPAttention
22
  from transformers.utils import logging
23
+
24
+ try:
25
+ from flash_attn import flash_attn_func
26
+ except ImportError:
27
+ pass
28
 
29
  logger = logging.get_logger(__name__)
30
 
 
42
  num_channels=3,
43
  num_hidden_layers=24,
44
  patch_size=14,
45
+ projection_dim=768
46
  )
47
 
48
+ class CLIPAttentionFA2(CLIPAttention):
49
+ """Add flash attention 2 to CLIPAttention. (This is only used in the vision encoder)"""
50
+
51
+ def forward(self,
52
+ hidden_states,
53
+ attention_mask=None,
54
+ causal_attention_mask=None,
55
+ output_attentions=False,
56
+ ):
57
+ """Input shape: Batch x Time x Channel"""
58
+
59
+ assert attention_mask is None, "CLIPAttentionFA2 does not support attention_mask"
60
+ assert causal_attention_mask is None, "CLIPAttentionFA2 does not support causal_attention_mask"
61
+ assert output_attentions is False, "CLIPAttentionFA2 does not support output_attentions"
62
+
63
+ bsz, tgt_len, embed_dim = hidden_states.size()
64
+ query_states = self.q_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
65
+ key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
66
+ value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
67
+
68
+ attn_output = flash_attn_func(
69
+ query_states,
70
+ key_states,
71
+ value_states,
72
+ dropout_p=self.dropout if self.training else 0.0,
73
+ softmax_scale=self.scale,
74
+ causal=False,
75
+ ).reshape(bsz, tgt_len, embed_dim)
76
+
77
+ attn_output = self.out_proj(attn_output)
78
+ return attn_output, None
79
+
80
+
81
  class Phi3ImageEmbedding(nn.Module):
82
  """Phi3 Image embedding."""
83
 
 
103
  self.img_processor = CLIPVisionModel(clip_config)
104
  image_dim_out = config.img_processor['image_dim_out']
105
  self.num_img_tokens = config.img_processor['num_img_tokens']
106
+
107
+ # FA2 in CLIP
108
+ if config._attn_implementation == 'flash_attention_2':
109
+ for layer in self.img_processor.vision_model.encoder.layers:
110
+ clip_fa2 = CLIPAttentionFA2(clip_config)
111
+ del layer.self_attn
112
+ layer.self_attn = clip_fa2
113
  else:
114
  raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented')
115
 
 
202
 
203
  with torch.no_grad():
204
  positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
205
+
206
  select = False
207
 
208
+ if isinstance(self.img_projection, nn.Sequential):
209
+ target_device = self.img_projection[0].bias.device
210
+ target_dtype = self.img_projection[0].bias.dtype
211
+ else: # It's a single nn.Linear layer
212
+ target_device = self.img_projection.bias.device
213
+ target_dtype = self.img_projection.bias.dtype
214
 
215
  if len(positions.tolist()) > 0:
216
  with torch.no_grad():
 
242
  img_sizes = img_sizes.view(-1, 2)
243
  for _bs in range(bs):
244
  h, w = img_sizes[_bs]
245
+ h = h // 336
246
  w = w // 336
247
  B_ = h * w
248
 
 
280
  temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
281
  assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
282
  output_len.append(temp_len)
283
+
284
  num_img_tokens = output_len
285
  img_set_tensor = []
286
  for _output_img in output_imgs:
 
312
  else:
313
  raise NotImplementedError
314
  select = True
315
+
316
  with torch.no_grad():
317
  input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
318
+
319
  hidden_states = self.wte(input_ids)
320
 
321
  if select: