ikuinen99 commited on
Commit
245b48f
1 Parent(s): 22e0bc4
Files changed (1) hide show
  1. imagebind/models/image_bind.py +141 -141
imagebind/models/image_bind.py CHANGED
@@ -269,12 +269,12 @@ class ImageBindModel(nn.Module):
269
  depth_stem=None,
270
  )
271
 
272
- text_preprocessor = TextPreprocessor(
273
- context_length=77,
274
- vocab_size=49408,
275
- embed_dim=text_embed_dim,
276
- causal_masking=True,
277
- )
278
 
279
  audio_stem = PatchEmbedGeneric(
280
  proj_stem=[
@@ -295,73 +295,73 @@ class ImageBindModel(nn.Module):
295
  audio_stem=audio_stem,
296
  )
297
 
298
- depth_stem = PatchEmbedGeneric(
299
- [
300
- nn.Conv2d(
301
- kernel_size=depth_kernel_size,
302
- in_channels=1,
303
- out_channels=depth_embed_dim,
304
- stride=depth_kernel_size,
305
- bias=False,
306
- ),
307
- ],
308
- norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
309
- )
310
-
311
- depth_preprocessor = RGBDTPreprocessor(
312
- img_size=[1, 224, 224],
313
- num_cls_tokens=1,
314
- pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
315
- rgbt_stem=None,
316
- depth_stem=depth_stem,
317
- )
318
-
319
- thermal_stem = PatchEmbedGeneric(
320
- [
321
- nn.Conv2d(
322
- kernel_size=thermal_kernel_size,
323
- in_channels=1,
324
- out_channels=thermal_embed_dim,
325
- stride=thermal_kernel_size,
326
- bias=False,
327
- ),
328
- ],
329
- norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
330
- )
331
- thermal_preprocessor = ThermalPreprocessor(
332
- img_size=[1, 224, 224],
333
- num_cls_tokens=1,
334
- pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
335
- thermal_stem=thermal_stem,
336
- )
337
-
338
- imu_stem = PatchEmbedGeneric(
339
- [
340
- nn.Linear(
341
- in_features=48,
342
- out_features=imu_embed_dim,
343
- bias=False,
344
- ),
345
- ],
346
- norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
347
- )
348
-
349
- imu_preprocessor = IMUPreprocessor(
350
- img_size=[6, 2000],
351
- num_cls_tokens=1,
352
- kernel_size=8,
353
- embed_dim=imu_embed_dim,
354
- pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
355
- imu_stem=imu_stem,
356
- )
357
 
358
  modality_preprocessors = {
359
  ModalityType.VISION: rgbt_preprocessor,
360
- ModalityType.TEXT: text_preprocessor,
361
  ModalityType.AUDIO: audio_preprocessor,
362
- ModalityType.DEPTH: depth_preprocessor,
363
- ModalityType.THERMAL: thermal_preprocessor,
364
- ModalityType.IMU: imu_preprocessor,
365
  }
366
 
367
  return nn.ModuleDict(modality_preprocessors)
@@ -424,14 +424,14 @@ class ImageBindModel(nn.Module):
424
  add_bias_kv=False,
425
  drop_path=0.0,
426
  )
427
- modality_trunks[ModalityType.TEXT] = instantiate_trunk(
428
- text_embed_dim,
429
- text_num_blocks,
430
- text_num_heads,
431
- pre_transformer_ln=False,
432
- add_bias_kv=False,
433
- drop_path=0.0,
434
- )
435
  modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
436
  audio_embed_dim,
437
  audio_num_blocks,
@@ -440,30 +440,30 @@ class ImageBindModel(nn.Module):
440
  add_bias_kv=True,
441
  drop_path=audio_drop_path,
442
  )
443
- modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
444
- depth_embed_dim,
445
- depth_num_blocks,
446
- depth_num_heads,
447
- pre_transformer_ln=False,
448
- add_bias_kv=True,
449
- drop_path=depth_drop_path,
450
- )
451
- modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
452
- thermal_embed_dim,
453
- thermal_num_blocks,
454
- thermal_num_heads,
455
- pre_transformer_ln=False,
456
- add_bias_kv=True,
457
- drop_path=thermal_drop_path,
458
- )
459
- modality_trunks[ModalityType.IMU] = instantiate_trunk(
460
- imu_embed_dim,
461
- imu_num_blocks,
462
- imu_num_heads,
463
- pre_transformer_ln=False,
464
- add_bias_kv=True,
465
- drop_path=imu_drop_path,
466
- )
467
 
