Haobo Yuan commited on
Commit
9cc3eb2
·
1 Parent(s): cdf83ef
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -4
  2. README.md +3 -4
  3. app/assets/sa_01.jpg +3 -0
  4. app/assets/sa_224028.jpg +3 -0
  5. app/assets/sa_227490.jpg +3 -0
  6. app/assets/sa_228025.jpg +3 -0
  7. app/assets/sa_234958.jpg +3 -0
  8. app/assets/sa_235005.jpg +3 -0
  9. app/assets/sa_235032.jpg +3 -0
  10. app/assets/sa_235036.jpg +3 -0
  11. app/assets/sa_235086.jpg +3 -0
  12. app/assets/sa_235094.jpg +3 -0
  13. app/assets/sa_235113.jpg +3 -0
  14. app/assets/sa_235130.jpg +3 -0
  15. app/configs/sam_r50x16_fpn.py +81 -0
  16. app/configs/sam_vith.py +38 -0
  17. app/models/last_layer.py +20 -0
  18. app/models/model.py +92 -0
  19. app/models/openclip_backbone.py +292 -0
  20. app/models/ovsam_head.py +226 -0
  21. app/models/sam_backbone.py +113 -0
  22. app/models/sam_mask_decoder.py +140 -0
  23. app/models/sam_pe.py +152 -0
  24. app/models/transformer_neck.py +158 -0
  25. ext/class_names/imagenet_21k_names.py +0 -0
  26. ext/class_names/lvis_list.py +242 -0
  27. ext/meta/sam_meta.py +41 -0
  28. ext/open_clip/__init__.py +15 -0
  29. ext/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  30. ext/open_clip/coca_model.py +458 -0
  31. ext/open_clip/constants.py +2 -0
  32. ext/open_clip/factory.py +387 -0
  33. ext/open_clip/generation_utils.py +0 -0
  34. ext/open_clip/hf_configs.py +56 -0
  35. ext/open_clip/hf_model.py +193 -0
  36. ext/open_clip/loss.py +216 -0
  37. ext/open_clip/model.py +473 -0
  38. ext/open_clip/model_configs/EVA01-g-14-plus.json +18 -0
  39. ext/open_clip/model_configs/EVA01-g-14.json +18 -0
  40. ext/open_clip/model_configs/EVA02-B-16.json +18 -0
  41. ext/open_clip/model_configs/EVA02-E-14-plus.json +18 -0
  42. ext/open_clip/model_configs/EVA02-E-14.json +18 -0
  43. ext/open_clip/model_configs/EVA02-L-14-336.json +18 -0
  44. ext/open_clip/model_configs/EVA02-L-14.json +18 -0
  45. ext/open_clip/model_configs/RN101-quickgelu.json +22 -0
  46. ext/open_clip/model_configs/RN101.json +21 -0
  47. ext/open_clip/model_configs/RN50-quickgelu.json +22 -0
  48. ext/open_clip/model_configs/RN50.json +21 -0
  49. ext/open_clip/model_configs/RN50x16.json +21 -0
  50. ext/open_clip/model_configs/RN50x4.json +21 -0
.gitattributes CHANGED
@@ -17,10 +17,6 @@
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +29,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
 
 
20
  *.rar filter=lfs diff=lfs merge=lfs -text
21
  *.safetensors filter=lfs diff=lfs merge=lfs -text
22
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.pickle filter=lfs diff=lfs merge=lfs -text
33
+ *.pkl filter=lfs diff=lfs merge=lfs -text
34
+ *.pt filter=lfs diff=lfs merge=lfs -text
35
+ *.pth filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,11 @@
1
  ---
2
- title: Ovsam
3
  emoji: 📚
4
  colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.13.0
8
- app_file: app.py
9
  pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Open-Vocabulary SAM
3
  emoji: 📚
4
  colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.13.0
8
+ app_file: main.py
9
  pinned: false
10
+ python_version: 3.10
11
  ---
 
 
app/assets/sa_01.jpg ADDED

Git LFS Details

  • SHA256: bdb5acb53dfc78e74008d113b22f5a2fb1e2c7b33cb8eadf4983d709bfe366ba
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
app/assets/sa_224028.jpg ADDED

Git LFS Details

  • SHA256: 09236dc4305d0603ec94ae3c1e2ac89fdf992a694f39734dc33b5d91773d103f
  • Pointer size: 131 Bytes
  • Size of remote file: 611 kB
app/assets/sa_227490.jpg ADDED

Git LFS Details

  • SHA256: 36530d85ea2ad1b62b655318426842327f6493bc344f5bf69449113e47fece33
  • Pointer size: 131 Bytes
  • Size of remote file: 667 kB
app/assets/sa_228025.jpg ADDED

Git LFS Details

  • SHA256: d766b2af59c8c8a2f319af16447c9c866cba4b436eba243b910a0d106aef7268
  • Pointer size: 131 Bytes
  • Size of remote file: 621 kB
app/assets/sa_234958.jpg ADDED

Git LFS Details

  • SHA256: cdc12e95824716fe9f271d5db027f9a169cb2f44128ec6fd82f8169303980345
  • Pointer size: 131 Bytes
  • Size of remote file: 477 kB
app/assets/sa_235005.jpg ADDED

Git LFS Details

  • SHA256: 32f949ba190d4e304314c299d04fccf64c2f9985c2aaec20425b81b8953f70e7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
app/assets/sa_235032.jpg ADDED

Git LFS Details

  • SHA256: 00ac4b97397914081793265b1b2dc33ed942bebbf6a94997f36cca3708bc8d20
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
app/assets/sa_235036.jpg ADDED

Git LFS Details

  • SHA256: e93a33e3c1a254a3651296d5482c30fcc381b1ac052b5a31fce6cd7cb74d17ee
  • Pointer size: 131 Bytes
  • Size of remote file: 717 kB
app/assets/sa_235086.jpg ADDED

Git LFS Details

  • SHA256: 2c8be10dc14f2833853110c62c4d2217f4d8d3303966fb4d32b12b2231c4013a
  • Pointer size: 131 Bytes
  • Size of remote file: 488 kB
app/assets/sa_235094.jpg ADDED

Git LFS Details

  • SHA256: ef85bf49bf46045882d9b055c129e698e9f8d0d13d7068812482cad443909088
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
app/assets/sa_235113.jpg ADDED

Git LFS Details

  • SHA256: 53a92d6f0b1cb0a0178507c179f0f7ebf3260c8855113a698db01a8a09afd5c3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
app/assets/sa_235130.jpg ADDED

Git LFS Details

  • SHA256: 786f47750bb852fc2a90e2d0a4e5f838a6c2601a278a4ff107d2c321cdf02991
  • Pointer size: 131 Bytes
  • Size of remote file: 787 kB
