Rewrite image embedding to remove the in-place op

#53
Files changed (1) hide show
  1. image_embedding_phi3_v.py +129 -153
image_embedding_phi3_v.py CHANGED
@@ -13,8 +13,6 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- from datetime import datetime
17
-
18
  import torch
19
  from torch import nn
20
  from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
@@ -28,6 +26,9 @@ except ImportError:
28
 
29
  logger = logging.get_logger(__name__)
30
 
 
 
 
31
  CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
32
  attention_dropout=0.0,
33
  dropout=0.0,
@@ -179,32 +180,44 @@ class Phi3ImageEmbedding(nn.Module):
179
  patch_feature = img_feature[:, 1:]
180
  return patch_feature
181
 
182
- if TYPE_FEATURE == "cls_patch":
183
- return img_feature
184
-
185
  raise NotImplementedError
186
 
187
- def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None) -> torch.FloatTensor:
188
-
189
- MAX_INPUT_ID = int(1e9)
190
- img_embeds = pixel_values
191
- img_sizes = image_sizes
192
-
193
- if self.img_features is not None:
194
- img_embeds = self.img_features.clone()
195
- self.img_features = None
196
-
197
- if self.img_sizes is not None:
198
- img_sizes = self.img_sizes
199
-
200
  input_shape = input_ids.size()
201
  input_ids = input_ids.view(-1, input_shape[-1])
202
 
203
- with torch.no_grad():
204
- positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- select = False
207
 
 
 
 
 
 
 
 
208
  if isinstance(self.img_projection, nn.Sequential):
209
  target_device = self.img_projection[0].bias.device
210
  target_dtype = self.img_projection[0].bias.dtype
@@ -212,135 +225,98 @@ class Phi3ImageEmbedding(nn.Module):
212
  target_device = self.img_projection.bias.device
213
  target_dtype = self.img_projection.bias.dtype
214
 