468
  return nn.ModuleDict(modality_trunks)
469
 
@@ -486,12 +486,12 @@ class ImageBindModel(nn.Module):
486
  nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
487
  )
488
 
489
- modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
490
- proj=nn.Sequential(
491
- nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
492
- nn.Linear(text_embed_dim, out_embed_dim, bias=False),
493
- )
494
- )
495
 
496
  modality_heads[ModalityType.AUDIO] = nn.Sequential(
497
  nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
@@ -499,24 +499,24 @@ class ImageBindModel(nn.Module):
499
  nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
500
  )
501
 
502
- modality_heads[ModalityType.DEPTH] = nn.Sequential(
503
- nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
504
- SelectElement(index=0) if use_selection else nn.Identity(),
505
- nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
506
- )
507
-
508
- modality_heads[ModalityType.THERMAL] = nn.Sequential(
509
- nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
510
- SelectElement(index=0) if use_selection else nn.Identity(),
511
- nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
512
- )
513
-
514
- modality_heads[ModalityType.IMU] = nn.Sequential(
515
- nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
516
- SelectElement(index=0) if use_selection else nn.Identity(),
517
- nn.Dropout(p=0.5),
518
- nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
519
- )
520
 
521
  return nn.ModuleDict(modality_heads)
522
 
@@ -524,25 +524,25 @@ class ImageBindModel(nn.Module):
524
  modality_postprocessors = {}
525
 
526
  modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
527
- modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
528
- Normalize(dim=-1), LearnableLogitScaling(learnable=True)
529
- )
530
  modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
531
  Normalize(dim=-1),
532
  LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
533
  )
534
- modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
535
- Normalize(dim=-1),
536
- LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
537
- )
538
- modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
539
- Normalize(dim=-1),
540
- LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
541
- )
542
- modality_postprocessors[ModalityType.IMU] = nn.Sequential(
543
- Normalize(dim=-1),
544
- LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
545
- )
546
 
547
  return nn.ModuleDict(modality_postprocessors)
548
 
@@ -612,7 +612,7 @@ def imagebind_huge(pretrained=False, freeze_imagebind=False, with_head=True, use
612
  progress=True,
613
  )
614
 
615
- model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"))
616
 
617
  if use_blip_vision:
618
  from bubogpt.models.eva_vit import create_eva_vit_g
 
269
  depth_stem=None,
270
  )
271
 
272
+ # text_preprocessor = TextPreprocessor(
273
+ # context_length=77,
274
+ # vocab_size=49408,
275
+ # embed_dim=text_embed_dim,
276
+ # causal_masking=True,
277
+ # )
278
 
279
  audio_stem = PatchEmbedGeneric(
280
  proj_stem=[
 
295
  audio_stem=audio_stem,
296
  )
297
 
298
+ # depth_stem = PatchEmbedGeneric(
299
+ # [
300
+ # nn.Conv2d(
301
+ # kernel_size=depth_kernel_size,
302
+ # in_channels=1,
303
+ # out_channels=depth_embed_dim,
304
+ # stride=depth_kernel_size,
305
+ # bias=False,
306
+ # ),
307
+ # ],
308
+ # norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
309
+ # )
310
+ #
311
+ # depth_preprocessor = RGBDTPreprocessor(
312
+ # img_size=[1, 224, 224],
313
+ # num_cls_tokens=1,
314
+ # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
315
+ # rgbt_stem=None,
316
+ # depth_stem=depth_stem,
317
+ # )
318
+ #
319
+ # thermal_stem = PatchEmbedGeneric(
320
+ # [
321
+ # nn.Conv2d(
322
+ # kernel_size=thermal_kernel_size,
323
+ # in_channels=1,
324
+ # out_channels=thermal_embed_dim,
325
+ # stride=thermal_kernel_size,
326
+ # bias=False,
327
+ # ),
328
+ # ],
329
+ # norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
330
+ # )
331
+ # thermal_preprocessor = ThermalPreprocessor(
332
+ # img_size=[1, 224, 224],
333
+ # num_cls_tokens=1,
334
+ # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
335
+ # thermal_stem=thermal_stem,
336
+ # )
337
+ #
338
+ # imu_stem = PatchEmbedGeneric(
339
+ # [
340
+ # nn.Linear(
341
+ # in_features=48,
342
+ # out_features=imu_embed_dim,
343
+ # bias=False,
344
+ # ),
345
+ # ],
346
+ # norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
347
+ # )
348
+ #
349
+ # imu_preprocessor = IMUPreprocessor(
350
+ # img_size=[6, 2000],
351
+ # num_cls_tokens=1,
352
+ # kernel_size=8,
353
+ # embed_dim=imu_embed_dim,
354
+ # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
355
+ # imu_stem=imu_stem,
356
+ # )
357
 
358
  modality_preprocessors = {
359
  ModalityType.VISION: rgbt_preprocessor,
360
+ # ModalityType.TEXT: text_preprocessor,
361
  ModalityType.AUDIO: audio_preprocessor,
362
+ # ModalityType.DEPTH: depth_preprocessor,
363
+ # ModalityType.THERMAL: thermal_preprocessor,
364
+ # ModalityType.IMU: imu_preprocessor,
365
  }
366
 
367
  return nn.ModuleDict(modality_preprocessors)
 
424
  add_bias_kv=False,
425
  drop_path=0.0,
426
  )