app/configs/sam_r50x16_fpn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcv.ops import RoIAlign
2
+ from mmdet.models import FPN, SingleRoIExtractor
3
+
4
+ from app.models.model import SAMSegmentor
5
+ from app.models.openclip_backbone import OpenCLIPBackbone
6
+ from app.models.ovsam_head import OVSAMHead
7
+ from app.models.sam_pe import SAMPromptEncoder
8
+ from app.models.transformer_neck import MultiLayerTransformerNeck
9
+
10
+ model = dict(
11
+ type=SAMSegmentor,
12
+ data_preprocessor=None,
13
+ enable_backbone=True,
14
+ backbone=dict(
15
+ type=OpenCLIPBackbone,
16
+ model_name='RN50x16',
17
+ fix=True,
18
+ init_cfg=dict(
19
+ type='clip_pretrain',
20
+ checkpoint='openai'
21
+ )
22
+ ),
23
+ neck=dict(
24
+ type=MultiLayerTransformerNeck,
25
+ input_size=(1024, 1024),
26
+ in_channels=[384, 768, 1536, 3072],
27
+ strides=[4, 8, 16, 32],
28
+ layer_ids=(0, 1, 2, 3),
29
+ embed_channels=1280,
30
+ out_channels=256,
31
+ fix=True,
32
+ init_cfg=dict(
33
+ type='Pretrained',
34
+ checkpoint='./models/sam2clip_vith_rn50.pth',
35
+ prefix='neck_student',
36
+ )
37
+ ),
38
+ prompt_encoder=dict(
39
+ type=SAMPromptEncoder,
40
+ model_name='vit_h',
41
+ fix=True,
42
+ init_cfg=dict(
43
+ type='sam_pretrain',
44
+ checkpoint='vit_h'
45
+ )
46
+ ),
47
+ fpn_neck=dict(
48
+ type=FPN,
49
+ in_channels=[384, 768, 1536, 3072],
50
+ out_channels=256,
51
+ num_outs=4,
52
+ init_cfg=dict(
53
+ type='Pretrained',
54
+ checkpoint='./models/R50x16_fpn_lvis_norare_v3det.pth',
55
+ prefix='fpn_neck',
56
+ ),
57
+ ),
58
+ mask_decoder=dict(
59
+ type=OVSAMHead,
60
+ model_name='vit_h',
61
+ with_label_token=True,
62
+ gen_box=True,
63
+ ov_classifier_name='RN50x16_LVISV1Dataset',
64
+ roi_extractor=dict(
65
+ type=SingleRoIExtractor,
66
+ roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0),
67
+ out_channels=256,
68
+ featmap_strides=[4, 8, 16, 32]
69
+ ),
70
+ fix=False,
71
+ init_cfg=dict(
72
+ type='Pretrained',
73
+ checkpoint='./models/ovsam_R50x16_lvisnorare.pth',
74
+ prefix='mask_decoder',
75
+ ),
76
+ load_roi_conv=dict(
77
+ checkpoint='./models/R50x16_fpn_lvis_norare_v3det.pth',
78
+ prefix='roi_conv',
79
+ )
80
+ )
81
+ )
app/configs/sam_vith.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.models.last_layer import LastLayerNeck
2
+ from app.models.model import SAMSegmentor
3
+ from app.models.sam_backbone import SAMBackbone
4
+ from app.models.sam_mask_decoder import SAMMaskDecoder
5
+ from app.models.sam_pe import SAMPromptEncoder
6
+
7
+ model = dict(
8
+ type=SAMSegmentor,
9
+ data_preprocessor=None,
10
+ backbone=dict(
11
+ type=SAMBackbone,
12
+ model_name='vit_h',
13
+ fix=True,
14
+ init_cfg=dict(
15
+ type='sam_pretrain',
16
+ checkpoint='vit_h'
17
+ )
18
+ ),
19
+ neck=dict(type=LastLayerNeck),
20
+ prompt_encoder=dict(
21
+ type=SAMPromptEncoder,
22
+ model_name='vit_h',
23
+ fix=True,
24
+ init_cfg=dict(
25
+ type='sam_pretrain',
26
+ checkpoint='vit_h'
27
+ )
28
+ ),
29
+ mask_decoder=dict(
30
+ type=SAMMaskDecoder,
31
+ model_name='vit_h',
32
+ fix=True,
33
+ init_cfg=dict(
34
+ type='sam_pretrain',
35
+ checkpoint='vit_h'
36
+ )
37
+ )
38
+ )
app/models/last_layer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from mmengine.model import BaseModule
4
+ from torch import Tensor
5
+
6
+ from mmdet.registry import MODELS
7
+
8
+
9
+ @MODELS.register_module()
10
+ class LastLayerNeck(BaseModule):
11
+ r"""Last Layer Neck
12
+
13
+ Return the last layer feature of the backbone.
14
+ """
15
+
16
+ def __init__(self) -> None:
17
+ super().__init__(init_cfg=None)
18
+
19
+ def forward(self, inputs: Tuple[Tensor]) -> Tensor:
20
+ return inputs[-1]
app/models/model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from mmengine.model import BaseModel
3
+
4
+ from mmdet.registry import MODELS
5
+ from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
6
+
7
+
8
+ @MODELS.register_module()
9
+ class SAMSegmentor(BaseModel):
10
+ MASK_THRESHOLD = 0.5
11
+
12
+ def __init__(
13
+ self,
14
+ backbone: ConfigType,
15
+ neck: ConfigType,
16
+ prompt_encoder: ConfigType,
17
+ mask_decoder: ConfigType,
18
+ data_preprocessor: OptConfigType = None,
19
+ fpn_neck: OptConfigType = None,
20
+ init_cfg: OptMultiConfig = None,
21
+ use_clip_feat: bool = False,
22
+ use_head_feat: bool = False,
23
+ use_gt_prompt: bool = False,
24
+ use_point: bool = False,
25
+ enable_backbone: bool = False,
26
+ ) -> None:
27
+ super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg)
28
+
29
+ self.backbone = MODELS.build(backbone)
30
+ self.neck = MODELS.build(neck)
31
+ self.pe = MODELS.build(prompt_encoder)
32
+ self.mask_decoder = MODELS.build(mask_decoder)
33
+ if fpn_neck is not None:
34
+ self.fpn_neck = MODELS.build(fpn_neck)
35
+ else:
36
+ self.fpn_neck = None
37
+
38
+ self.use_clip_feat = use_clip_feat
39
+ self.use_head_feat = use_head_feat
40
+ self.use_gt_prompt = use_gt_prompt
41
+ self.use_point = use_point
42
+
43
+ self.enable_backbone = enable_backbone
44
+
45
+ def extract_feat(self, inputs):
46
+ backbone_feat = self.backbone(inputs)
47
+ neck_feat = self.neck(backbone_feat)
48
+ if self.fpn_neck is not None:
49
+ fpn_feat = self.fpn_neck(backbone_feat)
50
+ else:
51
+ fpn_feat = None
52
+
53
+ return dict(
54
+ backbone_feat=backbone_feat,
55
+ neck_feat=neck_feat,
56
+ fpn_feat=fpn_feat
57
+ )
58
+
59
+ def extract_masks(self, feat_cache, prompts):
60
+ sparse_embed, dense_embed = self.pe(
61
+ prompts,
62
+ image_size=(1024, 1024),
63
+ with_points='point_coords' in prompts,
64
+ with_bboxes='bboxes' in prompts,
65
+ )
66
+
67
+ kwargs = dict()
68
+ if self.enable_backbone:
69
+ kwargs['backbone_feats'] = feat_cache['backbone_feat']
70
+ kwargs['backbone'] = self.backbone
71
+ kwargs['fpn_feats'] = feat_cache['fpn_feat']
72
+ low_res_masks, iou_predictions, cls_pred = self.mask_decoder(
73
+ image_embeddings=feat_cache['neck_feat'],
74
+ image_pe=self.pe.get_dense_pe(),
75
+ sparse_prompt_embeddings=sparse_embed,
76
+ dense_prompt_embeddings=dense_embed,
77
+ multi_mask_output=False,
78
+ **kwargs
79
+ )
80
+ masks = F.interpolate(
81
+ low_res_masks,
82
+ scale_factor=4.,
83
+ mode='bilinear',
84
+ align_corners=False,
85
+ )
86
+
87
+ masks = masks.sigmoid()
88
+ cls_pred = cls_pred.softmax(-1)[..., :-1]
89
+ return masks.detach().cpu().numpy(), cls_pred.detach().cpu()
90
+
91
+ def forward(self, inputs):
92
+ return inputs
app/models/openclip_backbone.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.nn as nn
6
+ from mmdet.registry import MODELS
7
+
8
+ from mmengine.model import BaseModule
9
+ from mmengine.dist import get_dist_info
10
+ from mmengine.logging import MMLogger
11
+
12
+ import ext.open_clip as open_clip
13
+ from utils.load_checkpoint import load_checkpoint_with_prefix
14
+
15
+
16
+ @MODELS.register_module()
17
+ class OpenCLIPBackbone(BaseModule):
18
+ """OpenCLIPBackbone,
19
+ Please refer to:
20
+ https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface
21
+ for the supported models and checkpoints.
22
+ """
23
+ STAGES = 4
24
+
25
+ def __init__(
26
+ self,
27
+ img_size: int = 1024,
28
+ model_name: str = '',
29
+ fix: bool = True,
30
+ fix_layers: Optional[List] = None,
31
+ init_cfg=None,
32
+ ):
33
+ assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \
34
+ f"{init_cfg['type']} is not supported."
35
+ pretrained = init_cfg['checkpoint']
36
+ super().__init__(init_cfg=None)
37
+ self.init_cfg = init_cfg
38
+ self.logger = MMLogger.get_current_instance()
39
+ rank, world_size = get_dist_info()
40
+
41
+ if world_size > 1:
42
+ if rank == 0:
43
+ if init_cfg['type'] == 'clip_pretrain':
44
+ _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained,
45
+ return_transform=False, logger=self.logger)
46
+ elif init_cfg['type'] == 'image_pretrain':
47
+ _ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger)
48
+
49
+ else:
50
+ pass
51
+ dist.barrier()
52
+
53
+ # Get the clip model
54
+ if init_cfg['type'] == 'clip_pretrain':
55
+ clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained,
56
+ return_transform=False, logger=self.logger)
57
+ elif init_cfg['type'] == 'image_pretrain':
58
+ clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger)
59
+ elif init_cfg['type'] == 'Pretrained':
60
+ clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger)
61
+ else:
62
+ raise NotImplementedError
63
+
64
+ self.out_indices = (0, 1, 2, 3)
65
+ model_name_lower = model_name.lower()
66
+ if 'convnext_' in model_name_lower:
67
+ model_type = 'convnext'
68
+ if '_base' in model_name_lower:
69
+ output_channels = [128, 256, 512, 1024]
70
+ feat_size = 0
71
+ elif '_large' in model_name_lower:
72
+ output_channels = [192, 384, 768, 1536]
73
+ feat_size = 0
74
+ elif '_xxlarge' in model_name_lower:
75
+ output_channels = [384, 768, 1536, 3072]
76
+ feat_size = 0
77
+ else:
78
+ raise NotImplementedError(f"{model_name} not supported yet.")
79
+ elif 'rn' in model_name_lower:
80
+ model_type = 'resnet'
81
+ if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']:
82
+ output_channels = [256, 512, 1024, 2048]
83
+ feat_size = 7
84
+ elif model_name_lower == 'rn50x4':
85
+ output_channels = [320, 640, 1280, 2560]
86
+ feat_size = 9
87
+ elif model_name_lower == 'rn50x16':
88
+ output_channels = [384, 768, 1536, 3072]
89
+ feat_size = 12
90
+ elif model_name_lower == 'rn50x64':
91
+ output_channels = [512, 1024, 2048, 4096]
92
+ feat_size = 14
93
+ else:
94
+ raise NotImplementedError(f"{model_name} not supported yet.")
95
+ else:
96
+ raise NotImplementedError(f"{model_name} not supported yet.")
97
+
98
+ self.model_name = model_name
99
+ self.fix = fix
100
+ self.model_type = model_type
101
+ self.output_channels = output_channels
102
+ self.feat_size = feat_size
103
+
104
+ # Get the visual model
105
+ if self.model_type == 'resnet':
106
+ self.stem = nn.Sequential(*[
107
+ clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1,
108
+ clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2,
109
+ clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3,
110
+ ])
111
+ elif self.model_type == 'convnext':
112
+ self.stem = clip_model.visual.trunk.stem
113
+ else:
114
+ raise ValueError
115
+
116
+ if self.model_type == 'resnet':
117
+ self.avgpool = clip_model.visual.avgpool
118
+ elif self.model_type == 'convnext':
119
+ self.avgpool = nn.Identity()
120
+ else:
121
+ raise ValueError
122
+
123
+ self.res_layers = []
124
+ for i in range(self.STAGES):
125
+ if self.model_type == 'resnet':
126
+ layer_name = f'layer{i + 1}'
127
+ layer = getattr(clip_model.visual, layer_name)
128
+ elif self.model_type == 'convnext':
129
+ layer_name = f'layer{i + 1}'
130
+ layer = clip_model.visual.trunk.stages[i]
131
+ else:
132
+ raise ValueError
133
+ self.add_module(layer_name, layer)
134
+ self.res_layers.append(layer_name)
135
+
136
+ if self.model_type == 'resnet':
137
+ self.norm_pre = nn.Identity()
138
+ elif self.model_type == 'convnext':
139
+ self.norm_pre = clip_model.visual.trunk.norm_pre
140
+
141
+ if self.model_type == 'resnet':
142
+ self.head = clip_model.visual.attnpool
143
+ elif self.model_type == 'convnext':
144
+ self.head = nn.Sequential(*[
145
+ clip_model.visual.trunk.head,
146
+ clip_model.visual.head,
147
+ ])
148
+
149
+ if self.init_cfg['type'] == 'Pretrained':
150
+ checkpoint_path = pretrained
151
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
152
+ self.load_state_dict(state_dict, strict=True)
153
+
154
+ self.fix_layers = fix_layers
155
+
156
+ if not self.fix:
157
+ self.train()
158
+ for name, param in self.norm_pre.named_parameters():
159
+ param.requires_grad = False
160
+ for name, param in self.head.named_parameters():
161
+ param.requires_grad = False
162
+ if self.fix_layers is not None:
163
+ for i, layer_name in enumerate(self.res_layers):
164
+ if i in self.fix_layers:
165
+ res_layer = getattr(self, layer_name)
166
+ for name, param in res_layer.named_parameters():
167
+ param.requires_grad = False
168
+
169
+ if self.fix:
170
+ self.train(mode=False)
171
+ for name, param in self.named_parameters():
172
+ param.requires_grad = False
173
+
174
+ def init_weights(self):
175
+ self.logger.info(f"Init Config for {self.model_name}")
176
+ self.logger.info(self.init_cfg)
177
+
178
+ def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
179
+ if not isinstance(mode, bool):
180
+ raise ValueError("training mode is expected to be boolean")
181
+ if self.fix:
182
+ super().train(mode=False)
183
+ else:
184
+ super().train(mode=mode)
185
+ if self.fix_layers is not None:
186
+ for i, layer_name in enumerate(self.res_layers):
187
+ if i in self.fix_layers:
188
+ res_layer = getattr(self, layer_name)
189
+ res_layer.train(mode=False)
190
+ return self
191
+
192
+ def forward_func(self, x):
193
+ x = self.stem(x)
194
+ x = self.avgpool(x)
195
+ outs = []
196
+ for i, layer_name in enumerate(self.res_layers):
197
+ res_layer = getattr(self, layer_name)
198
+ x = res_layer(x).contiguous()
199
+ if i in self.out_indices:
200
+ outs.append(x)
201
+ return tuple(outs)
202
+
203
+ def get_clip_feature(self, backbone_feat):
204
+ if self.model_type == 'resnet':
205
+ return backbone_feat
206
+ elif self.model_type == 'convnext':
207
+ return self.norm_pre(backbone_feat)
208
+ raise NotImplementedError
209
+
210
+ def forward_feat(self, features):
211
+ if self.model_type == 'convnext':
212
+ batch, num_query, channel = features.shape
213
+ features = features.reshape(batch * num_query, channel, 1, 1)
214
+ features = self.head(features)
215
+ return features.view(batch, num_query, features.shape[-1])
216
+ elif self.model_type == 'resnet':
217
+ num_query, channel, seven, seven = features.shape
218
+ features = self.head(features)
219
+ return features
220
+
221
+ def forward(self, x):
222
+ if self.fix:
223
+ with torch.no_grad():
224
+ outs = self.forward_func(x)
225
+ else:
226
+ outs = self.forward_func(x)
227
+ return outs
228
+
229
+ def get_text_model(self):
230
+ return OpenCLIPBackboneText(
231
+ self.model_name,
232
+ init_cfg=self.init_cfg
233
+ )
234
+
235
+
236
+ @MODELS.register_module()
237
+ class OpenCLIPBackboneText(BaseModule):
238
+ def __init__(
239
+ self,
240
+ model_name: str = '',
241
+ init_cfg=None,
242
+ ):
243
+ assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported."
244
+ pretrained = init_cfg['checkpoint']
245
+ super().__init__(init_cfg=None)
246
+ self.init_cfg = init_cfg
247
+ self.logger = MMLogger.get_current_instance()
248
+ rank, world_size = get_dist_info()
249
+
250
+ if world_size > 1:
251
+ if rank == 0:
252
+ _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False,
253
+ logger=self.logger)
254
+ else:
255
+ pass
256
+ dist.barrier()
257
+
258
+ # Get the clip model
259
+ clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False,
260
+ logger=self.logger)
261
+
262
+ # Get the textual model
263
+ self.text_tokenizer = open_clip.get_tokenizer(model_name)
264
+ self.text_transformer = clip_model.transformer
265
+ self.text_token_embedding = clip_model.token_embedding
266
+ self.text_pe = clip_model.positional_embedding
267
+ self.text_ln_final = clip_model.ln_final
268
+ self.text_proj = clip_model.text_projection
269
+
270
+ self.register_buffer('text_attn_mask', clip_model.attn_mask)
271
+
272
+ self.param_dtype = torch.float32
273
+ self.model_name = model_name
274
+
275
+ def init_weights(self):
276
+ self.logger.info(f"Init Config for {self.model_name}")
277
+ self.logger.info(self.init_cfg)
278
+
279
+ # Copied from
280
+ # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343
281
+ @torch.no_grad()
282
+ def forward(self, text):
283
+ text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device)
284
+ x = self.text_token_embedding(text_tokens).to(self.param_dtype)
285
+ x = x + self.text_pe.to(self.param_dtype)
286
+ x = x.permute(1, 0, 2)
287
+ x = self.text_transformer(x, attn_mask=self.text_attn_mask)
288
+ x = x.permute(1, 0, 2)
289
+ x = self.text_ln_final(x) # [batch_size, n_ctx, transformer.width]
290
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
291
+ x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj
292
+ return x
app/models/ovsam_head.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Literal, Tuple, List, Optional
4
+
5
+ import torch
6
+ from mmcv.cnn import ConvModule
7
+ from mmdet.structures.bbox import bbox2roi
8
+ from mmdet.structures.mask import mask2bbox
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from mmengine import MMLogger
12
+ from mmengine.model import BaseModule
13
+ from mmdet.registry import MODELS
14
+
15
+ from ext.sam import MaskDecoder
16
+ from ext.sam.mask_decoder import MLP as SAMMLP
17
+ from ext.meta.sam_meta import meta_dict, checkpoint_dict
18
+ from utils.load_checkpoint import load_checkpoint_with_prefix
19
+
20
+
21
+ @MODELS.register_module()
22
+ class OVSAMHead(BaseModule):
23
+
24
+ def __init__(
25
+ self,
26
+ model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h',
27
+ with_label_token: bool = False,
28
+ ov_classifier_name: Optional[str] = None,
29
+ logit: Optional[float] = None,
30
+ roi_extractor=None,
31
+ fix: bool = True,
32
+ init_cfg=None,
33
+ cur_mask=1,
34
+ roi_extractor_single=None,
35
+ load_roi_conv=None,
36
+ gen_box=False,
37
+ ):
38
+ assert init_cfg is not None and \
39
+ init_cfg['type'] in ['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported."
40
+ pretrained = init_cfg['checkpoint']
41
+ super().__init__(init_cfg=None)
42
+ self.init_cfg = init_cfg
43
+ self.logger = MMLogger.get_current_instance()
44
+ if roi_extractor_single is not None:
45
+ self.roi_extractor_single = MODELS.build(roi_extractor_single)
46
+ self.roi_merge_proj = nn.Linear(768 * 2, 768)
47
+ else:
48
+ self.roi_extractor_single = None
49
+ self.roi_merge_proj = None
50
+
51
+ mask_decoder = MaskDecoder(
52
+ num_multimask_outputs=cur_mask - 1,
53
+ transformer_dim=meta_dict[model_name]['prompt_embed_dim'],
54
+ iou_head_depth=3,
55
+ iou_head_hidden_dim=256,
56
+ with_iou=False
57
+ )
58
+
59
+ if self.init_cfg['type'] == 'sam_pretrain':
60
+ raise NotImplementedError
61
+
62
+ self.mask_decoder = mask_decoder
63
+
64
+ self.with_label_token = with_label_token
65
+
66
+ if self.with_label_token:
67
+ ov_path = os.path.join(os.path.expanduser('./models/'), f"{ov_classifier_name}.pth")
68
+ cls_embed = torch.load(ov_path)
69
+ cls_embed_norm = cls_embed.norm(p=2, dim=-1)
70
+ assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm))
71
+
72
+ _dim = cls_embed.size(2)
73
+ _prototypes = cls_embed.size(1)
74
+ back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cpu')
75
+ cls_embed = torch.cat([
76
+ cls_embed, back_token.repeat(_prototypes, 1)[None]
77
+ ], dim=0)
78
+ self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False)
79
+
80
+ if logit is None:
81
+ logit_scale = torch.tensor(4.6052, dtype=torch.float32)
82
+ else:
83
+ logit_scale = torch.tensor(logit, dtype=torch.float32)
84
+ self.register_buffer('logit_scale', logit_scale, persistent=False)
85
+
86
+ transformer_dim = self.mask_decoder.mask_tokens.weight.shape[1]
87
+ self.label_token = nn.Embedding(1, transformer_dim)
88
+ self.label_mlp = SAMMLP(transformer_dim, transformer_dim, _dim, 3)
89
+
90
+ self.gen_box = gen_box
91
+
92
+ if roi_extractor is not None:
93
+ self.roi = MODELS.build(roi_extractor)
94
+ self.roi_conv = nn.Sequential(*[
95
+ ConvModule(in_channels=self.roi.out_channels, out_channels=_dim, kernel_size=1, bias=False)
96
+ ])
97
+ else:
98
+ self.roi = None
99
+
100
+ if self.init_cfg['type'] == 'Pretrained':
101
+ checkpoint_path = pretrained
102
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
103
+ self.load_state_dict(state_dict, strict=True)
104
+ if roi_extractor is not None and load_roi_conv is not None:
105
+ checkpoint_path = load_roi_conv['checkpoint']
106
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=load_roi_conv['prefix'])
107
+ self.roi_conv.load_state_dict(state_dict, strict=True)
108
+
109
+ self.fix = fix
110
+
111
+ if self.fix:
112
+ self.train(mode=False)
113
+ for name, param in self.named_parameters():
114
+ param.requires_grad = False
115
+
116
+ def init_weights(self):
117
+ self.logger.info(f"Init Config for {self.__class__.__name__}")
118
+ self.logger.info(self.init_cfg)
119
+
120
+ def forward_logit(self, cls_embd):
121
+ cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed)
122
+ cls_pred = cls_pred.max(-1).values
123
+ cls_pred = self.logit_scale.exp() * cls_pred
124
+ return cls_pred
125
+
126
+ def predict_masks(
127
+ self,
128
+ image_embeddings: torch.Tensor,
129
+ image_pe: torch.Tensor,
130
+ sparse_prompt_embeddings: torch.Tensor,
131
+ dense_prompt_embeddings: torch.Tensor,
132
+ fpn_feats: List[torch.Tensor],
133
+ roi_list: Optional[List[torch.Tensor]],
134
+ backbone_feature: torch.Tensor,
135
+ backbone=None
136
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137
+ """Predicts masks. See 'forward' for more details."""
138
+ num_instances = int(sparse_prompt_embeddings.size(0))
139
+ # Concatenate output tokens
140
+ output_tokens = torch.cat([
141
+ self.label_token.weight,
142
+ self.mask_decoder.mask_tokens.weight], dim=0
143
+ )
144
+ output_tokens = output_tokens.unsqueeze(0).expand(num_instances, -1, -1)
145
+ queries = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
146
+
147
+ # image_embeddings = torch.repeat_interleave(image_embeddings, num_instances, dim=0)
148
+ image_embeddings = image_embeddings + dense_prompt_embeddings
149
+ pos_img = torch.repeat_interleave(image_pe, num_instances, dim=0)
150
+ b, c, h, w = image_embeddings.shape
151
+
152
+ # Run the transformer
153
+ queries, mask_feats = self.mask_decoder.transformer(image_embeddings, pos_img, queries)
154
+ label_query = queries[:, 0, :]
155
+ mask_embeds = queries[:, 1:(1 + self.mask_decoder.num_mask_tokens), :]
156
+
157
+ # Upscale mask embeddings and predict masks using the mask tokens
158
+ mask_feats = mask_feats.transpose(1, 2).view(b, c, h, w)
159
+ mask_feats = self.mask_decoder.output_upscaling(mask_feats)
160
+ mask_queries_list: List[torch.Tensor] = []
161
+ for i in range(self.mask_decoder.num_mask_tokens):
162
+ mask_queries_list.append(self.mask_decoder.output_hypernetworks_mlps[i](mask_embeds[:, i, :]))
163
+ mask_queries = torch.stack(mask_queries_list, dim=1)
164
+ b, c, h, w = mask_feats.shape
165
+ masks = (mask_queries @ mask_feats.view(b, c, h * w)).view(b, -1, h, w)
166
+
167
+ # Generate class labels
168
+ if self.with_label_token:
169
+ cls_embed_list = []
170
+ assert self.mask_decoder.num_mask_tokens == 1
171
+ for i in range(self.mask_decoder.num_mask_tokens):
172
+ cls_embed_list.append(self.label_mlp(label_query))
173
+ cls_embed = torch.stack(cls_embed_list, dim=1)
174
+ if self.gen_box:
175
+ bboxes = mask2bbox(masks.sigmoid()[:, 0] > 0.5) * 4
176
+ roi_list = bbox2roi([bboxes])
177
+ roi_feats = self.roi(fpn_feats, roi_list)
178
+ roi_feats = self.roi_conv(roi_feats)
179
+ roi_feats = roi_feats.mean(dim=-1).mean(dim=-1)
180
+ if self.roi_extractor_single:
181
+ roi_feats_clip = self.roi_extractor_single(
182
+ backbone.get_clip_feature(backbone_feature[-1:]), roi_list
183
+ )
184
+ roi_feats_clip = backbone.forward_feat(roi_feats_clip)
185
+ roi_feats = self.roi_merge_proj(torch.cat([roi_feats, roi_feats_clip], dim=-1))
186
+ roi_feats = roi_feats[:, None] + 0 * cls_embed
187
+ cls_pred = self.forward_logit(roi_feats)
188
+ else:
189
+ cls_pred = None
190
+ return masks, None, cls_pred
191
+
192
+ def forward(
193
+ self,
194
+ image_embeddings: torch.Tensor,
195
+ image_pe: torch.Tensor,
196
+ sparse_prompt_embeddings: torch.Tensor,
197
+ dense_prompt_embeddings: torch.Tensor,
198
+ multi_mask_output: bool,
199
+ data_samples=None,
200
+ fpn_feats=None,
201
+ backbone_feats=None,
202
+ backbone=None,
203
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
204
+ num_prompts = len(sparse_prompt_embeddings)
205
+ image_embeddings = torch.repeat_interleave(image_embeddings, num_prompts, dim=0)
206
+
207
+ masks, _, cls_pred = self.predict_masks(
208
+ image_embeddings=image_embeddings,
209
+ image_pe=image_pe,
210
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
211
+ dense_prompt_embeddings=dense_prompt_embeddings,
212
+ fpn_feats=fpn_feats,
213
+ roi_list=None,
214
+ backbone_feature=backbone_feats,
215
+ backbone=backbone,
216
+ )
217
+
218
+ # Select the correct mask or masks for output
219
+ if multi_mask_output:
220
+ mask_slice = slice(1, None)
221
+ else:
222
+ mask_slice = slice(0, 1)
223
+ masks = masks[:, mask_slice, :, :]
224
+
225
+ # Prepare output
226
+ return masks, None, cls_pred
app/models/sam_backbone.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Literal
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from mmdet.registry import MODELS
7
+
8
+ from mmengine.model import BaseModule
9
+ from mmengine.logging import MMLogger
10
+
11
+ from ext.sam import ImageEncoderViT
12
+ from ext.meta.sam_meta import meta_dict, checkpoint_dict
13
+ from utils.load_checkpoint import load_checkpoint_with_prefix
14
+
15
+
16
+ @MODELS.register_module()
17
+ class SAMBackbone(BaseModule):
18
+
19
+ def __init__(
20
+ self,
21
+ model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h',
22
+ fix: bool = True,
23
+ init_cfg=None,
24
+ ):
25
+ assert init_cfg is not None and init_cfg['type'] in \
26
+ ['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported."
27
+ pretrained = init_cfg['checkpoint']
28
+ super().__init__(init_cfg=None)
29
+ self.init_cfg = init_cfg
30
+ self.logger = MMLogger.get_current_instance()
31
+
32
+ backbone_meta = meta_dict[model_name]
33
+
34
+ backbone = ImageEncoderViT(
35
+ depth=backbone_meta['encoder_depth'],
36
+ embed_dim=backbone_meta['encoder_embed_dim'],
37
+ num_heads=backbone_meta['encoder_num_heads'],
38
+ patch_size=backbone_meta['vit_patch_size'],
39
+ img_size=backbone_meta['image_size'],
40
+ global_attn_indexes=backbone_meta['encoder_global_attn_indexes'],
41
+ out_chans=backbone_meta['prompt_embed_dim'],
42
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
43
+ qkv_bias=True,
44
+ use_rel_pos=True,
45
+ mlp_ratio=4,
46
+ window_size=14,
47
+ )
48
+ if self.init_cfg['type'] == 'sam_pretrain':
49
+ checkpoint_path = checkpoint_dict[pretrained]
50
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='image_encoder')
51
+ backbone.load_state_dict(state_dict, strict=True)
52
+
53
+ self.stem = backbone.patch_embed
54
+ self.pos_embed = backbone.pos_embed
55
+
56
+ self.res_layers = []
57
+ last_pos = 0
58
+ for idx, cur_pos in enumerate(backbone_meta['encoder_global_attn_indexes']):
59
+ blocks = backbone.blocks[last_pos:cur_pos + 1]
60
+ layer_name = f'layer{idx + 1}'
61
+ self.add_module(layer_name, nn.Sequential(*blocks))
62
+ self.res_layers.append(layer_name)
63
+ last_pos = cur_pos + 1
64
+
65
+ self.out_proj = backbone.neck
66
+
67
+ if self.init_cfg['type'] == 'Pretrained':
68
+ checkpoint_path = pretrained
69
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
70
+ self.load_state_dict(state_dict, strict=True)
71
+
72
+ self.model_name = model_name
73
+ self.fix = fix
74
+ self.model_type = 'vit'
75
+ self.output_channels = None
76
+ self.out_indices = (0, 1, 2, 3)
77
+ if self.fix:
78
+ self.train(mode=False)
79
+ for name, param in self.named_parameters():
80
+ param.requires_grad = False
81
+
82
+ def init_weights(self):
83
+ self.logger.info(f"Init Config for {self.model_name}")
84
+ self.logger.info(self.init_cfg)
85
+
86
+ def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
87
+ if not isinstance(mode, bool):
88
+ raise ValueError("training mode is expected to be boolean")
89
+ if self.fix:
90
+ super().train(mode=False)
91
+ else:
92
+ super().train(mode=mode)
93
+ return self
94
+
95
+ def forward_func(self, x):
96
+ x = self.stem(x)
97
+ x = x + self.pos_embed
98
+ outs = []
99
+ for i, layer_name in enumerate(self.res_layers):
100
+ res_layer = getattr(self, layer_name)
101
+ x = res_layer(x)
102
+ if i in self.out_indices:
103
+ outs.append(x.permute(0, 3, 1, 2).contiguous())
104
+ outs[-1] = self.out_proj(outs[-1])
105
+ return tuple(outs)
106
+
107
+ def forward(self, x):
108
+ if self.fix:
109
+ with torch.no_grad():
110
+ outs = self.forward_func(x)
111
+ else:
112
+ outs = self.forward_func(x)
113
+ return outs
app/models/sam_mask_decoder.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Tuple, List
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from mmdet.structures import SampleList
6
+ from mmengine import MMLogger
7
+ from mmengine.model import BaseModule
8
+ from mmdet.registry import MODELS
9
+
10
+ from ext.sam import MaskDecoder
11
+ from ext.meta.sam_meta import meta_dict, checkpoint_dict
12
+ from utils.load_checkpoint import load_checkpoint_with_prefix
13
+
14
+
15
+ @MODELS.register_module()
16
+ class SAMMaskDecoder(BaseModule):
17
+
18
+ def __init__(
19
+ self,
20
+ model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h',
21
+ fix: bool = True,
22
+ init_cfg=None,
23
+ ):
24
+ assert init_cfg is not None and \
25
+ init_cfg['type'] in ['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported."
26
+ pretrained = init_cfg['checkpoint']
27
+ super().__init__(init_cfg=None)
28
+ self.init_cfg = init_cfg
29
+ self.logger = MMLogger.get_current_instance()
30
+
31
+ mask_decoder = MaskDecoder(
32
+ num_multimask_outputs=3,
33
+ transformer_dim=meta_dict[model_name]['prompt_embed_dim'],
34
+ iou_head_depth=3,
35
+ iou_head_hidden_dim=256,
36
+ )
37
+ if self.init_cfg['type'] == 'sam_pretrain':
38
+ checkpoint_path = checkpoint_dict[pretrained]
39
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='mask_decoder')
40
+ mask_decoder.load_state_dict(state_dict, strict=True)
41
+
42
+ self.mask_decoder = mask_decoder
43
+ if self.init_cfg['type'] == 'Pretrained':
44
+ checkpoint_path = pretrained
45
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
46
+ self.load_state_dict(state_dict, strict=True)
47
+
48
+ self.fix = fix
49
+ if self.fix:
50
+ self.train(mode=False)
51
+ for name, param in self.named_parameters():
52
+ param.requires_grad = False
53
+
54
+ def init_weights(self):
55
+ self.logger.info(f"Init Config for {self.__class__.__name__}")
56
+ self.logger.info(self.init_cfg)
57
+
58
+ def forward_logit(self, cls_embd):
59
+ cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed)
60
+ cls_pred = cls_pred.max(-1).values
61
+ cls_pred = self.logit_scale.exp() * cls_pred
62
+ return cls_pred
63
+
64
+ def predict_masks(
65
+ self,
66
+ image_embeddings: torch.Tensor,
67
+ image_pe: torch.Tensor,
68
+ sparse_prompt_embeddings: torch.Tensor,
69
+ dense_prompt_embeddings: torch.Tensor,
70
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
71
+ """Predicts masks. See 'forward' for more details."""
72
+ num_instances = int(sparse_prompt_embeddings.shape[0])
73
+ # Concatenate output tokens
74
+ output_tokens = torch.cat([self.mask_decoder.iou_token.weight, self.mask_decoder.mask_tokens.weight], dim=0)
75
+ output_tokens = output_tokens.unsqueeze(0).expand(num_instances, -1, -1)
76
+ queries = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
77
+
78
+ # image_embeddings = torch.repeat_interleave(image_embeddings, num_instances, dim=0)
79
+ image_embeddings = image_embeddings + dense_prompt_embeddings
80
+ pos_img = torch.repeat_interleave(image_pe, num_instances, dim=0)
81
+ b, c, h, w = image_embeddings.shape
82
+
83
+ # Run the transformer
84
+ queries, mask_feats = self.mask_decoder.transformer(image_embeddings, pos_img, queries)
85
+ iou_query = queries[:, 0, :]
86
+ mask_embeds = queries[:, 1:(1 + self.mask_decoder.num_mask_tokens), :]
87
+
88
+ # Upscale mask embeddings and predict masks using the mask tokens
89
+ mask_feats = mask_feats.transpose(1, 2).view(b, c, h, w)
90
+ mask_feats = self.mask_decoder.output_upscaling(mask_feats)
91
+ mask_queries_list: List[torch.Tensor] = []
92
+ for i in range(self.mask_decoder.num_mask_tokens):
93
+ mask_queries_list.append(self.mask_decoder.output_hypernetworks_mlps[i](mask_embeds[:, i, :]))
94
+ mask_queries = torch.stack(mask_queries_list, dim=1)
95
+ b, c, h, w = mask_feats.shape
96
+ masks = (mask_queries @ mask_feats.view(b, c, h * w)).view(b, -1, h, w)
97
+
98
+ # Generate mask quality predictions
99
+ iou_pred = self.mask_decoder.iou_prediction_head(iou_query)
100
+
101
+ return masks, iou_pred, None
102
+
103
+ def forward(
104
+ self,
105
+ image_embeddings: torch.Tensor,
106
+ image_pe: torch.Tensor,
107
+ sparse_prompt_embeddings: torch.Tensor,
108
+ dense_prompt_embeddings: torch.Tensor,
109
+ multi_mask_output: bool,
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111
+ num_prompts = len(sparse_prompt_embeddings)
112
+ image_embeddings = torch.repeat_interleave(image_embeddings, num_prompts, dim=0)
113
+ masks, iou_pred, cls_pred = self.predict_masks(
114
+ image_embeddings=image_embeddings,
115
+ image_pe=image_pe,
116
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
117
+ dense_prompt_embeddings=dense_prompt_embeddings,
118
+ )
119
+
120
+ # Select the correct mask or masks for output
121
+ if multi_mask_output:
122
+ mask_slice = slice(1, None)
123
+ else:
124
+ mask_slice = slice(0, 1)
125
+ masks = masks[:, mask_slice, :, :]
126
+ iou_pred = iou_pred[:, mask_slice]
127
+
128
+ # Prepare output
129
+ return masks, iou_pred, cls_pred
130
+
131
+ def forward_train(
132
+ self,
133
+ image_embeddings: torch.Tensor,
134
+ image_pe: torch.Tensor,
135
+ sparse_prompt_embeddings: torch.Tensor,
136
+ dense_prompt_embeddings: torch.Tensor,
137
+ batch_ind_list: List[int],
138
+ data_samples: SampleList,
139
+ ):
140
+ raise NotImplementedError
app/models/sam_pe.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Literal
2
+
3
+ import torch
4
+ from mmengine import MMLogger
5
+
6
+ from mmdet.registry import MODELS
7
+ from mmengine.model import BaseModule
8
+ from mmengine.structures import InstanceData
9
+
10
+ from ext.sam import PromptEncoder
11
+ from ext.meta.sam_meta import meta_dict, checkpoint_dict
12
+ from utils.load_checkpoint import load_checkpoint_with_prefix
13
+
14
+
15
+ @MODELS.register_module()
16
+ class SAMPromptEncoder(BaseModule):
17
+
18
+ def __init__(
19
+ self,
20
+ model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h',
21
+ fix: bool = True,
22
+ init_cfg=None,
23
+ ):
24
+ assert init_cfg is not None and init_cfg['type'] == 'sam_pretrain', f"{init_cfg['type']} is not supported."
25
+ pretrained = init_cfg['checkpoint']
26
+ super().__init__(init_cfg=None)
27
+ self.init_cfg = init_cfg
28
+ self.logger = MMLogger.get_current_instance()
29
+
30
+ backbone_meta = meta_dict[model_name]
31
+ checkpoint_path = checkpoint_dict[pretrained]
32
+
33
+ prompt_encoder = PromptEncoder(
34
+ embed_dim=256,
35
+ image_embedding_size=(backbone_meta['image_embedding_size'], backbone_meta['image_embedding_size']),
36
+ input_image_size=(backbone_meta['image_size'], backbone_meta['image_size']),
37
+ mask_in_chans=16,
38
+ )
39
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='prompt_encoder')
40
+ prompt_encoder.load_state_dict(state_dict, strict=True)
41
+
42
+ # meta
43
+ self.embed_dim = prompt_encoder.embed_dim
44
+ self.input_image_size = prompt_encoder.input_image_size
45
+ self.image_embedding_size = prompt_encoder.image_embedding_size
46
+ self.num_point_embeddings = 4
47
+ self.mask_input_size = prompt_encoder.mask_input_size
48
+
49
+ # positional encoding
50
+ self.pe_layer = prompt_encoder.pe_layer
51
+
52
+ # mask encoding
53
+ self.mask_downscaling = prompt_encoder.mask_downscaling
54
+ self.no_mask_embed = prompt_encoder.no_mask_embed
55
+
56
+ # point encoding
57
+ self.point_embeddings = prompt_encoder.point_embeddings
58
+ self.not_a_point_embed = prompt_encoder.not_a_point_embed
59
+
60
+ self.fix = fix
61
+ if self.fix:
62
+ self.train(mode=False)
63
+ for name, param in self.named_parameters():
64
+ param.requires_grad = False
65
+
66
+ @property
67
+ def device(self):
68
+ return self.no_mask_embed.weight.device
69
+
70
+ def init_weights(self):
71
+ self.logger.info(f"Init Config for {self.__class__.__name__}")
72
+ self.logger.info(self.init_cfg)
73
+
74
+ def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
75
+ if not isinstance(mode, bool):
76
+ raise ValueError("training mode is expected to be boolean")
77
+ if self.fix:
78
+ super().train(mode=False)
79
+ else:
80
+ super().train(mode=mode)
81
+ return self
82
+
83
+ def _embed_boxes(self, bboxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
84
+ """Embeds box prompts."""
85
+ bboxes = bboxes + 0.5 # Shift to center of pixel
86
+ coords = bboxes.reshape(-1, 2, 2)
87
+ corner_embedding = self.pe_layer.forward_with_coords(coords, image_size)
88
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
89
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
90
+ return corner_embedding
91
+
92
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
93
+ """Embeds mask inputs."""
94
+ mask_embedding = self.mask_downscaling(masks)
95
+ return mask_embedding
96
+
97
+ def get_dense_pe(self) -> torch.Tensor:
98
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
99
+
100
+ def _embed_points(
101
+ self,
102
+ points: torch.Tensor,
103
+ labels: torch.Tensor,
104
+ pad: bool,
105
+ ) -> torch.Tensor:
106
+ """Embeds point prompts."""
107
+ points = points + 0.5 # Shift to center of pixel
108
+ if pad:
109
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
110
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
111
+ points = torch.cat([points, padding_point], dim=1)
112
+ labels = torch.cat([labels, padding_label], dim=1)
113
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
114
+ point_embedding[labels == -1] = 0.0
115
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
116
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
117
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
118
+ return point_embedding
119
+
120
+ def forward(
121
+ self,
122
+ instances: InstanceData,
123
+ image_size: Tuple[int, int],
124
+ with_points: bool = False,
125
+ with_bboxes: bool = False,
126
+ with_masks: bool = False,
127
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
128
+ assert with_points or with_bboxes or with_masks
129
+ bs = len(instances)
130
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.device)
131
+ if with_points:
132
+ assert 'point_coords' in instances
133
+ coords = instances.point_coords
134
+ labels = torch.ones_like(coords)[:, :, 0]
135
+ point_embeddings = self._embed_points(coords, labels, pad=not with_bboxes)
136
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
137
+
138
+ if with_bboxes:
139
+ assert 'bboxes' in instances
140
+ box_embeddings = self._embed_boxes(
141
+ instances.bboxes, image_size=image_size
142
+ )
143
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
144
+
145
+ if with_masks:
146
+ assert 'masks' in instances
147
+ dense_embeddings = self._embed_masks(instances.masks.masks)
148
+ else:
149
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
150
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
151
+ )
152
+ return sparse_embeddings, dense_embeddings
app/models/transformer_neck.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Tuple, List, Optional
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+ from mmengine.model import BaseModule, normal_init
8
+ from mmdet.registry import MODELS
9
+ from mmdet.models.layers import PatchEmbed
10
+
11
+ from ext.meta.sam_meta import checkpoint_dict
12
+ from ext.sam.common import LayerNorm2d
13
+ from ext.sam.image_encoder import Block
14
+
15
+ from utils.load_checkpoint import load_checkpoint_with_prefix
16
+
17
+
18
+ @MODELS.register_module()
19
+ class MultiLayerTransformerNeck(BaseModule):
20
+ STRIDE = 16
21
+
22
+ def __init__(
23
+ self,
24
+ input_size: Tuple[int, int],
25
+ in_channels: List[int],
26
+ embed_channels: int,
27
+ out_channels: int,
28
+ layer_ids: Tuple[int] = (0, 1, 2, 3),
29
+ strides: Tuple[int] = (4, 8, 16, 32),
30
+ embedding_path: Optional[str] = None,
31
+ fix=False,
32
+ init_cfg=None
33
+ ) -> None:
34
+ super().__init__(init_cfg=None)
35
+
36
+ self.transformer_size = (input_size[0] // self.STRIDE, input_size[1] // self.STRIDE)
37
+ self.layer_ids = layer_ids
38
+
39
+ self.patch_embeds = nn.ModuleList()
40
+ for idx, in_ch in enumerate(in_channels):
41
+ if idx in layer_ids:
42
+ if strides[idx] > self.STRIDE:
43
+ patch_embed = PatchEmbed(
44
+ conv_type=nn.ConvTranspose2d,
45
+ in_channels=in_ch,
46
+ embed_dims=embed_channels,
47
+ kernel_size=strides[idx] // self.STRIDE,
48
+ stride=strides[idx] // self.STRIDE,
49
+ input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx])
50
+ )
51
+ else:
52
+ patch_embed = PatchEmbed(
53
+ in_channels=in_ch,
54
+ embed_dims=embed_channels,
55
+ kernel_size=self.STRIDE // strides[idx],
56
+ stride=self.STRIDE // strides[idx],
57
+ input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx])
58
+ )
59
+ self.patch_embeds.append(patch_embed)
60
+ else:
61
+ self.patch_embeds.append(nn.Identity())
62
+
63
+ if embedding_path is not None:
64
+ assert embedding_path.startswith('sam_')
65
+ embedding_ckpt = embedding_path.split('_', maxsplit=1)[1]
66
+ path = checkpoint_dict[embedding_ckpt]
67
+ state_dict = load_checkpoint_with_prefix(path, prefix='image_encoder')
68
+ pos_embed = state_dict['pos_embed']
69
+ else:
70
+ # For loading from checkpoint
71
+ pos_embed = torch.zeros(1, input_size[0] // self.STRIDE, input_size[1] // self.STRIDE, embed_channels)
72
+
73
+ self.register_buffer('pos_embed', pos_embed)
74
+
75
+ self.level_encoding = nn.Embedding(len(layer_ids), embed_channels)
76
+
77
+ depth = 5
78
+ global_attn_indexes = [4]
79
+ window_size = 14
80
+
81
+ self.blocks = nn.ModuleList()
82
+ for i in range(depth):
83
+ block = Block(
84
+ dim=embed_channels,
85
+ num_heads=16,
86
+ mlp_ratio=4,
87
+ qkv_bias=True,
88
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
89
+ act_layer=nn.GELU,
90
+ use_rel_pos=True,
91
+ rel_pos_zero_init=True,
92
+ window_size=window_size if i not in global_attn_indexes else 0,
93
+ input_size=self.transformer_size,
94
+ )
95
+ self.blocks.append(block)
96
+
97
+ self.neck = nn.Sequential(
98
+ nn.Conv2d(
99
+ embed_channels,
100
+ out_channels,
101
+ kernel_size=1,
102
+ bias=False,
103
+ ),
104
+ LayerNorm2d(out_channels),
105
+ nn.Conv2d(
106
+ out_channels,
107
+ out_channels,
108
+ kernel_size=3,
109
+ padding=1,
110
+ bias=False,
111
+ ),
112
+ LayerNorm2d(out_channels),
113
+ )
114
+
115
+ self.fix = fix
116
+ if self.fix:
117
+ self.train(mode=False)
118
+ for name, param in self.named_parameters():
119
+ param.requires_grad = False
120
+
121
+ if init_cfg is not None:
122
+ assert init_cfg['type'] == 'Pretrained'
123
+ checkpoint_path = init_cfg['checkpoint']
124
+ state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=init_cfg['prefix'])
125
+ self.load_state_dict(state_dict, strict=True)
126
+ self._is_init = True
127
+
128
+ def init_weights(self):
129
+ normal_init(self.level_encoding, mean=0, std=1)
130
+
131
+ def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
132
+ if not isinstance(mode, bool):
133
+ raise ValueError("training mode is expected to be boolean")
134
+ if self.fix:
135
+ super().train(mode=False)
136
+ else:
137
+ super().train(mode=mode)
138
+ return self
139
+
140
+ def forward(self, inputs: Tuple[Tensor]) -> Tensor:
141
+ input_embeddings = []
142
+ level_cnt = 0
143
+ for idx, feat in enumerate(inputs):
144
+ if idx not in self.layer_ids:
145
+ continue
146
+ feat, size = self.patch_embeds[idx](feat)
147
+ feat = feat.unflatten(1, size)
148
+ feat = feat + self.level_encoding.weight[level_cnt]
149
+ input_embeddings.append(feat)
150
+ level_cnt += 1
151
+
152
+ feat = sum(input_embeddings)
153
+ feat = feat + self.pos_embed
154
+ for block in self.blocks:
155
+ feat = block(feat)
156
+ feat = feat.permute(0, 3, 1, 2).contiguous()
157
+ feat = self.neck(feat)
158
+ return feat
ext/class_names/imagenet_21k_names.py ADDED
The diff for this file is too large to render. See raw diff
 