215
- if len(positions.tolist()) > 0:
216
- with torch.no_grad():
217
- g_values = abs(input_ids[positions[:, 0], positions[:, 1]])
218
-
219
- if self.use_hd_transform and img_sizes is not None and len(img_sizes):
220
- hd_transform = True
221
- assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform'
222
- # img_embeds: (num_images, max_num_crops, 3, H, W)
223
- # img_sizes: (num_images, 2).view(1, -1)
224
-
225
- start_time = datetime.now()
226
- bs = img_embeds.shape[0]
227
- # Nx(HW)xC
228
- img_features = self.get_img_features(img_embeds.flatten(0, 1))
229
- base_feat_height = base_feat_width = int(img_features.shape[1] ** 0.5)
230
-
231
- assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform'
232
-
233
- # bs x max_num_crops x (24x24) x C
234
- img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
235
- C = self.image_dim_out
236
- H = base_feat_height
237
-
238
- output_imgs = []
239
- output_len = []
240
- # training is tensor, inference is list
241
- if isinstance(img_sizes, torch.Tensor):
242
- img_sizes = img_sizes.view(-1, 2)
243
- for _bs in range(bs):
244
- h, w = img_sizes[_bs]
245
- h = h // 336
246
- w = w // 336
247
- B_ = h * w
248
-
249
- # 1 x (24x24) x 1024
250
- global_img_feature = img_features[_bs, :1]
251
-
252
- # 1 x 12 x 12 x 4096
253
- glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous()
254
- temp_glb_GN = self.sub_GN.repeat(1, H//2, 1, 1)
255
-
256
- # 1 x 156 x 4096
257
- glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C)
258
-
259
- # (max_num_crops-1) x (12x12) x C
260
- sub_img = img_features[_bs, 1:]
261
- # 16x574x1024
262
- # get rid of padding sub_img
263
- sub_img = sub_img[:B_]
264
-
265
- # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
266
- sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous()
267
- sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C)
268
- temp_sub_GN = self.sub_GN.repeat(1, h*12, 1, 1)
269
- sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C)
270
- # (1, num_img_tokens, 1024*4)
271
-
272
- # glb + sub
273
- if self.hd_transform_order == 'glb_sub':
274
- output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
275
- elif self.hd_transform_order == 'sub_glb':
276
- output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
277
- else:
278
- raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented')
279
-
280
- temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
281
- assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
282
- output_len.append(temp_len)
283
-
284
- num_img_tokens = output_len
285
- img_set_tensor = []
286
- for _output_img in output_imgs:
287
- img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype))
288
- img_set_tensor.append(img_feature_proj)
289
- logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}')
290
- elif img_embeds.ndim == 4:
291
- selected_g_values = g_values[::self.num_img_tokens]
292
- assert len(img_embeds) == len(selected_g_values), f'img_embeds size: {img_embeds.size()}, selected_g_values size: {len(selected_g_values)}, selected_g_value {selected_g_values}'
293
- start_time = datetime.now()
294
- tt = (
295
- self.get_img_features(img_embeds)
296
- .to(target_device)
297
- .to(target_dtype)
298
- .reshape(-1, self.image_dim_out)
299
- )
300
- logger.info(f'img_embeds size: {img_embeds.size()}, loading time {datetime.now() - start_time}')
301
- img_set_tensor = self.img_projection(tt) # adapted visual features.
302
- elif img_embeds.ndim == 3:
303
- selected_g_values = g_values[::self.num_img_tokens]
304
- assert len(img_embeds) == len(selected_g_values), f'img_embeds size: {img_embeds.size()}, selected_g_values size: {len(selected_g_values)}, selected_g_value {selected_g_values}'
305
- tt = (
306
- img_embeds
307
- .to(target_device)
308
- .to(target_dtype)
309
- .view(-1, self.image_dim_out)
310
- )
311
- img_set_tensor = self.img_projection(tt) # adapted visual features.
312
- else:
313
- raise NotImplementedError
314
- select = True
315
-
316
- with torch.no_grad():
317
- input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
318
-
319
- hidden_states = self.wte(input_ids)
320
-
321
- if select:
322
- if hd_transform:
323
- idx = 0
324
- for i, cnt in enumerate(num_img_tokens):
325
- hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = (
326
- img_set_tensor[i]
327
- .to(hidden_states.dtype)
328
- .to(hidden_states.device)
329
- )
330
- idx += cnt
331
- else:
332
- idx = 0
333
- assert len(selected_g_values) * self.num_img_tokens == len(img_set_tensor), f'len(selected_g_values) * self.num_img_tokens = {len(selected_g_values) * self.num_img_tokens}, len(img_set_tensor) = {len(img_set_tensor)}'
334
- for i, g in enumerate(selected_g_values):
335
- cnt = self.num_img_tokens
336
- hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = (
337
- img_set_tensor[i * cnt : (i + 1) * cnt]
338
- .to(hidden_states.dtype)
339
- .to(hidden_states.device)
340
- )
341
- idx += cnt
342
-
343
- if self.drop is not None:
344
- hidden_states = self.drop(hidden_states)
345
-
346
- return hidden_states
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
 
16
  import torch
17
  from torch import nn
18
  from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
 
26
 
27
  logger = logging.get_logger(__name__)
28
 
29
+
30
+ MAX_INPUT_ID = int(1e9)
31
+
32
  CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
33
  attention_dropout=0.0,
34
  dropout=0.0,
 
180
  patch_feature = img_feature[:, 1:]
181
  return patch_feature
182
 
 
 
 
183
  raise NotImplementedError
184
 
185
+ def forward(
186
+ self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None
187
+ ) -> torch.FloatTensor:
 
 
 
 
 
 
 
 
 
 
188
  input_shape = input_ids.size()
189
  input_ids = input_ids.view(-1, input_shape[-1])
190
 
191
+ # positions for image tokens
192
+ positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
193
+ has_image = len(positions[0].tolist()) > 0
194
+ input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach()
195
+ hidden_states = self.wte(input_ids)
196
+
197
+ if has_image:
198
+ assert self.use_hd_transform
199
+ num_images, num_crops, c, h, w = pixel_values.shape
200
+ assert c == 3 and h == w == 336
201
+ img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape(
202
+ num_images, num_crops, -1, self.image_dim_out
203
+ )
204
+ image_features_proj = self.hd_feature_transform(img_features, image_sizes)
205
+ hidden_states = hidden_states.index_put(
206
+ positions, image_features_proj, accumulate=False
207
+ )
208
+
209
+ if self.drop is not None:
210
+ hidden_states = self.drop(hidden_states)
211
 
212
+ return hidden_states
213
 
214
+ def hd_feature_transform(self, image_features, image_sizes):
215
+ """
216
+ image_features: (num_images, num_crops+1, 24*24, 1024)
217
+ """
218
+ assert (
219
+ self.hd_transform_order == 'sub_glb'
220
+ ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
221
  if isinstance(self.img_projection, nn.Sequential):
222
  target_device = self.img_projection[0].bias.device
223
  target_dtype = self.img_projection[0].bias.dtype
 
225
  target_device = self.img_projection.bias.device
226
  target_dtype = self.img_projection.bias.dtype
227
 
228
+ global_image_features = image_features[:, 0] # (num_images, 24*24, 1024)
229
+ # global feature can be viewed as a special HD case with num_crops 1x1
230
+ global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)
231
+ global_image_features_hd_newline = self.add_image_newline(global_image_features_hd)
232
+
233
+ all_image_embeddings = []
234
+ # need a for loop to process each image because of different image sizes
235
+ # (patch arrangement is different for each image)
236
+ for i, img_size in enumerate(image_sizes):
237
+ h, w = img_size
238
+ h_crop = h // 336
239
+ w_crop = w // 336
240
+ num_crops = h_crop * w_crop
241
+
242
+ # NOTE: real num_crops is padded
243
+ # (num_crops, 24*24, 1024)
244
+ sub_image_features = image_features[i, 1 : 1 + num_crops]
245
+ sub_image_features_hd = self.reshape_hd_patches_2x2merge(
246
+ sub_image_features, h_crop, w_crop
247
+ )
248
+ sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)
249
+
250
+ # [sub features, separator, global features]
251
+ all_image_embeddings.extend(
252
+ [
253
+ sub_image_features_hd_newline.squeeze(0), # (h_crop*12*(w_crop*12+1), 4096)
254
+ self.glb_GN.squeeze(0),
255
+ global_image_features_hd_newline[i],
256
+ ]
257
+ )
258
+
259
+ image_features_proj = self.img_projection(
260
+ torch.cat(all_image_embeddings, dim=0).to(target_device).to(target_dtype)
261
+ )
262
+
263
+ return image_features_proj
264
+
265
+ def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
266
+ """
267
+ image_features: (num_images*num_crops, 24*24, 1024)
268
+ output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops
269
+ """
270
+ N, L, C = image_features.shape
271
+ assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0
272
+ num_images = N // (h_crop * w_crop)
273
+ H = int(L**0.5)
274
+ image_features_hd = (
275
+ image_features.reshape(N, H, H, C) # N, 24, 24, 1024
276
+ .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
277
+ .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
278
+ .reshape(N, -1, 4 * C) # N, 144, 4096
279
+ .reshape(
280
+ num_images, h_crop, w_crop, H // 2, H // 2, -1
281
+ ) # n_img, h_crop, w_crop, 12, 12, 4096
282
+ .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
283
+ .reshape(
284
+ num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
285
+ ) # n_img, h_crop*12, w_crop*12, 4096
286
+ )
287
+
288
+ # alternative implementation using einops
289
+ # from einops import rearrange
290
+ # image_features_nhwc = rearrange(
291
+ # image_features,
292
+ # 'N (H W) c -> N H W c',
293
+ # H=H,
294
+ # W=H,
295
+ # )
296
+ # image_features_2x2merge = rearrange(
297
+ # image_features_nhwc,
298
+ # 'N (h h_pool) (w w_pool) c -> N h w (h_pool w_pool c)',
299
+ # h_pool=2,
300
+ # w_pool=2,
301
+ # )
302
+ # image_features_hd = rearrange(
303
+ # image_features_2x2merge,
304
+ # '(n_img h_crop w_crop) h w C -> n_img (h_crop h) (w_crop w) C',
305
+ # h_crop=h_crop,
306
+ # w_crop=w_crop,
307
+ # )
308
+
309
+ return image_features_hd
310
+
311
+ def add_image_newline(self, image_features_hd):
312
+ """
313
+ image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
314
+ output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
315
+ """
316
+ num_images, h, w, hid_dim = image_features_hd.shape
317
+ # add the newline token to the HD image feature patches
318
+ newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim)
319
+ image_features_hd_newline = torch.cat(
320
+ [image_features_hd, newline_embeddings], dim=2
321
+ ).reshape(num_images, -1, hid_dim)
322
+ return image_features_hd_newline