427
+ # modality_trunks[ModalityType.TEXT] = instantiate_trunk(
428
+ # text_embed_dim,
429
+ # text_num_blocks,
430
+ # text_num_heads,
431
+ # pre_transformer_ln=False,
432
+ # add_bias_kv=False,
433
+ # drop_path=0.0,
434
+ # )
435
  modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
436
  audio_embed_dim,
437
  audio_num_blocks,
 
440
  add_bias_kv=True,
441
  drop_path=audio_drop_path,
442
  )
443
+ # modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
444
+ # depth_embed_dim,
445
+ # depth_num_blocks,
446
+ # depth_num_heads,
447
+ # pre_transformer_ln=False,
448
+ # add_bias_kv=True,
449
+ # drop_path=depth_drop_path,
450
+ # )
451
+ # modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
452
+ # thermal_embed_dim,
453
+ # thermal_num_blocks,
454
+ # thermal_num_heads,
455
+ # pre_transformer_ln=False,
456
+ # add_bias_kv=True,
457
+ # drop_path=thermal_drop_path,
458
+ # )
459
+ # modality_trunks[ModalityType.IMU] = instantiate_trunk(
460
+ # imu_embed_dim,
461
+ # imu_num_blocks,
462
+ # imu_num_heads,
463
+ # pre_transformer_ln=False,
464
+ # add_bias_kv=True,
465
+ # drop_path=imu_drop_path,
466
+ # )
467
 
468
  return nn.ModuleDict(modality_trunks)
469
 
 
486
  nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
487
  )
488
 
489
+ # modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
490
+ # proj=nn.Sequential(
491
+ # nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
492
+ # nn.Linear(text_embed_dim, out_embed_dim, bias=False),
493
+ # )
494
+ # )
495
 
496
  modality_heads[ModalityType.AUDIO] = nn.Sequential(
497
  nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
 
499
  nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
500
  )
501
 
502
+ # modality_heads[ModalityType.DEPTH] = nn.Sequential(
503
+ # nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
504
+ # SelectElement(index=0) if use_selection else nn.Identity(),
505
+ # nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
506
+ # )
507
+ #
508
+ # modality_heads[ModalityType.THERMAL] = nn.Sequential(
509
+ # nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
510
+ # SelectElement(index=0) if use_selection else nn.Identity(),
511
+ # nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
512
+ # )
513
+ #
514
+ # modality_heads[ModalityType.IMU] = nn.Sequential(
515
+ # nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
516
+ # SelectElement(index=0) if use_selection else nn.Identity(),
517
+ # nn.Dropout(p=0.5),
518
+ # nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
519
+ # )
520
 
521
  return nn.ModuleDict(modality_heads)
522
 
 
524
  modality_postprocessors = {}
525
 
526
  modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
527
+ # modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
528
+ # Normalize(dim=-1), LearnableLogitScaling(learnable=True)
529
+ # )
530
  modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
531
  Normalize(dim=-1),
532
  LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
533
  )
534
+ # modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
535
+ # Normalize(dim=-1),
536
+ # LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
537
+ # )
538
+ # modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
539
+ # Normalize(dim=-1),
540
+ # LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
541
+ # )
542
+ # modality_postprocessors[ModalityType.IMU] = nn.Sequential(
543
+ # Normalize(dim=-1),
544
+ # LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
545
+ # )
546
 
547
  return nn.ModuleDict(modality_postprocessors)
548
 
 
612
  progress=True,
613
  )
614
 
615
+ model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"), strict=False)
616
 
617
  if use_blip_vision:
618
  from bubogpt.models.eva_vit import create_eva_vit_g