ext/class_names/lvis_list.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LVIS_CLASSES = ('aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
2
+ 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
3
+ 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
4
+ 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
5
+ 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
6
+ 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
7
+ 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
8
+ 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
9
+ 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
10
+ 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
11
+ 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
12
+ 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
13
+ 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
14
+ 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
15
+ 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
16
+ 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
17
+ 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
18
+ 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
19
+ 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
20
+ 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
21
+ 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
22
+ 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
23
+ 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
24
+ 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
25
+ 'bottle_opener', 'bouquet', 'bow_(weapon)',
26
+ 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl',
27
+ 'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders',
28
+ 'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread',
29
+ 'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach',
30
+ 'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket',
31
+ 'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train',
32
+ 'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed',
33
+ 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter',
34
+ 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet',
35
+ 'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder',
36
+ 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can',
37
+ 'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane',
38
+ 'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen',
39
+ 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
40
+ 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
41
+ 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
42
+ 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
43
+ 'cash_register', 'casserole', 'cassette', 'cast', 'cat',
44
+ 'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery',
45
+ 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
46
+ 'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard',
47
+ 'cherry', 'chessboard', 'chicken_(animal)', 'chickpea',
48
+ 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
49
+ 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
50
+ 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
51
+ 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
52
+ 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
53
+ 'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard',
54
+ 'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower',
55
+ 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
56
+ 'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)',
57
+ 'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
58
+ 'coin', 'colander', 'coleslaw', 'coloring_material',
59
+ 'combination_lock', 'pacifier', 'comic_book', 'compass',
60
+ 'computer_keyboard', 'condiment', 'cone', 'control',
61
+ 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
62
+ 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
63
+ 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
64
+ 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
65
+ 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
66
+ 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
67
+ 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
68
+ 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
69
+ 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
70
+ 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
71
+ 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
72
+ 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
73
+ 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
74
+ 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
75
+ 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
76
+ 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
77
+ 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove',
78
+ 'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat',
79
+ 'dress_suit', 'dresser', 'drill', 'drone', 'dropper',
80
+ 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
81
+ 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle',
82
+ 'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg',
83
+ 'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair',
84
+ 'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot',
85
+ 'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret',
86
+ 'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine',
87
+ 'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine',
88
+ 'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug',
89
+ 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod',
90
+ 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash',
91
+ 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
92
+ 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
93
+ 'food_processor', 'football_(American)', 'football_helmet',
94
+ 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
95
+ 'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge',
96
+ 'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose',
97
+ 'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin',
98
+ 'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger',
99
+ 'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove',
100
+ 'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart',
101
+ 'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater',
102
+ 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
103
+ 'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun',
104
+ 'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger',
105
+ 'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass',
106
+ 'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle',
107
+ 'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil',
108
+ 'headband', 'headboard', 'headlight', 'headscarf', 'headset',
109
+ 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
110
+ 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
111
+ 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
112
+ 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
113
+ 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
114
+ 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
115
+ 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
116
+ 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
117
+ 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
118
+ 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
119
+ 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
120
+ 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
121
+ 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
122
+ 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
123
+ 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
124
+ 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade',
125
+ 'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
126
+ 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
127
+ 'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat',
128
+ 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
129
+ 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange',
130
+ 'manger', 'manhole', 'map', 'marker', 'martini', 'mascot',
131
+ 'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)',
132
+ 'matchbox', 'mattress', 'measuring_cup', 'measuring_stick',
133
+ 'meatball', 'medicine', 'melon', 'microphone', 'microscope',
134
+ 'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake',
135
+ 'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)',
136
+ 'money', 'monitor_(computer_equipment) computer_monitor', 'monkey',
137
+ 'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle',
138
+ 'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad',
139
+ 'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument',
140
+ 'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle',
141
+ 'nest', 'newspaper', 'newsstand', 'nightshirt',
142
+ 'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook',
143
+ 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
144
+ 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
145
+ 'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven',
146
+ 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
147
+ 'padlock', 'paintbrush', 'painting', 'pajamas', 'palette',
148
+ 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
149
+ 'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
150
+ 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
151
+ 'parasol', 'parchment', 'parka', 'parking_meter', 'parrot',
152
+ 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
153
+ 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
154
+ 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
155
+ 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
156
+ 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
157
+ 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
158
+ 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
159
+ 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
160
+ 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
161
+ 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
162
+ 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
163
+ 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
164
+ 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
165
+ 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
166
+ 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
167
+ 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
168
+ 'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller',
169
+ 'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin',
170
+ 'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt',
171
+ 'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver',
172
+ 'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry',
173
+ 'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
174
+ 'recliner', 'record_player', 'reflector', 'remote_control',
175
+ 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
176
+ 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
177
+ 'rolling_pin', 'root_beer', 'router_(computer_equipment)',
178
+ 'rubber_band', 'runner_(carpet)', 'plastic_bag',
179
+ 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
180
+ 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
181
+ 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
182
+ 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
183
+ 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
184
+ 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
185
+ 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
186
+ 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
187
+ 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
188
+ 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
189
+ 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
190
+ 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
191
+ 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
192
+ 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
193
+ 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
194
+ 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
195
+ 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
196
+ 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
197
+ 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
198
+ 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
199
+ 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
200
+ 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
201
+ 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
202
+ 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew',
203
+ 'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove',
204
+ 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
205
+ 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
206
+ 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
207
+ 'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants',
208
+ 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit',
209
+ 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
210
+ 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
211
+ 'tambourine', 'army_tank', 'tank_(storage_vessel)',
212
+ 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
213
+ 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
214
+ 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
215
+ 'telephone_pole', 'telephoto_lens', 'television_camera',
216
+ 'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
217
+ 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
218
+ 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
219
+ 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
220
+ 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
221
+ 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
222
+ 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
223
+ 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
224
+ 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
225
+ 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod',
226
+ 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban',
227
+ 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
228
+ 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
229
+ 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
230
+ 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
231
+ 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
232
+ 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
233
+ 'washbasin', 'automatic_washer', 'watch', 'water_bottle',
234
+ 'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
235
+ 'water_gun', 'water_scooter', 'water_ski', 'water_tower',
236
+ 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
237
+ 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
238
+ 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
239
+ 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
240
+ 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
241
+ 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
242
+ 'yoke_(animal_equipment)', 'zebra', 'zucchini')
ext/meta/sam_meta.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ meta_dict = {
2
+ 'vit_h': dict(
3
+ encoder_embed_dim=1280,
4
+ encoder_depth=32,
5
+ encoder_num_heads=16,
6
+ encoder_global_attn_indexes=[7, 15, 23, 31],
7
+ # common
8
+ prompt_embed_dim=256,
9
+ image_size=1024,
10
+ vit_patch_size=16,
11
+ image_embedding_size=64
12
+ ),
13
+ 'vit_l': dict(
14
+ encoder_embed_dim=1024,
15
+ encoder_depth=24,
16
+ encoder_num_heads=16,
17
+ encoder_global_attn_indexes=[5, 11, 17, 23],
18
+ # common
19
+ prompt_embed_dim=256,
20
+ image_size=1024,
21
+ vit_patch_size=16,
22
+ image_embedding_size=64
23
+ ),
24
+ 'vit_b': dict(
25
+ encoder_embed_dim=768,
26
+ encoder_depth=12,
27
+ encoder_num_heads=12,
28
+ encoder_global_attn_indexes=[2, 5, 8, 11],
29
+ # common
30
+ prompt_embed_dim=256,
31
+ image_size=1024,
32
+ vit_patch_size=16,
33
+ image_embedding_size=64
34
+ )
35
+ }
36
+
37
+ checkpoint_dict = {
38
+ 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
39
+ 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
40
+ 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
41
+ }
ext/open_clip/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .coca_model import CoCa
2
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype
8
+ from .openai import load_openai_model, list_openai_models
9
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
10
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
11
+ from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
12
+ from .tokenizer import SimpleTokenizer, tokenize, decode
13
+ from .transform import image_transform, AugmentationCfg
14
+ from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
15
+ from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
ext/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
ext/open_clip/coca_model.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StoppingCriteriaList
27
+ )
28
+
29
+ GENERATION_TYPES = {
30
+ "top_k": TopKLogitsWarper,
31
+ "top_p": TopPLogitsWarper,
32
+ "beam_search": "beam_search"
33
+ }
34
+ _has_transformers = True
35
+ except ImportError as e:
36
+ GENERATION_TYPES = {
37
+ "top_k": None,
38
+ "top_p": None,
39
+ "beam_search": "beam_search"
40
+ }
41
+ _has_transformers = False
42
+
43
+
44
+ @dataclass
45
+ class MultimodalCfg(CLIPTextCfg):
46
+ mlp_ratio: int = 4
47
+ dim_head: int = 64
48
+ heads: int = 8
49
+ n_queries: int = 256
50
+ attn_pooler_heads: int = 8
51
+
52
+
53
+ def _build_text_decoder_tower(
54
+ embed_dim,
55
+ multimodal_cfg,
56
+ quick_gelu: bool = False,
57
+ cast_dtype: Optional[torch.dtype] = None,
58
+ ):
59
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
60
+ act_layer = QuickGELU if quick_gelu else nn.GELU
61
+ norm_layer = (
62
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
63
+ )
64
+
65
+ decoder = MultimodalTransformer(
66
+ context_length=multimodal_cfg.context_length,
67
+ width=multimodal_cfg.width,
68
+ heads=multimodal_cfg.heads,
69
+ layers=multimodal_cfg.layers,
70
+ ls_init_value=multimodal_cfg.ls_init_value,
71
+ output_dim=embed_dim,
72
+ act_layer=act_layer,
73
+ norm_layer=norm_layer,
74
+ )
75
+
76
+ return decoder
77
+
78
+
79
+ class CoCa(nn.Module):
80
+ def __init__(
81
+ self,
82
+ embed_dim,
83
+ multimodal_cfg: MultimodalCfg,
84
+ text_cfg: CLIPTextCfg,
85
+ vision_cfg: CLIPVisionCfg,
86
+ quick_gelu: bool = False,
87
+ cast_dtype: Optional[torch.dtype] = None,
88
+ pad_id: int = 0,
89
+ ):
90
+ super().__init__()
91
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
92
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
94
+
95
+ self.text = _build_text_tower(
96
+ embed_dim=embed_dim,
97
+ text_cfg=text_cfg,
98
+ quick_gelu=quick_gelu,
99
+ cast_dtype=cast_dtype,
100
+ )
101
+
102
+ vocab_size = (
103
+ text_cfg.vocab_size # for hf models
104
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
105
+ else text_cfg.vocab_size
106
+ )
107
+
108
+ self.visual = _build_vision_tower(
109
+ embed_dim=embed_dim,
110
+ vision_cfg=vision_cfg,
111
+ quick_gelu=quick_gelu,
112
+ cast_dtype=cast_dtype,
113
+ )
114
+
115
+ self.text_decoder = _build_text_decoder_tower(
116
+ vocab_size,
117
+ multimodal_cfg=multimodal_cfg,
118
+ quick_gelu=quick_gelu,
119
+ cast_dtype=cast_dtype,
120
+ )
121
+
122
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
123
+ self.pad_id = pad_id
124
+
125
+ @torch.jit.ignore
126
+ def set_grad_checkpointing(self, enable=True):
127
+ self.visual.set_grad_checkpointing(enable)
128
+ self.text.set_grad_checkpointing(enable)
129
+ self.text_decoder.set_grad_checkpointing(enable)
130
+
131
+ def _encode_image(self, images, normalize=True):
132
+ image_latent, tokens_embs = self.visual(images)
133
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
134
+ return image_latent, tokens_embs
135
+
136
+ def _encode_text(self, text, normalize=True, embed_cls=True):
137
+ text = text[:, :-1] if embed_cls else text # make space for CLS token
138
+ text_latent, token_emb = self.text(text)
139
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
140
+ return text_latent, token_emb
141
+
142
+ def encode_image(self, images, normalize=True):
143
+ image_latent, _ = self._encode_image(images, normalize=normalize)
144
+ return image_latent
145
+
146
+ def encode_text(self, text, normalize=True, embed_cls=True):
147
+ text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
148
+ return text_latent
149
+
150
+ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
151
+ text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
152
+ if image_latent is None or image_embs is None:
153
+ image_latent, image_embs = self._encode_image(image)
154
+
155
+ # TODO: add assertion to avoid bugs?
156
+ labels = text[:, -token_embs.shape[1]:]
157
+
158
+ logits = self.text_decoder(image_embs, token_embs)
159
+ return {
160
+ "image_features": image_latent,
161
+ "text_features": text_latent,
162
+ "logits": logits,
163
+ "labels": labels,
164
+ "logit_scale": self.logit_scale.exp()
165
+ }
166
+
167
+ def generate(
168
+ self,
169
+ image,
170
+ text=None,
171
+ seq_len=30,
172
+ max_seq_len=77,
173
+ temperature=1.,
174
+ generation_type="beam_search",
175
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
176
+ top_k=1, # keeps the top_k most probable tokens
177
+ pad_token_id=None,
178
+ eos_token_id=None,
179
+ sot_token_id=None,
180
+ num_beams=6,
181
+ num_beam_groups=3,
182
+ min_seq_len=5,
183
+ stopping_criteria=None,
184
+ repetition_penalty=1.0,
185
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
186
+ ):
187
+ # taking many ideas and components from HuggingFace GenerationMixin
188
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
189
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
190
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
191
+
192
+ with torch.no_grad():
193
+ sot_token_id = 49406 if sot_token_id is None else sot_token_id
194
+ eos_token_id = 49407 if eos_token_id is None else eos_token_id
195
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
196
+ logit_processor = LogitsProcessorList(
197
+ [
198
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
199
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
200
+ ]
201
+ )
202
+
203
+ if stopping_criteria is None:
204
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
205
+
206
+ stopping_criteria = StoppingCriteriaList(
207
+ stopping_criteria
208
+ )
209
+
210
+ device = image.device
211
+
212
+ if generation_type == "beam_search":
213
+ output = self._generate_beamsearch(
214
+ image_inputs = image,
215
+ pad_token_id=pad_token_id,
216
+ eos_token_id=eos_token_id,
217
+ sot_token_id=sot_token_id,
218
+ num_beams=num_beams,
219
+ num_beam_groups=num_beam_groups,
220
+ min_seq_len=min_seq_len,
221
+ stopping_criteria=stopping_criteria,
222
+ logit_processor=logit_processor,
223
+ )
224
+ if fixed_output_length and output.shape[1] < seq_len:
225
+ return torch.cat(
226
+ (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
227
+ dim=1
228
+ )
229
+ return output
230
+
231
+ elif generation_type == "top_p":
232
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
233
+ elif generation_type == "top_k":
234
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
235
+ else:
236
+ raise ValueError(
237
+ f"generation_type has to be one of "
238
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
239
+ )
240
+
241
+ image_latent, image_embs = self._encode_image(image)
242
+
243
+ if text is None:
244
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
245
+
246
+ was_training = self.training
247
+ num_dims = len(text.shape)
248
+
249
+ if num_dims == 1:
250
+ text = text[None, :]
251
+
252
+ cur_len = text.shape[1]
253
+ self.eval()
254
+ out = text
255
+
256
+ while True:
257
+ x = out[:, -max_seq_len:]
258
+ cur_len = x.shape[1]
259
+ logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
260
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
261
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
262
+
263
+ if mask.all():
264
+ if not fixed_output_length:
265
+ break
266
+ else:
267
+ logits = logits[~mask, :]
268
+ filtered_logits = logit_processor(x[~mask, :], logits)
269
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
270
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
271
+
272
+ if (cur_len + 1 == seq_len):
273
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
274
+ else:
275
+ sample[~mask, :] = torch.multinomial(probs, 1)
276
+
277
+ out = torch.cat((out, sample), dim=-1)
278
+
279
+ cur_len += 1
280
+
281
+ if stopping_criteria(out, None):
282
+ break
283
+
284
+ if num_dims == 1:
285
+ out = out.squeeze(0)
286
+
287
+ self.train(was_training)
288
+ return out
289
+
290
+ def _generate_beamsearch(
291
+ self,
292
+ image_inputs,
293
+ pad_token_id=None,
294
+ eos_token_id=None,
295
+ sot_token_id=None,
296
+ num_beams=6,
297
+ num_beam_groups=3,
298
+ min_seq_len=5,
299
+ stopping_criteria=None,
300
+ logit_processor=None,
301
+ logit_warper=None,
302
+ ):
303
+ device = image_inputs.device
304
+ batch_size = image_inputs.shape[0]
305
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
306
+ image_latent, image_embs = self._encode_image(image_inputs)
307
+
308
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
309
+ input_ids = input_ids * sot_token_id
310
+ beam_scorer = BeamSearchScorer(
311
+ batch_size=batch_size,
312
+ num_beams=num_beams,
313
+ device=device,
314
+ num_beam_groups=num_beam_groups,
315
+ )
316
+ # instantiate logits processors
317
+ logits_processor = (
318
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
319
+ if logit_processor is None
320
+ else logit_processor
321
+ )
322
+
323
+ batch_size = len(beam_scorer._beam_hyps)
324
+ num_beams = beam_scorer.num_beams
325
+ num_beam_groups = beam_scorer.num_beam_groups
326
+ num_sub_beams = num_beams // num_beam_groups
327
+ batch_beam_size, cur_len = input_ids.shape
328
+ beam_indices = None
329
+
330
+ if num_beams * batch_size != batch_beam_size:
331
+ raise ValueError(
332
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
333
+ )
334
+
335
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
336
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
337
+ # the same group don't produce same tokens everytime.
338
+ beam_scores[:, ::num_sub_beams] = 0
339
+ beam_scores = beam_scores.view((batch_size * num_beams,))
340
+
341
+ while True:
342
+
343
+ # predicted tokens in cur_len step
344
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
345
+
346
+ # indices which will form the beams in the next time step
347
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
348
+
349
+ # do one decoder step on all beams of all sentences in batch
350
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
351
+ outputs = self(
352
+ model_inputs['images'],
353
+ model_inputs['text'],
354
+ embed_cls=False,
355
+ image_latent=image_latent,
356
+ image_embs=image_embs
357
+ )
358
+
359
+ for beam_group_idx in range(num_beam_groups):
360
+ group_start_idx = beam_group_idx * num_sub_beams
361
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
362
+ group_size = group_end_idx - group_start_idx
363
+
364
+ # indices of beams of current group among all sentences in batch
365
+ batch_group_indices = []
366
+
367
+ for batch_idx in range(batch_size):
368
+ batch_group_indices.extend(
369
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
370
+ )
371
+ group_input_ids = input_ids[batch_group_indices]
372
+
373
+ # select outputs of beams of currentg group only
374
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
375
+ vocab_size = next_token_logits.shape[-1]
376
+
377
+ next_token_scores_processed = logits_processor(
378
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
379
+ )
380
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
381
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
382
+
383
+ # reshape for beam search
384
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
385
+
386
+ next_token_scores, next_tokens = torch.topk(
387
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
388
+ )
389
+
390
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
391
+ next_tokens = next_tokens % vocab_size
392
+
393
+ # stateless
394
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
395
+ beam_outputs = beam_scorer.process(
396
+ group_input_ids,
397
+ next_token_scores,
398
+ next_tokens,
399
+ next_indices,
400
+ pad_token_id=pad_token_id,
401
+ eos_token_id=eos_token_id,
402
+ beam_indices=process_beam_indices,
403
+ )
404
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
405
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
406
+ beam_idx = beam_outputs["next_beam_indices"]
407
+
408
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
409
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
410
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
411
+
412
+ # (beam_idx // group_size) -> batch_idx
413
+ # (beam_idx % group_size) -> offset of idx inside the group
414
+ reordering_indices[batch_group_indices] = (
415
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
416
+ )
417
+
418
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
419
+
420
+ # increase cur_len
421
+ cur_len = cur_len + 1
422
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
423
+ break
424
+
425
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
426
+ sequence_outputs = beam_scorer.finalize(
427
+ input_ids,
428
+ beam_scores,
429
+ next_tokens,
430
+ next_indices,
431
+ pad_token_id=pad_token_id,
432
+ eos_token_id=eos_token_id,
433
+ max_length=stopping_criteria.max_length,
434
+ beam_indices=final_beam_indices,
435
+ )
436
+ return sequence_outputs['sequences']
437
+
438
+
439
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
440
+ if past:
441
+ input_ids = input_ids[:, -1].unsqueeze(-1)
442
+
443
+ attention_mask = kwargs.get("attention_mask", None)
444
+ position_ids = kwargs.get("position_ids", None)
445
+
446
+ if attention_mask is not None and position_ids is None:
447
+ # create position_ids on the fly for batch generation
448
+ position_ids = attention_mask.long().cumsum(-1) - 1
449
+ position_ids.masked_fill_(attention_mask == 0, 1)
450
+ else:
451
+ position_ids = None
452
+ return {
453
+ "text": input_ids,
454
+ "images": image_inputs,
455
+ "past_key_values": past,
456
+ "position_ids": position_ids,
457
+ "attention_mask": attention_mask,
458
+ }
ext/open_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
ext/open_clip/factory.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
14
+ resize_pos_embed, get_cast_dtype
15
+ from .coca_model import CoCa
16
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
17
+ from .openai import load_openai_model
18
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
19
+ list_pretrained_tags_by_model, download_pretrained_from_hf
20
+ from .transform import image_transform, AugmentationCfg
21
+ from .tokenizer import HFTokenizer, tokenize
22
+
23
+
24
+ HF_HUB_PREFIX = 'hf-hub:'
25
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
26
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
27
+
28
+
29
+ def _natural_key(string_):
30
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
31
+
32
+
33
+ def _rescan_model_configs():
34
+ global _MODEL_CONFIGS
35
+
36
+ config_ext = ('.json',)
37
+ config_files = []
38
+ for config_path in _MODEL_CONFIG_PATHS:
39
+ if config_path.is_file() and config_path.suffix in config_ext:
40
+ config_files.append(config_path)
41
+ elif config_path.is_dir():
42
+ for ext in config_ext:
43
+ config_files.extend(config_path.glob(f'*{ext}'))
44
+
45
+ for cf in config_files:
46
+ with open(cf, 'r') as f:
47
+ model_cfg = json.load(f)
48
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
49
+ _MODEL_CONFIGS[cf.stem] = model_cfg
50
+
51
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
52
+
53
+
54
+ _rescan_model_configs() # initial populate of model config registry
55
+
56
+
57
+ def list_models():
58
+ """ enumerate available model architectures based on config files """
59
+ return list(_MODEL_CONFIGS.keys())
60
+
61
+
62
+ def add_model_config(path):
63
+ """ add model config path or file and update registry """
64
+ if not isinstance(path, Path):
65
+ path = Path(path)
66
+ _MODEL_CONFIG_PATHS.append(path)
67
+ _rescan_model_configs()
68
+
69
+
70
+ def get_model_config(model_name):
71
+ if model_name in _MODEL_CONFIGS:
72
+ return deepcopy(_MODEL_CONFIGS[model_name])
73
+ else:
74
+ return None
75
+
76
+
77
+ def get_tokenizer(model_name):
78
+ if model_name.startswith(HF_HUB_PREFIX):
79
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
80
+ else:
81
+ config = get_model_config(model_name)
82
+ tokenizer = HFTokenizer(
83
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
84
+ return tokenizer
85
+
86
+
87
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
88
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
89
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
90
+ state_dict = checkpoint['state_dict']
91
+ else:
92
+ state_dict = checkpoint
93
+ if next(iter(state_dict.items()))[0].startswith('module'):
94
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
95
+ return state_dict
96
+
97
+
98
+ def load_checkpoint(model, checkpoint_path, strict=True):
99
+ state_dict = load_state_dict(checkpoint_path)
100
+ # detect old format and make compatible with new format
101
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
102
+ state_dict = convert_to_custom_text_state_dict(state_dict)
103
+ resize_pos_embed(state_dict, model)
104
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
105
+ return incompatible_keys
106
+
107
+
108
+ def create_model(
109
+ model_name: str,
110
+ pretrained: Optional[str] = None,
111
+ precision: str = 'fp32',
112
+ device: Union[str, torch.device] = 'cpu',
113
+ jit: bool = False,
114
+ force_quick_gelu: bool = False,
115
+ force_custom_text: bool = False,
116
+ force_patch_dropout: Optional[float] = None,
117
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
118
+ pretrained_image: bool = False,
119
+ pretrained_hf: bool = True,
120
+ cache_dir: Optional[str] = None,
121
+ output_dict: Optional[bool] = None,
122
+ require_pretrained: bool = False,
123
+ logger: logging.Logger = logging,
124
+ ):
125
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
126
+ if has_hf_hub_prefix:
127
+ model_id = model_name[len(HF_HUB_PREFIX):]
128
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
129
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
130
+
131
+ with open(config_path, 'r', encoding='utf-8') as f:
132
+ config = json.load(f)
133
+ pretrained_cfg = config['preprocess_cfg']
134
+ model_cfg = config['model_cfg']
135
+ else:
136
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
137
+ checkpoint_path = None
138
+ pretrained_cfg = {}
139
+ model_cfg = None
140
+
141
+ if isinstance(device, str):
142
+ device = torch.device(device)
143
+
144
+ if pretrained and pretrained.lower() == 'openai':
145
+ logger.info(f'Loading pretrained {model_name} from OpenAI.')
146
+ model = load_openai_model(
147
+ model_name,
148
+ precision=precision,
149
+ device=device,
150
+ cache_dir=cache_dir,
151
+ )
152
+ else:
153
+ model_cfg = model_cfg or get_model_config(model_name)
154
+ if model_cfg is not None:
155
+ logger.info(f'Loaded {model_name} model config.')
156
+ else:
157
+ logger.error(f'Model config for {model_name} not found; available models {list_models()}.')
158
+ raise RuntimeError(f'Model config for {model_name} not found.')
159
+
160
+ if force_quick_gelu:
161
+ # override for use of QuickGELU on non-OpenAI transformer models
162
+ model_cfg["quick_gelu"] = True
163
+
164
+ if force_patch_dropout is not None:
165
+ # override the default patch dropout value
166
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
167
+
168
+ if force_image_size is not None:
169
+ # override model config's image size
170
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
171
+
172
+ is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
173
+ if pretrained_image:
174
+ if is_timm_model:
175
+ # pretrained weight loading for timm models set via vision_cfg
176
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
177
+ else:
178
+ assert False, 'pretrained image towers currently only supported for timm models'
179
+
180
+ # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
181
+ cast_dtype = get_cast_dtype(precision)
182
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
183
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
184
+
185
+ if custom_text:
186
+ if is_hf_model:
187
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
188
+ if "coca" in model_name:
189
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
190
+ else:
191
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
192
+ else:
193
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
194
+
195
+ if precision in ("fp16", "bf16"):
196
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
197
+ # manual mixed precision that matches original OpenAI behaviour
198
+ if is_timm_model:
199
+ # FIXME this is a bit janky, create timm based model in low-precision and
200
+ # then cast only LayerNormFp32 instances back to float32 so they don't break.
201
+ # Why? The convert_weights_to_lp fn only works with native models.
202
+ model.to(device=device, dtype=dtype)
203
+ from .transformer import LayerNormFp32
204
+ def _convert_ln(m):
205
+ if isinstance(m, LayerNormFp32):
206
+ m.weight.data = m.weight.data.to(torch.float32)
207
+ m.bias.data = m.bias.data.to(torch.float32)
208
+ model.apply(_convert_ln)
209
+ else:
210
+ model.to(device=device)
211
+ convert_weights_to_lp(model, dtype=dtype)
212
+ elif precision in ("pure_fp16", "pure_bf16"):
213
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
214
+ model.to(device=device, dtype=dtype)
215
+ else:
216
+ model.to(device=device)
217
+
218
+ pretrained_loaded = False
219
+ if pretrained:
220
+ checkpoint_path = ''
221
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
222
+ if pretrained_cfg:
223
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
224
+ elif os.path.exists(pretrained):
225
+ checkpoint_path = pretrained
226
+
227
+ if checkpoint_path:
228
+ logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
229
+ load_checkpoint(model, checkpoint_path)
230
+ else:
231
+ error_str = (
232
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
233
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
234
+ logger.warning(error_str)
235
+ raise RuntimeError(error_str)
236
+ pretrained_loaded = True
237
+ elif has_hf_hub_prefix:
238
+ logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
239
+ load_checkpoint(model, checkpoint_path)
240
+ pretrained_loaded = True
241
+
242
+ if require_pretrained and not pretrained_loaded:
243
+ # callers of create_model_from_pretrained always expect pretrained weights
244
+ raise RuntimeError(
245
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
246
+
247
+ # set image / mean metadata from pretrained_cfg if available, or use default
248
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
249
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
250
+
251
+ if output_dict and hasattr(model, "output_dict"):
252
+ model.output_dict = True
253
+
254
+ if jit:
255
+ model = torch.jit.script(model)
256
+
257
+ return model
258
+
259
+
260
+ def create_loss(args):
261
+ if args.distill:
262
+ return DistillClipLoss(
263
+ local_loss=args.local_loss,
264
+ gather_with_grad=args.gather_with_grad,
265
+ cache_labels=True,
266
+ rank=args.rank,
267
+ world_size=args.world_size,
268
+ use_horovod=args.horovod,
269
+ )
270
+ elif "coca" in args.model.lower():
271
+ return CoCaLoss(
272
+ caption_loss_weight=args.coca_caption_loss_weight,
273
+ clip_loss_weight=args.coca_contrastive_loss_weight,
274
+ local_loss=args.local_loss,
275
+ gather_with_grad=args.gather_with_grad,
276
+ cache_labels=True,
277
+ rank=args.rank,
278
+ world_size=args.world_size,
279
+ use_horovod=args.horovod,
280
+ )
281
+ return ClipLoss(
282
+ local_loss=args.local_loss,
283
+ gather_with_grad=args.gather_with_grad,
284
+ cache_labels=True,
285
+ rank=args.rank,
286
+ world_size=args.world_size,
287
+ use_horovod=args.horovod,
288
+ )
289
+
290
+
291
+ def create_model_and_transforms(
292
+ model_name: str,
293
+ pretrained: Optional[str] = None,
294
+ precision: str = 'fp32',
295
+ device: Union[str, torch.device] = 'cpu',
296
+ jit: bool = False,
297
+ force_quick_gelu: bool = False,
298
+ force_custom_text: bool = False,
299
+ force_patch_dropout: Optional[float] = None,
300
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
301
+ pretrained_image: bool = False,
302
+ pretrained_hf: bool = True,
303
+ image_mean: Optional[Tuple[float, ...]] = None,
304
+ image_std: Optional[Tuple[float, ...]] = None,
305
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
306
+ cache_dir: Optional[str] = None,
307
+ output_dict: Optional[bool] = None,
308
+ logger: logging.Logger = logging,
309
+ ):
310
+ model = create_model(
311
+ model_name,
312
+ pretrained,
313
+ precision=precision,
314
+ device=device,
315
+ jit=jit,
316
+ force_quick_gelu=force_quick_gelu,
317
+ force_custom_text=force_custom_text,
318
+ force_patch_dropout=force_patch_dropout,
319
+ force_image_size=force_image_size,
320
+ pretrained_image=pretrained_image,
321
+ pretrained_hf=pretrained_hf,
322
+ cache_dir=cache_dir,
323
+ output_dict=output_dict,
324
+ logger=logger,
325
+ )
326
+
327
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
328
+ image_std = image_std or getattr(model.visual, 'image_std', None)
329
+ preprocess_train = image_transform(
330
+ model.visual.image_size,
331
+ is_train=True,
332
+ mean=image_mean,
333
+ std=image_std,
334
+ aug_cfg=aug_cfg,
335
+ )
336
+ preprocess_val = image_transform(
337
+ model.visual.image_size,
338
+ is_train=False,
339
+ mean=image_mean,
340
+ std=image_std,
341
+ )
342
+
343
+ return model, preprocess_train, preprocess_val
344
+
345
+
346
+ def create_model_from_pretrained(
347
+ model_name: str,
348
+ pretrained: Optional[str] = None,
349
+ precision: str = 'fp32',
350
+ device: Union[str, torch.device] = 'cpu',
351
+ jit: bool = False,
352
+ force_quick_gelu: bool = False,
353
+ force_custom_text: bool = False,
354
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
355
+ return_transform: bool = True,
356
+ image_mean: Optional[Tuple[float, ...]] = None,
357
+ image_std: Optional[Tuple[float, ...]] = None,
358
+ cache_dir: Optional[str] = None,
359
+ logger: logging.Logger = logging,
360
+ ):
361
+ model = create_model(
362
+ model_name,
363
+ pretrained,
364
+ precision=precision,
365
+ device=device,
366
+ jit=jit,
367
+ force_quick_gelu=force_quick_gelu,
368
+ force_custom_text=force_custom_text,
369
+ force_image_size=force_image_size,
370
+ cache_dir=cache_dir,
371
+ require_pretrained=True,
372
+ logger=logger,
373
+ )
374
+
375
+ if not return_transform:
376
+ return model
377
+
378
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
379
+ image_std = image_std or getattr(model.visual, 'image_std', None)
380
+ preprocess = image_transform(
381
+ model.visual.image_size,
382
+ is_train=False,
383
+ mean=image_mean,
384
+ std=image_std,
385
+ )
386
+
387
+ return model, preprocess
ext/open_clip/generation_utils.py ADDED
File without changes
ext/open_clip/hf_configs.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ # https://huggingface.co/docs/transformers/model_doc/bert
46
+ "bert": {
47
+ "config_names": {
48
+ "context_length": "max_position_embeddings",
49
+ "vocab_size": "vocab_size",
50
+ "width": "hidden_size",
51
+ "heads": "num_attention_heads",
52
+ "layers": "num_hidden_layers",
53
+ },
54
+ "pooler": "cls_pooler",
55
+ },
56
+ }
ext/open_clip/hf_model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+ import re
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import TensorType
10
+
11
+ try:
12
+ import transformers
13
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
14
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15
+ BaseModelOutputWithPoolingAndCrossAttentions
16
+ except ImportError as e:
17
+ transformers = None
18
+
19
+
20
+ class BaseModelOutput:
21
+ pass
22
+
23
+
24
+ class PretrainedConfig:
25
+ pass
26
+
27
+ from .hf_configs import arch_dict
28
+
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+
35
+ # TODO: ?last - for gpt-like models
36
+ _POOLERS = {}
37
+
38
+
39
+ def register_pooler(cls):
40
+ """Decorator registering pooler class"""
41
+ _POOLERS[_camel2snake(cls.__name__)] = cls
42
+ return cls
43
+
44
+
45
+ @register_pooler
46
+ class MeanPooler(nn.Module):
47
+ """Mean pooling"""
48
+
49
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
50
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
51
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
52
+
53
+
54
+ @register_pooler
55
+ class MaxPooler(nn.Module):
56
+ """Max pooling"""
57
+
58
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
59
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
60
+ return masked_output.max(1).values
61
+
62
+
63
+ @register_pooler
64
+ class ClsPooler(nn.Module):
65
+ """CLS token pooling"""
66
+
67
+ def __init__(self, use_pooler_output=True):
68
+ super().__init__()
69
+ self.cls_token_position = 0
70
+ self.use_pooler_output = use_pooler_output
71
+
72
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
73
+ if (self.use_pooler_output and
74
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
75
+ (x.pooler_output is not None)
76
+ ):
77
+ return x.pooler_output
78
+
79
+ return x.last_hidden_state[:, self.cls_token_position, :]
80
+
81
+
82
+ @register_pooler
83
+ class ClsLastHiddenStatePooler(nn.Module):
84
+ """CLS token pooling
85
+ NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
86
+ """
87
+
88
+ def __init__(self):
89
+ super().__init__()
90
+ self.cls_token_position = 0
91
+
92
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
93
+ return x.last_hidden_state[:, self.cls_token_position, :]
94
+
95
+
96
+ class HFTextEncoder(nn.Module):
97
+ """HuggingFace model adapter"""
98
+ output_tokens: torch.jit.Final[bool]
99
+
100
+ def __init__(
101
+ self,
102
+ model_name_or_path: str,
103
+ output_dim: int,
104
+ config: PretrainedConfig = None,
105
+ pooler_type: str = None,
106
+ proj: str = None,
107
+ pretrained: bool = True,
108
+ output_tokens: bool = False,
109
+ ):
110
+ super().__init__()
111
+ self.output_tokens = output_tokens
112
+ self.output_dim = output_dim
113
+
114
+ # TODO: find better way to get this information
115
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
116
+
117
+ if transformers is None:
118
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
119
+ if config is None:
120
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
121
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
122
+ AutoModel.from_config, self.config)
123
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
124
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
125
+ self.transformer = create_func(model_args)
126
+ self.transformer = self.transformer.encoder
127
+ else:
128
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
129
+ else:
130
+ self.config = config
131
+ self.transformer = AutoModel.from_config(config)
132
+ if pooler_type is None: # get default arch pooler
133
+ pooler_type = (arch_dict[self.config.model_type]["pooler"])
134
+
135
+ # FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
136
+ self.vocab_size = getattr(self.config, 'vocab_size', 0)
137
+ self.context_length = getattr(self.config, 'max_position_embeddings', 0)
138
+
139
+ self.pooler = _POOLERS[pooler_type]()
140
+
141
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
142
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
143
+ self.proj = nn.Identity()
144
+ elif proj == 'linear':
145
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
146
+ elif proj == 'mlp':
147
+ hidden_size = (d_model + output_dim) // 2
148
+ self.proj = nn.Sequential(
149
+ nn.Linear(d_model, hidden_size, bias=False),
150
+ nn.GELU(),
151
+ nn.Linear(hidden_size, output_dim, bias=False),
152
+ )
153
+
154
+ def forward(self, x: TensorType):
155
+ attn_mask = (x != self.config.pad_token_id).long()
156
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
157
+ pooled_out = self.pooler(out, attn_mask)
158
+ projected = self.proj(pooled_out)
159
+
160
+ seq_len = out.last_hidden_state.shape[1]
161
+ tokens = (
162
+ out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
163
+ if type(self.pooler) == ClsPooler
164
+ else out.last_hidden_state
165
+ )
166
+
167
+ if self.output_tokens:
168
+ return projected, tokens
169
+ return projected
170
+
171
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
172
+ if not unlocked_layers: # full freezing
173
+ for n, p in self.transformer.named_parameters():
174
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
175
+ return
176
+
177
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
178
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
179
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
180
+ embeddings = getattr(
181
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
182
+ modules = [embeddings, *layer_list][:-unlocked_layers]
183
+ # freeze layers
184
+ for module in modules:
185
+ for n, p in module.named_parameters():
186
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
187
+
188
+ @torch.jit.ignore
189
+ def set_grad_checkpointing(self, enable=True):
190
+ self.transformer.gradient_checkpointing_enable()
191
+
192
+ def init_parameters(self):
193
+ pass
ext/open_clip/loss.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ try:
6
+ import torch.distributed.nn
7
+ from torch import distributed as dist
8
+
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+
19
+ def gather_features(
20
+ image_features,
21
+ text_features,
22
+ local_loss=False,
23
+ gather_with_grad=False,
24
+ rank=0,
25
+ world_size=1,
26
+ use_horovod=False
27
+ ):
28
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
29
+ if use_horovod:
30
+ assert hvd is not None, 'Please install horovod'
31
+ if gather_with_grad:
32
+ all_image_features = hvd.allgather(image_features)
33
+ all_text_features = hvd.allgather(text_features)
34
+ else:
35
+ with torch.no_grad():
36
+ all_image_features = hvd.allgather(image_features)
37
+ all_text_features = hvd.allgather(text_features)
38
+ if not local_loss:
39
+ # ensure grads for local rank when all_* features don't have a gradient
40
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
41
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
42
+ gathered_image_features[rank] = image_features
43
+ gathered_text_features[rank] = text_features
44
+ all_image_features = torch.cat(gathered_image_features, dim=0)
45
+ all_text_features = torch.cat(gathered_text_features, dim=0)
46
+ else:
47
+ # We gather tensors from all gpus
48
+ if gather_with_grad:
49
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
50
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
51
+ else:
52
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
53
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
54
+ dist.all_gather(gathered_image_features, image_features)
55
+ dist.all_gather(gathered_text_features, text_features)
56
+ if not local_loss:
57
+ # ensure grads for local rank when all_* features don't have a gradient
58
+ gathered_image_features[rank] = image_features
59
+ gathered_text_features[rank] = text_features
60
+ all_image_features = torch.cat(gathered_image_features, dim=0)
61
+ all_text_features = torch.cat(gathered_text_features, dim=0)
62
+
63
+ return all_image_features, all_text_features
64
+
65
+
66
+ class ClipLoss(nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ local_loss=False,
71
+ gather_with_grad=False,
72
+ cache_labels=False,
73
+ rank=0,
74
+ world_size=1,
75
+ use_horovod=False,
76
+ ):
77
+ super().__init__()
78
+ self.local_loss = local_loss
79
+ self.gather_with_grad = gather_with_grad
80
+ self.cache_labels = cache_labels
81
+ self.rank = rank
82
+ self.world_size = world_size
83
+ self.use_horovod = use_horovod
84
+
85
+ # cache state
86
+ self.prev_num_logits = 0
87
+ self.labels = {}
88
+
89
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
90
+ # calculated ground-truth and cache if enabled
91
+ if self.prev_num_logits != num_logits or device not in self.labels:
92
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
93
+ if self.world_size > 1 and self.local_loss:
94
+ labels = labels + num_logits * self.rank
95
+ if self.cache_labels:
96
+ self.labels[device] = labels
97
+ self.prev_num_logits = num_logits
98
+ else:
99
+ labels = self.labels[device]
100
+ return labels
101
+
102
+ def get_logits(self, image_features, text_features, logit_scale):
103
+ if self.world_size > 1:
104
+ all_image_features, all_text_features = gather_features(
105
+ image_features, text_features,
106
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
107
+
108
+ if self.local_loss:
109
+ logits_per_image = logit_scale * image_features @ all_text_features.T
110
+ logits_per_text = logit_scale * text_features @ all_image_features.T
111
+ else:
112
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
113
+ logits_per_text = logits_per_image.T
114
+ else:
115
+ logits_per_image = logit_scale * image_features @ text_features.T
116
+ logits_per_text = logit_scale * text_features @ image_features.T
117
+
118
+ return logits_per_image, logits_per_text
119
+
120
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
121
+ device = image_features.device
122
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
123
+
124
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
125
+
126
+ total_loss = (
127
+ F.cross_entropy(logits_per_image, labels) +
128
+ F.cross_entropy(logits_per_text, labels)
129
+ ) / 2
130
+
131
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
132
+
133
+
134
+ class CoCaLoss(ClipLoss):
135
+ def __init__(
136
+ self,
137
+ caption_loss_weight,
138
+ clip_loss_weight,
139
+ pad_id=0, # pad_token for open_clip custom tokenizer
140
+ local_loss=False,
141
+ gather_with_grad=False,
142
+ cache_labels=False,
143
+ rank=0,
144
+ world_size=1,
145
+ use_horovod=False,
146
+ ):
147
+ super().__init__(
148
+ local_loss=local_loss,
149
+ gather_with_grad=gather_with_grad,
150
+ cache_labels=cache_labels,
151
+ rank=rank,
152
+ world_size=world_size,
153
+ use_horovod=use_horovod
154
+ )
155
+
156
+ self.clip_loss_weight = clip_loss_weight
157
+ self.caption_loss_weight = caption_loss_weight
158
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
159
+
160
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
161
+
162
+ clip_loss = torch.tensor(0)
163
+
164
+ if self.clip_loss_weight:
165
+ clip_loss = super().forward(image_features, text_features, logit_scale)
166
+ clip_loss = self.clip_loss_weight * clip_loss
167
+
168
+ caption_loss = self.caption_loss(
169
+ logits.permute(0, 2, 1),
170
+ labels,
171
+ )
172
+ caption_loss = caption_loss * self.caption_loss_weight
173
+
174
+ if output_dict:
175
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
176
+
177
+ return clip_loss, caption_loss
178
+
179
+
180
+ class DistillClipLoss(ClipLoss):
181
+
182
+ def dist_loss(self, teacher_logits, student_logits):
183
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
184
+
185
+ def forward(
186
+ self,
187
+ image_features,
188
+ text_features,
189
+ logit_scale,
190
+ dist_image_features,
191
+ dist_text_features,
192
+ dist_logit_scale,
193
+ output_dict=False,
194
+ ):
195
+ logits_per_image, logits_per_text = \
196
+ self.get_logits(image_features, text_features, logit_scale)
197
+
198
+ dist_logits_per_image, dist_logits_per_text = \
199
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
200
+
201
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
202
+
203
+ contrastive_loss = (
204
+ F.cross_entropy(logits_per_image, labels) +
205
+ F.cross_entropy(logits_per_text, labels)
206
+ ) / 2
207
+
208
+ distill_loss = (
209
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
210
+ self.dist_loss(dist_logits_per_text, logits_per_text)
211
+ ) / 2
212
+
213
+ if output_dict:
214
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
215
+
216
+ return contrastive_loss, distill_loss
ext/open_clip/model.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from .hf_model import HFTextEncoder
17
+ from .modified_resnet import ModifiedResNet
18
+ from .timm_model import TimmModel
19
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
20
+ from .utils import to_2tuple
21
+
22
+
23
+ @dataclass
24
+ class CLIPVisionCfg:
25
+ layers: Union[Tuple[int, int, int, int], int] = 12
26
+ width: int = 768
27
+ head_width: int = 64
28
+ mlp_ratio: float = 4.0
29
+ patch_size: int = 16
30
+ image_size: Union[Tuple[int, int], int] = 224
31
+
32
+ ls_init_value: Optional[float] = None # layer scale initial value
33
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
34
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
35
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
36
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
37
+ n_queries: int = 256 # n_queries for attentional pooler
38
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
39
+ output_tokens: bool = False
40
+
41
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
42
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
43
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
44
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
45
+ timm_proj_bias: bool = False # enable bias final projection
46
+ timm_drop: float = 0. # head dropout
47
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
48
+
49
+
50
+ @dataclass
51
+ class CLIPTextCfg:
52
+ context_length: int = 77
53
+ vocab_size: int = 49408
54
+ width: int = 512
55
+ heads: int = 8
56
+ layers: int = 12
57
+ ls_init_value: Optional[float] = None # layer scale initial value
58
+ hf_model_name: str = None
59
+ hf_tokenizer_name: str = None
60
+ hf_model_pretrained: bool = True
61
+ proj: str = 'mlp'
62
+ pooler_type: str = 'mean_pooler'
63
+ embed_cls: bool = False
64
+ pad_id: int = 0
65
+ output_tokens: bool = False
66
+
67
+
68
+ def get_cast_dtype(precision: str):
69
+ cast_dtype = None
70
+ if precision == 'bf16':
71
+ cast_dtype = torch.bfloat16
72
+ elif precision == 'fp16':
73
+ cast_dtype = torch.float16
74
+ return cast_dtype
75
+
76
+
77
+ def get_input_dtype(precision: str):
78
+ input_dtype = None
79
+ if precision in ('bf16', 'pure_bf16'):
80
+ input_dtype = torch.bfloat16
81
+ elif precision in ('fp16', 'pure_fp16'):
82
+ input_dtype = torch.float16
83
+ return input_dtype
84
+
85
+
86
+ def _build_vision_tower(
87
+ embed_dim: int,
88
+ vision_cfg: CLIPVisionCfg,
89
+ quick_gelu: bool = False,
90
+ cast_dtype: Optional[torch.dtype] = None
91
+ ):
92
+ if isinstance(vision_cfg, dict):
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
94
+
95
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
96
+ # memory efficient in recent PyTorch releases (>= 1.10).
97
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
98
+ act_layer = QuickGELU if quick_gelu else nn.GELU
99
+
100
+ if vision_cfg.timm_model_name:
101
+ visual = TimmModel(
102
+ vision_cfg.timm_model_name,
103
+ pretrained=vision_cfg.timm_model_pretrained,
104
+ pool=vision_cfg.timm_pool,
105
+ proj=vision_cfg.timm_proj,
106
+ proj_bias=vision_cfg.timm_proj_bias,
107
+ drop=vision_cfg.timm_drop,
108
+ drop_path=vision_cfg.timm_drop_path,
109
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
110
+ embed_dim=embed_dim,
111
+ image_size=vision_cfg.image_size,
112
+ )
113
+ elif isinstance(vision_cfg.layers, (tuple, list)):
114
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
115
+ visual = ModifiedResNet(
116
+ layers=vision_cfg.layers,
117
+ output_dim=embed_dim,
118
+ heads=vision_heads,
119
+ image_size=vision_cfg.image_size,
120
+ width=vision_cfg.width,
121
+ )
122
+ else:
123
+ vision_heads = vision_cfg.width // vision_cfg.head_width
124
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
125
+ visual = VisionTransformer(
126
+ image_size=vision_cfg.image_size,
127
+ patch_size=vision_cfg.patch_size,
128
+ width=vision_cfg.width,
129
+ layers=vision_cfg.layers,
130
+ heads=vision_heads,
131
+ mlp_ratio=vision_cfg.mlp_ratio,
132
+ ls_init_value=vision_cfg.ls_init_value,
133
+ patch_dropout=vision_cfg.patch_dropout,
134
+ input_patchnorm=vision_cfg.input_patchnorm,
135
+ global_average_pool=vision_cfg.global_average_pool,
136
+ attentional_pool=vision_cfg.attentional_pool,
137
+ n_queries=vision_cfg.n_queries,
138
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
139
+ output_tokens=vision_cfg.output_tokens,
140
+ output_dim=embed_dim,
141
+ act_layer=act_layer,
142
+ norm_layer=norm_layer,
143
+ )
144
+
145
+ return visual
146
+
147
+
148
+ def _build_text_tower(
149
+ embed_dim: int,
150
+ text_cfg: CLIPTextCfg,
151
+ quick_gelu: bool = False,
152
+ cast_dtype: Optional[torch.dtype] = None,
153
+ ):
154
+ if isinstance(text_cfg, dict):
155
+ text_cfg = CLIPTextCfg(**text_cfg)
156
+
157
+ if text_cfg.hf_model_name:
158
+ text = HFTextEncoder(
159
+ text_cfg.hf_model_name,
160
+ output_dim=embed_dim,
161
+ proj=text_cfg.proj,
162
+ pooler_type=text_cfg.pooler_type,
163
+ pretrained=text_cfg.hf_model_pretrained,
164
+ output_tokens=text_cfg.output_tokens,
165
+ )
166
+ else:
167
+ act_layer = QuickGELU if quick_gelu else nn.GELU
168
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
169
+
170
+ text = TextTransformer(
171
+ context_length=text_cfg.context_length,
172
+ vocab_size=text_cfg.vocab_size,
173
+ width=text_cfg.width,
174
+ heads=text_cfg.heads,
175
+ layers=text_cfg.layers,
176
+ ls_init_value=text_cfg.ls_init_value,
177
+ output_dim=embed_dim,
178
+ embed_cls=text_cfg.embed_cls,
179
+ output_tokens=text_cfg.output_tokens,
180
+ pad_id=text_cfg.pad_id,
181
+ act_layer=act_layer,
182
+ norm_layer=norm_layer,
183
+ )
184
+ return text
185
+
186
+
187
+ class CLIP(nn.Module):
188
+ output_dict: torch.jit.Final[bool]
189
+
190
+ def __init__(
191
+ self,
192
+ embed_dim: int,
193
+ vision_cfg: CLIPVisionCfg,
194
+ text_cfg: CLIPTextCfg,
195
+ quick_gelu: bool = False,
196
+ cast_dtype: Optional[torch.dtype] = None,
197
+ output_dict: bool = False,
198
+ ):
199
+ super().__init__()
200
+ self.output_dict = output_dict
201
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
202
+
203
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
204
+ self.transformer = text.transformer
205
+ self.context_length = text.context_length
206
+ self.vocab_size = text.vocab_size
207
+ self.token_embedding = text.token_embedding
208
+ self.positional_embedding = text.positional_embedding
209
+ self.ln_final = text.ln_final
210
+ self.text_projection = text.text_projection
211
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
212
+
213
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
214
+
215
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
216
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
217
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
218
+
219
+ @torch.jit.ignore
220
+ def set_grad_checkpointing(self, enable=True):
221
+ self.visual.set_grad_checkpointing(enable)
222
+ self.transformer.grad_checkpointing = enable
223
+
224
+ def encode_image(self, image, normalize: bool = False):
225
+ features = self.visual(image)
226
+ return F.normalize(features, dim=-1) if normalize else features
227
+
228
+ def encode_text(self, text, normalize: bool = False):
229
+ cast_dtype = self.transformer.get_cast_dtype()
230
+
231
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
232
+
233
+ x = x + self.positional_embedding.to(cast_dtype)
234
+ x = x.permute(1, 0, 2) # NLD -> LND
235
+ x = self.transformer(x, attn_mask=self.attn_mask)
236
+ x = x.permute(1, 0, 2) # LND -> NLD
237
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
238
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
239
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
240
+ return F.normalize(x, dim=-1) if normalize else x
241
+
242
+ def forward(
243
+ self,
244
+ image: Optional[torch.Tensor] = None,
245
+ text: Optional[torch.Tensor] = None,
246
+ ):
247
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
248
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
249
+ if self.output_dict:
250
+ return {
251
+ "image_features": image_features,
252
+ "text_features": text_features,
253
+ "logit_scale": self.logit_scale.exp()
254
+ }
255
+ return image_features, text_features, self.logit_scale.exp()
256
+
257
+
258
+ class CustomTextCLIP(nn.Module):
259
+ output_dict: torch.jit.Final[bool]
260
+
261
+ def __init__(
262
+ self,
263
+ embed_dim: int,
264
+ vision_cfg: CLIPVisionCfg,
265
+ text_cfg: CLIPTextCfg,
266
+ quick_gelu: bool = False,
267
+ cast_dtype: Optional[torch.dtype] = None,
268
+ output_dict: bool = False,
269
+ ):
270
+ super().__init__()
271
+ self.output_dict = output_dict
272
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
273
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
274
+ self.context_length = self.text.context_length
275
+ self.vocab_size = self.text.vocab_size
276
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
277
+
278
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
279
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
280
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
281
+
282
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
283
+ self.text.lock(unlocked_layers, freeze_layer_norm)
284
+
285
+ @torch.jit.ignore
286
+ def set_grad_checkpointing(self, enable=True):
287
+ self.visual.set_grad_checkpointing(enable)
288
+ self.text.set_grad_checkpointing(enable)
289
+
290
+ def encode_image(self, image, normalize: bool = False):
291
+ features = self.visual(image)
292
+ return F.normalize(features, dim=-1) if normalize else features
293
+
294
+ def encode_text(self, text, normalize: bool = False):
295
+ features = self.text(text)
296
+ return F.normalize(features, dim=-1) if normalize else features
297
+
298
+ def forward(
299
+ self,
300
+ image: Optional[torch.Tensor] = None,
301
+ text: Optional[torch.Tensor] = None,
302
+ ):
303
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
304
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
305
+ if self.output_dict:
306
+ return {
307
+ "image_features": image_features,
308
+ "text_features": text_features,
309
+ "logit_scale": self.logit_scale.exp()
310
+ }
311
+ return image_features, text_features, self.logit_scale.exp()
312
+
313
+
314
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
315
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
316
+
317
+ def _convert_weights(l):
318
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
319
+ l.weight.data = l.weight.data.to(dtype)
320
+ if l.bias is not None:
321
+ l.bias.data = l.bias.data.to(dtype)
322
+
323
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
324
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
325
+ tensor = getattr(l, attr)
326
+ if tensor is not None:
327
+ tensor.data = tensor.data.to(dtype)
328
+
329
+ if isinstance(l, (CLIP, TextTransformer)):
330
+ # convert text nn.Parameter projections
331
+ attr = getattr(l, "text_projection", None)
332
+ if attr is not None:
333
+ attr.data = attr.data.to(dtype)
334
+
335
+ if isinstance(l, VisionTransformer):
336
+ # convert vision nn.Parameter projections
337
+ attr = getattr(l, "proj", None)
338
+ if attr is not None:
339
+ attr.data = attr.data.to(dtype)
340
+
341
+ model.apply(_convert_weights)
342
+
343
+
344
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
345
+
346
+
347
+ # used to maintain checkpoint compatibility
348
+ def convert_to_custom_text_state_dict(state_dict: dict):
349
+ if 'text_projection' in state_dict:
350
+ # old format state_dict, move text tower -> .text
351
+ new_state_dict = {}
352
+ for k, v in state_dict.items():
353
+ if any(k.startswith(p) for p in (
354
+ 'text_projection',
355
+ 'positional_embedding',
356
+ 'token_embedding',
357
+ 'transformer',
358
+ 'ln_final',
359
+ )):
360
+ k = 'text.' + k
361
+ new_state_dict[k] = v
362
+ return new_state_dict
363
+ return state_dict
364
+
365
+
366
+ def build_model_from_openai_state_dict(
367
+ state_dict: dict,
368
+ quick_gelu=True,
369
+ cast_dtype=torch.float16,
370
+ ):
371
+ vit = "visual.proj" in state_dict
372
+
373
+ if vit:
374
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
375
+ vision_layers = len(
376
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
377
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
378
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
379
+ image_size = vision_patch_size * grid_size
380
+ else:
381
+ counts: list = [
382
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
383
+ vision_layers = tuple(counts)
384
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
385
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
386
+ vision_patch_size = None
387
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
388
+ image_size = output_width * 32
389
+
390
+ embed_dim = state_dict["text_projection"].shape[1]
391
+ context_length = state_dict["positional_embedding"].shape[0]
392
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
393
+ transformer_width = state_dict["ln_final.weight"].shape[0]
394
+ transformer_heads = transformer_width // 64
395
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
396
+
397
+ vision_cfg = CLIPVisionCfg(
398
+ layers=vision_layers,
399
+ width=vision_width,
400
+ patch_size=vision_patch_size,
401
+ image_size=image_size,
402
+ )
403
+ text_cfg = CLIPTextCfg(
404
+ context_length=context_length,
405
+ vocab_size=vocab_size,
406
+ width=transformer_width,
407
+ heads=transformer_heads,
408
+ layers=transformer_layers,
409
+ )
410
+ model = CLIP(
411
+ embed_dim,
412
+ vision_cfg=vision_cfg,
413
+ text_cfg=text_cfg,
414
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
415
+ cast_dtype=cast_dtype,
416
+ )
417
+
418
+ for key in ["input_resolution", "context_length", "vocab_size"]:
419
+ state_dict.pop(key, None)
420
+
421
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
422
+ model.load_state_dict(state_dict)
423
+ return model.eval()
424
+
425
+
426
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
427
+ model.eval()
428
+ image_size = model.visual.image_size
429
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
430
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
431
+ model = torch.jit.trace_module(
432
+ model,
433
+ inputs=dict(
434
+ forward=(example_images, example_text),
435
+ encode_text=(example_text,),
436
+ encode_image=(example_images,)
437
+ ))
438
+ model.visual.image_size = image_size
439
+ return model
440
+
441
+
442
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
443
+ # Rescale the grid of position embeddings when loading from state_dict
444
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
445
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
446
+ return
447
+ grid_size = to_2tuple(model.visual.grid_size)
448
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
449
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
450
+ if new_seq_len == old_pos_embed.shape[0]:
451
+ return
452
+
453
+ if extra_tokens:
454
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
455
+ else:
456
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
457
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
458
+
459
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
460
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
461
+ pos_emb_img = F.interpolate(
462
+ pos_emb_img,
463
+ size=grid_size,
464
+ mode=interpolation,
465
+ antialias=antialias,
466
+ align_corners=False,
467
+ )
468
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
469
+ if pos_emb_tok is not None:
470
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
471
+ else:
472
+ new_pos_embed = pos_emb_img
473
+ state_dict['visual.positional_embedding'] = new_pos_embed
ext/open_clip/model_configs/EVA01-g-14-plus.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva_giant_patch14_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA01-g-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva_giant_patch14_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-B-16.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_base_patch16_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-E-14-plus.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_enormous_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1280,
14
+ "heads": 20,
15
+ "layers": 32
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-E-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_enormous_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-L-14-336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "timm_model_name": "eva02_large_patch14_clip_336",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-L-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_large_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
ext/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
ext/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }