jwyang commited on
Commit
4121bec
1 Parent(s): 520e34b

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +123 -0
  2. config.py +245 -0
  3. configs/Base-RCNN-C4.yaml +18 -0
  4. configs/Base-RCNN-FPN.yaml +42 -0
  5. configs/CLIP_fast_rcnn_R_50_C4.yaml +71 -0
  6. configs/CLIP_fast_rcnn_swin_base_C4.yaml +74 -0
  7. configs/mask_rcnn_CLIP_R_50_C4_1x.yaml +55 -0
  8. configs/mask_rcnn_R_50_C4_1x.yaml +23 -0
  9. configs/mask_rcnn_R_50_FPN_1x.yaml +23 -0
  10. datasets/README.md +140 -0
  11. datasets/custom_images/dog_and_cat.jfif +0 -0
  12. datasets/prepare_ade20k_sem_seg.py +26 -0
  13. datasets/prepare_cocofied_lvis.py +176 -0
  14. datasets/prepare_for_tests.sh +22 -0
  15. datasets/prepare_panoptic_fpn.py +116 -0
  16. detectron2/__init__.py +10 -0
  17. detectron2/__pycache__/__init__.cpython-39.pyc +0 -0
  18. detectron2/checkpoint/__init__.py +10 -0
  19. detectron2/checkpoint/__pycache__/__init__.cpython-39.pyc +0 -0
  20. detectron2/checkpoint/__pycache__/c2_model_loading.cpython-39.pyc +0 -0
  21. detectron2/checkpoint/__pycache__/catalog.cpython-39.pyc +0 -0
  22. detectron2/checkpoint/__pycache__/clip_model_loading.cpython-39.pyc +0 -0
  23. detectron2/checkpoint/__pycache__/detection_checkpoint.cpython-39.pyc +0 -0
  24. detectron2/checkpoint/c2_model_loading.py +407 -0
  25. detectron2/checkpoint/catalog.py +115 -0
  26. detectron2/checkpoint/clip_model_loading.py +415 -0
  27. detectron2/checkpoint/detection_checkpoint.py +134 -0
  28. detectron2/config/__init__.py +24 -0
  29. detectron2/config/__pycache__/__init__.cpython-39.pyc +0 -0
  30. detectron2/config/__pycache__/compat.cpython-39.pyc +0 -0
  31. detectron2/config/__pycache__/config.cpython-39.pyc +0 -0
  32. detectron2/config/__pycache__/defaults.cpython-39.pyc +0 -0
  33. detectron2/config/__pycache__/instantiate.cpython-39.pyc +0 -0
  34. detectron2/config/__pycache__/lazy.cpython-39.pyc +0 -0
  35. detectron2/config/compat.py +229 -0
  36. detectron2/config/config.py +249 -0
  37. detectron2/config/defaults.py +786 -0
  38. detectron2/config/instantiate.py +82 -0
  39. detectron2/config/lazy.py +370 -0
  40. detectron2/data/__init__.py +19 -0
  41. detectron2/data/__pycache__/__init__.cpython-39.pyc +0 -0
  42. detectron2/data/__pycache__/build.cpython-39.pyc +0 -0
  43. detectron2/data/__pycache__/catalog.cpython-39.pyc +0 -0
  44. detectron2/data/__pycache__/clip_build.cpython-39.pyc +0 -0
  45. detectron2/data/__pycache__/common.cpython-39.pyc +0 -0
  46. detectron2/data/__pycache__/dataset_mapper.cpython-39.pyc +0 -0
  47. detectron2/data/__pycache__/detection_utils.cpython-39.pyc +0 -0
  48. detectron2/data/build.py +536 -0
  49. detectron2/data/catalog.py +236 -0
  50. detectron2/data/clip_build.py +158 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import logging
4
+ import os
5
+ import gradio as gr
6
+ import numpy as np
7
+ import cv2
8
+ import torch
9
+ import torch.nn as nn
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13
+ from timm.data import create_transform
14
+ from config import get_config
15
+
16
+ from collections import OrderedDict
17
+
18
+ import detectron2.utils.comm as comm
19
+ from detectron2.checkpoint import DetectionCheckpointer
20
+ from detectron2.config import get_cfg
21
+ from detectron2.data import MetadataCatalog
22
+ from detectron2.engine import DefaultTrainer as Trainer
23
+ from detectron2.engine import default_argument_parser, default_setup, hooks, launch
24
+ from detectron2.evaluation import (
25
+ CityscapesInstanceEvaluator,
26
+ CityscapesSemSegEvaluator,
27
+ COCOEvaluator,
28
+ COCOPanopticEvaluator,
29
+ DatasetEvaluators,
30
+ LVISEvaluator,
31
+ PascalVOCDetectionEvaluator,
32
+ SemSegEvaluator,
33
+ verify_results,
34
+ FLICKR30KEvaluator,
35
+ )
36
+ from detectron2.modeling import GeneralizedRCNNWithTTA
37
+
38
+ def parse_option():
39
+ parser = argparse.ArgumentParser('RegionCLIP demo script', add_help=False)
40
+ parser.add_argument('--config-file', type=str, default="configs/CLIP_fast_rcnn_R_50_C4.yaml", metavar="FILE", help='path to config file', )
41
+ args, unparsed = parser.parse_known_args()
42
+
43
+ return args
44
+
45
+ def build_transforms(img_size, center_crop=True):
46
+ t = []
47
+ if center_crop:
48
+ size = int((256 / 224) * img_size)
49
+ t.append(
50
+ transforms.Resize(size)
51
+ )
52
+ t.append(
53
+ transforms.CenterCrop(img_size)
54
+ )
55
+ else:
56
+ t.append(
57
+ transforms.Resize(img_size)
58
+ )
59
+ t.append(transforms.ToTensor())
60
+ return transforms.Compose(t)
61
+
62
+ def setup(args):
63
+ """
64
+ Create configs and perform basic setups.
65
+ """
66
+ cfg = get_cfg()
67
+ cfg.merge_from_file(args.config_file)
68
+ cfg.freeze()
69
+ default_setup(cfg, args)
70
+ return cfg
71
+
72
+ '''
73
+ build model
74
+ '''
75
+ args = parse_option()
76
+ cfg = setup(args)
77
+
78
+ model = Trainer.build_model(cfg)
79
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
80
+ cfg.MODEL.WEIGHTS, resume=False
81
+ )
82
+ if cfg.MODEL.META_ARCHITECTURE in ['CLIPRCNN', 'CLIPFastRCNN', 'PretrainFastRCNN'] \
83
+ and cfg.MODEL.CLIP.BB_RPN_WEIGHTS is not None\
84
+ and cfg.MODEL.CLIP.CROP_REGION_TYPE == 'RPN': # load 2nd pretrained model
85
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, bb_rpn_weights=True).resume_or_load(
86
+ cfg.MODEL.CLIP.BB_RPN_WEIGHTS, resume=False
87
+ )
88
+
89
+ '''
90
+ build data transform
91
+ '''
92
+ eval_transforms = build_transforms(800, center_crop=False)
93
+ # display_transforms = build_transforms4display(960, center_crop=False)
94
+
95
+ def localize_object(image, texts):
96
+ print(texts)
97
+ img_t = eval_transforms(Image.fromarray(image).convert("RGB")) * 255
98
+
99
+ print(img_t.shape)
100
+ model.eval()
101
+ with torch.no_grad():
102
+ print(img_t[0][:10, :10])
103
+ res = model(texts, [{"image": img_t}])
104
+
105
+ return res
106
+
107
+
108
+ image = gr.inputs.Image()
109
+
110
+ gr.Interface(
111
+ description="RegionCLIP for Open-Vocabulary Object Detection",
112
+ fn=localize_object,
113
+ inputs=["image", "text"],
114
+ outputs=[
115
+ gr.outputs.Image(
116
+ type="pil",
117
+ label="grounding results"),
118
+ ],
119
+ examples=[
120
+ ["./elephants.png", "an elephant"],
121
+ ["./apple_with_ipod.jpg", "an apple"],
122
+ ],
123
+ ).launch()
config.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Unified Contrastive Learning (UniCL)
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang (jianwyan@microsoft.com)
6
+ # Based on Swin Transformer written by Zhe Liu
7
+ # --------------------------------------------------------
8
+
9
+ import os
10
+ import yaml
11
+ from yacs.config import CfgNode as CN
12
+
13
+ _C = CN()
14
+ _C.VERBOSE = False
15
+
16
+ # Base config files
17
+ _C.BASE = ['']
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Data settings
21
+ # -----------------------------------------------------------------------------
22
+ _C.DATA = CN()
23
+ # Batch size for a single GPU, could be overwritten by command line argument
24
+ _C.DATA.BATCH_SIZE = 128
25
+ # Path to dataset, could be overwritten by command line argument
26
+ _C.DATA.DATA_PATH = ''
27
+ # Dataset name
28
+ _C.DATA.DATASET = 'imagenet'
29
+ # Input image size
30
+ _C.DATA.IMG_SIZE = 224
31
+ # Interpolation to resize image (random, bilinear, bicubic)
32
+ _C.DATA.INTERPOLATION = 'bicubic'
33
+ # Use zipped dataset instead of folder dataset
34
+ # could be overwritten by command line argument
35
+ _C.DATA.ZIP_MODE = False
36
+ # Cache Data in Memory, could be overwritten by command line argument
37
+ _C.DATA.CACHE_MODE = 'part'
38
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
39
+ _C.DATA.PIN_MEMORY = True
40
+ # Number of data loading threads
41
+ _C.DATA.NUM_WORKERS = 8
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # Model settings
45
+ # -----------------------------------------------------------------------------
46
+ _C.MODEL = CN()
47
+ # Model name
48
+ _C.MODEL.NAME = ''
49
+ # Checkpoint to resume, could be overwritten by command line argument
50
+ _C.MODEL.RESUME = ''
51
+ # Number of classes, overwritten in data preparation
52
+ _C.MODEL.NUM_CLASSES = 0
53
+ # Label Smoothing
54
+ _C.MODEL.LABEL_SMOOTHING = 0.1
55
+ # Whether load pretrained model
56
+ _C.MODEL.PRETRAINED = ''
57
+ # Projection dimension
58
+ _C.MODEL.DIM_PROJECTION = 512
59
+ # Mode specific
60
+ _C.MODEL.SPEC = CN(new_allowed=True)
61
+ # -----------------------------------------------------------------------------
62
+ # Build Image Encoder
63
+ # -----------------------------------------------------------------------------
64
+ _C.MODEL.IMAGE_ENCODER = CN()
65
+ # Image encoder type
66
+ _C.MODEL.IMAGE_ENCODER.TYPE = 'swin'
67
+ # Input image size
68
+ _C.MODEL.IMAGE_ENCODER.IMG_SIZE = 224
69
+ # Dropout rate
70
+ _C.MODEL.IMAGE_ENCODER.DROP_RATE = 0.0
71
+ # Drop path rate
72
+ _C.MODEL.IMAGE_ENCODER.DROP_PATH_RATE = 0.1
73
+
74
+ # Swin Transformer parameters
75
+ _C.MODEL.IMAGE_ENCODER.SWIN = CN()
76
+ _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_SIZE = 4
77
+ _C.MODEL.IMAGE_ENCODER.SWIN.IN_CHANS = 3
78
+ _C.MODEL.IMAGE_ENCODER.SWIN.EMBED_DIM = 96
79
+ _C.MODEL.IMAGE_ENCODER.SWIN.DEPTHS = [2, 2, 6, 2]
80
+ _C.MODEL.IMAGE_ENCODER.SWIN.NUM_HEADS = [3, 6, 12, 24]
81
+ _C.MODEL.IMAGE_ENCODER.SWIN.WINDOW_SIZE = 7
82
+ _C.MODEL.IMAGE_ENCODER.SWIN.MLP_RATIO = 4.
83
+ _C.MODEL.IMAGE_ENCODER.SWIN.QKV_BIAS = True
84
+ _C.MODEL.IMAGE_ENCODER.SWIN.QK_SCALE = None
85
+ _C.MODEL.IMAGE_ENCODER.SWIN.APE = False
86
+ _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_NORM = True
87
+
88
+ # FocalNet parameters
89
+ _C.MODEL.IMAGE_ENCODER.FOCAL = CN()
90
+ _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_SIZE = 4
91
+ _C.MODEL.IMAGE_ENCODER.FOCAL.IN_CHANS = 3
92
+ _C.MODEL.IMAGE_ENCODER.FOCAL.EMBED_DIM = 96
93
+ _C.MODEL.IMAGE_ENCODER.FOCAL.DEPTHS = [2, 2, 6, 2]
94
+ _C.MODEL.IMAGE_ENCODER.FOCAL.MLP_RATIO = 4.
95
+ _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_NORM = True
96
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_LEVELS = [2, 2, 2, 2]
97
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_WINDOWS = [3, 3, 3, 3]
98
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_FACTORS = [2, 2, 2, 2]
99
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_CONV_EMBED = False
100
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_LAYERSCALE = False
101
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_POSTLN = False
102
+
103
+ # -----------------------------------------------------------------------------
104
+ # Build Text Encoder
105
+ # -----------------------------------------------------------------------------
106
+ _C.MODEL.TEXT_ENCODER = CN()
107
+
108
+ _C.MODEL.TEXT_ENCODER.NAME = 'transformer'
109
+ _C.MODEL.TEXT_ENCODER.LOAD_PRETRAINED = False
110
+ _C.MODEL.TEXT_ENCODER.PRETRAINED = ''
111
+ _C.MODEL.TEXT_ENCODER.TOKENIZER = 'clip'
112
+ _C.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
113
+ _C.MODEL.TEXT_ENCODER.WIDTH = 1024
114
+ _C.MODEL.TEXT_ENCODER.HEADS = 16
115
+ _C.MODEL.TEXT_ENCODER.LAYERS = 12
116
+ _C.MODEL.TEXT_ENCODER.AUTOGRESSIVE = True
117
+
118
+ # -----------------------------------------------------------------------------
119
+ # Training settings
120
+ # -----------------------------------------------------------------------------
121
+ _C.TRAIN = CN()
122
+ _C.TRAIN.START_EPOCH = 0
123
+ _C.TRAIN.EPOCHS = 32
124
+ _C.TRAIN.WARMUP_EPOCHS = 5
125
+ _C.TRAIN.WEIGHT_DECAY = 0.1
126
+ _C.TRAIN.BASE_LR = 5e-4
127
+ _C.TRAIN.WARMUP_LR = 5e-7
128
+ _C.TRAIN.MIN_LR = 5e-6
129
+ # Clip gradient norm
130
+ _C.TRAIN.CLIP_GRAD = 5.0
131
+ # Auto resume from latest checkpoint
132
+ _C.TRAIN.AUTO_RESUME = True
133
+ # Gradient accumulation steps
134
+ # could be overwritten by command line argument
135
+ _C.TRAIN.ACCUMULATION_STEPS = 0
136
+ # Whether to use gradient checkpointing to save memory
137
+ # could be overwritten by command line argument
138
+ _C.TRAIN.USE_CHECKPOINT = False
139
+
140
+ # LR scheduler
141
+ _C.TRAIN.LR_SCHEDULER = CN()
142
+ _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
143
+ # Epoch interval to decay LR, used in StepLRScheduler
144
+ _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
145
+ # LR decay rate, used in StepLRScheduler
146
+ _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
147
+
148
+ # Optimizer
149
+ _C.TRAIN.OPTIMIZER = CN()
150
+ _C.TRAIN.OPTIMIZER.NAME = 'adamw'
151
+ # Optimizer Epsilon
152
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
153
+ # Optimizer Betas
154
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
155
+ # SGD momentum
156
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
157
+
158
+ # -----------------------------------------------------------------------------
159
+ # Augmentation settings
160
+ # -----------------------------------------------------------------------------
161
+ _C.AUG = CN()
162
+ # Color jitter factor
163
+ _C.AUG.COLOR_JITTER = 0.4
164
+ # Use AutoAugment policy. "v0" or "original"
165
+ _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
166
+ # Random erase prob
167
+ _C.AUG.REPROB = 0.25
168
+ # Random erase mode
169
+ _C.AUG.REMODE = 'pixel'
170
+ # Random erase count
171
+ _C.AUG.RECOUNT = 1
172
+ # Mixup alpha, mixup enabled if > 0
173
+ _C.AUG.MIXUP = 0.8
174
+ # Cutmix alpha, cutmix enabled if > 0
175
+ _C.AUG.CUTMIX = 1.0
176
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
177
+ _C.AUG.CUTMIX_MINMAX = None
178
+ # Probability of performing mixup or cutmix when either/both is enabled
179
+ _C.AUG.MIXUP_PROB = 1.0
180
+ # Probability of switching to cutmix when both mixup and cutmix enabled
181
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
182
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
183
+ _C.AUG.MIXUP_MODE = 'batch'
184
+
185
+ # -----------------------------------------------------------------------------
186
+ # Testing settings
187
+ # -----------------------------------------------------------------------------
188
+ _C.TEST = CN()
189
+ # Whether to use center crop when testing
190
+ _C.TEST.CROP = True
191
+
192
+ # -----------------------------------------------------------------------------
193
+ # Misc
194
+ # -----------------------------------------------------------------------------
195
+ # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
196
+ # overwritten by command line argument
197
+ _C.AMP_OPT_LEVEL = ''
198
+ # Path to output folder, overwritten by command line argument
199
+ _C.OUTPUT = ''
200
+ # Tag of experiment, overwritten by command line argument
201
+ _C.TAG = 'default'
202
+ # Frequency to save checkpoint
203
+ _C.SAVE_FREQ = 1
204
+ # Frequency to logging info
205
+ _C.PRINT_FREQ = 100
206
+ # Fixed random seed
207
+ _C.SEED = 0
208
+ # Perform evaluation only, overwritten by command line argument
209
+ _C.EVAL_MODE = False
210
+ # Test throughput only, overwritten by command line argument
211
+ _C.THROUGHPUT_MODE = False
212
+ # Debug only so that skip dataloader initialization, overwritten by command line argument
213
+ _C.DEBUG_MODE = False
214
+ # local rank for DistributedDataParallel, given by command line argument
215
+ _C.LOCAL_RANK = 0
216
+
217
+
218
+ def _update_config_from_file(config, cfg_file):
219
+ config.defrost()
220
+ with open(cfg_file, 'r') as f:
221
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
222
+
223
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
224
+ if cfg:
225
+ _update_config_from_file(
226
+ config, os.path.join(os.path.dirname(cfg_file), cfg)
227
+ )
228
+ print('=> merge config from {}'.format(cfg_file))
229
+ config.merge_from_file(cfg_file)
230
+ config.freeze()
231
+
232
+
233
+ def update_config(config, args):
234
+ _update_config_from_file(config, args.cfg)
235
+ config.freeze()
236
+
237
+
238
+ def get_config(args):
239
+ """Get a yacs CfgNode object with default values."""
240
+ # Return a clone so that the defaults will not be altered
241
+ # This is for the "local variable" use pattern
242
+ config = _C.clone()
243
+ update_config(config, args)
244
+
245
+ return config
configs/Base-RCNN-C4.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GeneralizedRCNN"
3
+ RPN:
4
+ PRE_NMS_TOPK_TEST: 6000
5
+ POST_NMS_TOPK_TEST: 1000
6
+ ROI_HEADS:
7
+ NAME: "Res5ROIHeads"
8
+ DATASETS:
9
+ TRAIN: ("coco_2017_train",)
10
+ TEST: ("coco_2017_val",)
11
+ SOLVER:
12
+ IMS_PER_BATCH: 16
13
+ BASE_LR: 0.02
14
+ STEPS: (60000, 80000)
15
+ MAX_ITER: 90000
16
+ INPUT:
17
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
18
+ VERSION: 2
configs/Base-RCNN-FPN.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GeneralizedRCNN"
3
+ BACKBONE:
4
+ NAME: "build_resnet_fpn_backbone"
5
+ RESNETS:
6
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
7
+ FPN:
8
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
9
+ ANCHOR_GENERATOR:
10
+ SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
11
+ ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
12
+ RPN:
13
+ IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
14
+ PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
15
+ PRE_NMS_TOPK_TEST: 1000 # Per FPN level
16
+ # Detectron1 uses 2000 proposals per-batch,
17
+ # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
18
+ # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
19
+ POST_NMS_TOPK_TRAIN: 1000
20
+ POST_NMS_TOPK_TEST: 1000
21
+ ROI_HEADS:
22
+ NAME: "StandardROIHeads"
23
+ IN_FEATURES: ["p2", "p3", "p4", "p5"]
24
+ ROI_BOX_HEAD:
25
+ NAME: "FastRCNNConvFCHead"
26
+ NUM_FC: 2
27
+ POOLER_RESOLUTION: 7
28
+ ROI_MASK_HEAD:
29
+ NAME: "MaskRCNNConvUpsampleHead"
30
+ NUM_CONV: 4
31
+ POOLER_RESOLUTION: 14
32
+ DATASETS:
33
+ TRAIN: ("coco_2017_train",)
34
+ TEST: ("coco_2017_val",)
35
+ SOLVER:
36
+ IMS_PER_BATCH: 16
37
+ BASE_LR: 0.02
38
+ STEPS: (60000, 80000)
39
+ MAX_ITER: 90000
40
+ INPUT:
41
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
42
+ VERSION: 2
configs/CLIP_fast_rcnn_R_50_C4.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "./Base-RCNN-C4.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "CLIPFastRCNN" # "CLIPRCNN" # "GeneralizedRCNN"
4
+ MASK_ON: False
5
+ WEIGHTS: "./model_final.pth"
6
+ BACKBONE:
7
+ NAME: "build_clip_resnet_backbone" # "build_resnet_fpn_backbone"
8
+ FREEZE_AT: 2
9
+ TEXT_BACKBONE:
10
+ NAME: "build_clip_language_encoder"
11
+ CLIP:
12
+ CROP_REGION_TYPE: "RPN"
13
+ OFFLINE_RPN_CONFIG: "./configs/mask_rcnn_R_50_FPN_1x.yaml"
14
+ USE_TEXT_EMB_CLASSIFIER: True
15
+ TEXT_EMB_PATH: "./lvis_1203_cls_emb_notnorm_rn50x4.pth"
16
+ NO_BOX_DELTA: True
17
+ OFFLINE_RPN_NMS_THRESH: 0.7
18
+ CLSS_TEMP: 0.01
19
+ MULTIPLY_RPN_SCORE: True
20
+ TEXT_EMB_DIM: 640
21
+ RESNETS:
22
+ DEPTH: 200
23
+ OUT_FEATURES: ["res4"]
24
+ NORM: FrozenBN
25
+ STEM_OUT_CHANNELS: 64
26
+ RES2_OUT_CHANNELS: 256
27
+ RPN:
28
+ HEAD_NAME: StandardRPNHead
29
+ IN_FEATURES: ["res4"]
30
+ POST_NMS_TOPK_TEST: 1000
31
+ NMS_THRESH:
32
+ ROI_HEADS:
33
+ NAME: "CLIPRes5ROIHeads" # "Res5ROIHeads" # "StandardROIHeads"
34
+ IN_FEATURES: ["res4"]
35
+ NUM_CLASSES: 1203
36
+ NMS_THRESH_TEST: 0.3
37
+ SCORE_THRESH_TEST: 0.0
38
+ ROI_BOX_HEAD:
39
+ NAME: ""
40
+ NUM_FC: 0
41
+ CLS_AGNOSTIC_BBOX_REG: True
42
+ POOLER_RESOLUTION: 18
43
+ ROI_MASK_HEAD:
44
+ NAME: "MaskRCNNConvUpsampleHead"
45
+ NUM_CONV: 0
46
+ POOLER_RESOLUTION: 14
47
+ PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
48
+ PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
49
+ INPUT:
50
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
51
+ DATASETS:
52
+ TRAIN: ("lvis_v1_train",)
53
+ TEST: ("lvis_v1_val",)
54
+ TEST:
55
+ DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
56
+ EVAL_PERIOD: 25000
57
+ SOLVER:
58
+ IMS_PER_BATCH: 16
59
+ BASE_LR: 0.02
60
+ STEPS: (120000, 160000)
61
+ MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
62
+ DATALOADER:
63
+ SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
64
+ REPEAT_THRESHOLD: 0.001
65
+ INPUT:
66
+ MIN_SIZE_TRAIN_SAMPLING: choice
67
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
68
+ MAX_SIZE_TRAIN: 1333
69
+ MIN_SIZE_TEST: 800
70
+ MAX_SIZE_TEST: 1333
71
+ FORMAT: "RGB"
configs/CLIP_fast_rcnn_swin_base_C4.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "./Base-RCNN-C4.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "CLIPFastRCNN" # "CLIPRCNN" # "GeneralizedRCNN"
4
+ BACKBONE:
5
+ NAME: "build_clip_swin" # "build_resnet_fpn_backbone"
6
+ FREEZE_AT: 2
7
+ TEXT_BACKBONE:
8
+ NAME: "build_clip_swin_text_backbone"
9
+ SPEC:
10
+ EMBED_DIM: 512
11
+ VISION:
12
+ PATCH_SIZE: 4
13
+ IN_CHANS: 3
14
+ EMBED_DIM: 128
15
+ DEPTHS: [ 2, 2, 18, 2 ]
16
+ NUM_HEADS: [ 4, 8, 16, 32 ]
17
+ WINDOW_SIZE: 7
18
+ MLP_RATIO: 4.
19
+ QKV_BIAS: True
20
+ APE: False
21
+ PATCH_NORM: True
22
+ DROP_RATE: 0.0
23
+ DROP_PATH_RATE: 0.2
24
+ OUT_FEATURES: ["stage2", "stage3", "stage4", "stage5"]
25
+ TEXT:
26
+ NAME: 'transformer'
27
+ TOKENIZER: clip
28
+ CONTEXT_LENGTH: 77
29
+ WIDTH: 512
30
+ HEADS: 8
31
+ LAYERS: 12
32
+ WEIGHTS: "" # "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
33
+ MASK_ON: True
34
+ RPN:
35
+ HEAD_NAME: StandardRPNHead
36
+ IN_FEATURES: ["stage4"]
37
+ ROI_HEADS:
38
+ NAME: "CLIPSwinROIHeads" # "Res5ROIHeads" # "StandardROIHeads"
39
+ IN_FEATURES: ["stage4"]
40
+ NUM_CLASSES: 1203
41
+ SCORE_THRESH_TEST: 0.0001
42
+ ROI_BOX_HEAD:
43
+ NAME: ""
44
+ NUM_FC: 0
45
+ POOLER_RESOLUTION: 14
46
+ ROI_MASK_HEAD:
47
+ NAME: "MaskRCNNConvUpsampleHead"
48
+ NUM_CONV: 0
49
+ POOLER_RESOLUTION: 14
50
+ PIXEL_MEAN: [0.485, 0.456, 0.406]
51
+ PIXEL_STD: [0.229, 0.224, 0.225]
52
+ INPUT:
53
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
54
+ DATASETS:
55
+ TRAIN: ("lvis_v1_train",)
56
+ TEST: ("lvis_v1_val",)
57
+ TEST:
58
+ DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
59
+ EVAL_PERIOD: 25000
60
+ SOLVER:
61
+ IMS_PER_BATCH: 16
62
+ BASE_LR: 0.02
63
+ STEPS: (120000, 160000)
64
+ MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
65
+ DATALOADER:
66
+ SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
67
+ REPEAT_THRESHOLD: 0.001
68
+ INPUT:
69
+ MIN_SIZE_TRAIN_SAMPLING: choice
70
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
71
+ MAX_SIZE_TRAIN: 1333
72
+ MIN_SIZE_TEST: 800
73
+ MAX_SIZE_TEST: 1333
74
+ FORMAT: "RGB"
configs/mask_rcnn_CLIP_R_50_C4_1x.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "./Base-RCNN-C4.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "GeneralizedRCNN"
4
+ BACKBONE:
5
+ NAME: "build_clip_resnet_backbone" #"build_clip_resnet_fpn_backbone" # "build_resnet_fpn_backbone"
6
+ FREEZE_AT: 2
7
+ WEIGHTS: "" # "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
8
+ MASK_ON: True
9
+ RESNETS:
10
+ DEPTH: 50
11
+ OUT_FEATURES: ["res4"]
12
+ NORM: FrozenBN
13
+ STEM_OUT_CHANNELS: 64
14
+ RES2_OUT_CHANNELS: 256
15
+ RPN:
16
+ HEAD_NAME: StandardRPNHead
17
+ IN_FEATURES: ["res4"]
18
+ ROI_HEADS:
19
+ NAME: "CLIPRes5ROIHeads" # "Res5ROIHeads" # "StandardROIHeads"
20
+ IN_FEATURES: ["res4"]
21
+ NUM_CLASSES: 1203
22
+ SCORE_THRESH_TEST: 0.0001
23
+ ROI_BOX_HEAD:
24
+ NAME: ""
25
+ NUM_FC: 0
26
+ POOLER_RESOLUTION: 14
27
+ ROI_MASK_HEAD:
28
+ NAME: "MaskRCNNConvUpsampleHead"
29
+ NUM_CONV: 0
30
+ POOLER_RESOLUTION: 14
31
+ PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] # [103.530, 116.280, 123.675] #
32
+ PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] # [1.0, 1.0, 1.0] #
33
+ INPUT:
34
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
35
+ DATASETS:
36
+ TRAIN: ("lvis_v1_train",)
37
+ TEST: ("lvis_v1_val",)
38
+ TEST:
39
+ DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
40
+ EVAL_PERIOD: 25000
41
+ SOLVER:
42
+ IMS_PER_BATCH: 16
43
+ BASE_LR: 0.02
44
+ STEPS: (120000, 160000) # (140000,) #
45
+ MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
46
+ DATALOADER:
47
+ SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
48
+ REPEAT_THRESHOLD: 0.001
49
+ INPUT:
50
+ MIN_SIZE_TRAIN_SAMPLING: choice
51
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
52
+ MAX_SIZE_TRAIN: 1333
53
+ MIN_SIZE_TEST: 800
54
+ MAX_SIZE_TEST: 1333
55
+ FORMAT: "RGB" # "BGR"
configs/mask_rcnn_R_50_C4_1x.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "./Base-RCNN-C4.yaml"
2
+ MODEL:
3
+ WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
4
+ MASK_ON: True
5
+ RESNETS:
6
+ DEPTH: 50
7
+ ROI_HEADS:
8
+ NUM_CLASSES: 1203
9
+ SCORE_THRESH_TEST: 0.0001
10
+ INPUT:
11
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
12
+ DATASETS:
13
+ TRAIN: ("lvis_v1_train",)
14
+ TEST: ("lvis_v1_val",)
15
+ TEST:
16
+ DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
17
+ EVAL_PERIOD: 50000
18
+ SOLVER:
19
+ STEPS: (120000, 160000)
20
+ MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
21
+ DATALOADER:
22
+ SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
23
+ REPEAT_THRESHOLD: 0.001
configs/mask_rcnn_R_50_FPN_1x.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "./Base-RCNN-FPN.yaml"
2
+ MODEL:
3
+ WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
4
+ MASK_ON: True
5
+ RESNETS:
6
+ DEPTH: 50
7
+ ROI_HEADS:
8
+ NUM_CLASSES: 1203
9
+ SCORE_THRESH_TEST: 0.0001
10
+ INPUT:
11
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
12
+ DATASETS:
13
+ TRAIN: ("lvis_v1_train",)
14
+ TEST: ("lvis_v1_val",)
15
+ TEST:
16
+ DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300
17
+ EVAL_PERIOD: 50000
18
+ SOLVER:
19
+ STEPS: (120000, 160000)
20
+ MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs
21
+ DATALOADER:
22
+ SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
23
+ REPEAT_THRESHOLD: 0.001
datasets/README.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Builtin Datasets
2
+
3
+ A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog)
4
+ for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc).
5
+ This document explains how to setup the builtin datasets so they can be used by the above APIs.
6
+ [Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`,
7
+ and how to add new datasets to them.
8
+
9
+ Detectron2 has builtin support for a few datasets.
10
+ The datasets are assumed to exist in a directory specified by the environment variable
11
+ `DETECTRON2_DATASETS`.
12
+ Under this directory, detectron2 will look for datasets in the structure described below, if needed.
13
+ ```
14
+ $DETECTRON2_DATASETS/
15
+ coco/
16
+ lvis/
17
+ cityscapes/
18
+ VOC20{07,12}/
19
+ ```
20
+
21
+ You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`.
22
+ If left unset, the default is `./datasets` relative to your current working directory.
23
+
24
+ The [model zoo](https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md)
25
+ contains configs and models that use these builtin datasets.
26
+
27
+ ## Expected dataset structure for [COCO instance/keypoint detection](https://cocodataset.org/#download):
28
+
29
+ ```
30
+ coco/
31
+ annotations/
32
+ instances_{train,val}2017.json
33
+ person_keypoints_{train,val}2017.json
34
+ {train,val}2017/
35
+ # image files that are mentioned in the corresponding json
36
+ ```
37
+
38
+ You can use the 2014 version of the dataset as well.
39
+
40
+ Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset,
41
+ which you can download with `./datasets/prepare_for_tests.sh`.
42
+
43
+ ## Expected dataset structure for PanopticFPN:
44
+
45
+ Extract panoptic annotations from [COCO website](https://cocodataset.org/#download)
46
+ into the following structure:
47
+ ```
48
+ coco/
49
+ annotations/
50
+ panoptic_{train,val}2017.json
51
+ panoptic_{train,val}2017/ # png annotations
52
+ panoptic_stuff_{train,val}2017/ # generated by the script mentioned below
53
+ ```
54
+
55
+ Install panopticapi by:
56
+ ```
57
+ pip install git+https://github.com/cocodataset/panopticapi.git
58
+ ```
59
+ Then, run `python datasets/prepare_panoptic_fpn.py`, to extract semantic annotations from panoptic annotations.
60
+
61
+ ## Expected dataset structure for [LVIS instance segmentation](https://www.lvisdataset.org/dataset):
62
+ ```
63
+ coco/
64
+ {train,val,test}2017/
65
+ lvis/
66
+ lvis_v0.5_{train,val}.json
67
+ lvis_v0.5_image_info_test.json
68
+ lvis_v1_{train,val}.json
69
+ lvis_v1_image_info_test{,_challenge}.json
70
+ ```
71
+
72
+ Install lvis-api by:
73
+ ```
74
+ pip install git+https://github.com/lvis-dataset/lvis-api.git
75
+ ```
76
+
77
+ To evaluate models trained on the COCO dataset using LVIS annotations,
78
+ run `python datasets/prepare_cocofied_lvis.py` to prepare "cocofied" LVIS annotations.
79
+
80
+ ## Expected dataset structure for [cityscapes](https://www.cityscapes-dataset.com/downloads/):
81
+ ```
82
+ cityscapes/
83
+ gtFine/
84
+ train/
85
+ aachen/
86
+ color.png, instanceIds.png, labelIds.png, polygons.json,
87
+ labelTrainIds.png
88
+ ...
89
+ val/
90
+ test/
91
+ # below are generated Cityscapes panoptic annotation
92
+ cityscapes_panoptic_train.json
93
+ cityscapes_panoptic_train/
94
+ cityscapes_panoptic_val.json
95
+ cityscapes_panoptic_val/
96
+ cityscapes_panoptic_test.json
97
+ cityscapes_panoptic_test/
98
+ leftImg8bit/
99
+ train/
100
+ val/
101
+ test/
102
+ ```
103
+ Install cityscapes scripts by:
104
+ ```
105
+ pip install git+https://github.com/mcordts/cityscapesScripts.git
106
+ ```
107
+
108
+ Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with:
109
+ ```
110
+ CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py
111
+ ```
112
+ These files are not needed for instance segmentation.
113
+
114
+ Note: to generate Cityscapes panoptic dataset, run cityscapesescript with:
115
+ ```
116
+ CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py
117
+ ```
118
+ These files are not needed for semantic and instance segmentation.
119
+
120
+ ## Expected dataset structure for [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html):
121
+ ```
122
+ VOC20{07,12}/
123
+ Annotations/
124
+ ImageSets/
125
+ Main/
126
+ trainval.txt
127
+ test.txt
128
+ # train.txt or val.txt, if you use these splits
129
+ JPEGImages/
130
+ ```
131
+
132
+ ## Expected dataset structure for [ADE20k Scene Parsing](http://sceneparsing.csail.mit.edu/):
133
+ ```
134
+ ADEChallengeData2016/
135
+ annotations/
136
+ annotations_detectron2/
137
+ images/
138
+ objectInfo150.txt
139
+ ```
140
+ The directory `annotations_detectron2` is generated by running `python datasets/prepare_ade20k_sem_seg.py`.
datasets/custom_images/dog_and_cat.jfif ADDED
Binary file (121 kB). View file
 
datasets/prepare_ade20k_sem_seg.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ import numpy as np
5
+ import os
6
+ from pathlib import Path
7
+ import tqdm
8
+ from PIL import Image
9
+
10
+
11
+ def convert(input, output):
12
+ img = np.asarray(Image.open(input))
13
+ assert img.dtype == np.uint8
14
+ img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1
15
+ Image.fromarray(img).save(output)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016"
20
+ for name in ["training", "validation"]:
21
+ annotation_dir = dataset_dir / "annotations" / name
22
+ output_dir = dataset_dir / "annotations_detectron2" / name
23
+ output_dir.mkdir(parents=True, exist_ok=True)
24
+ for file in tqdm.tqdm(list(annotation_dir.iterdir())):
25
+ output_file = output_dir / file.name
26
+ convert(file, output_file)
datasets/prepare_cocofied_lvis.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ import copy
6
+ import json
7
+ import os
8
+ from collections import defaultdict
9
+
10
+ # This mapping is extracted from the official LVIS mapping:
11
+ # https://github.com/lvis-dataset/lvis-api/blob/master/data/coco_to_synset.json
12
+ COCO_SYNSET_CATEGORIES = [
13
+ {"synset": "person.n.01", "coco_cat_id": 1},
14
+ {"synset": "bicycle.n.01", "coco_cat_id": 2},
15
+ {"synset": "car.n.01", "coco_cat_id": 3},
16
+ {"synset": "motorcycle.n.01", "coco_cat_id": 4},
17
+ {"synset": "airplane.n.01", "coco_cat_id": 5},
18
+ {"synset": "bus.n.01", "coco_cat_id": 6},
19
+ {"synset": "train.n.01", "coco_cat_id": 7},
20
+ {"synset": "truck.n.01", "coco_cat_id": 8},
21
+ {"synset": "boat.n.01", "coco_cat_id": 9},
22
+ {"synset": "traffic_light.n.01", "coco_cat_id": 10},
23
+ {"synset": "fireplug.n.01", "coco_cat_id": 11},
24
+ {"synset": "stop_sign.n.01", "coco_cat_id": 13},
25
+ {"synset": "parking_meter.n.01", "coco_cat_id": 14},
26
+ {"synset": "bench.n.01", "coco_cat_id": 15},
27
+ {"synset": "bird.n.01", "coco_cat_id": 16},
28
+ {"synset": "cat.n.01", "coco_cat_id": 17},
29
+ {"synset": "dog.n.01", "coco_cat_id": 18},
30
+ {"synset": "horse.n.01", "coco_cat_id": 19},
31
+ {"synset": "sheep.n.01", "coco_cat_id": 20},
32
+ {"synset": "beef.n.01", "coco_cat_id": 21},
33
+ {"synset": "elephant.n.01", "coco_cat_id": 22},
34
+ {"synset": "bear.n.01", "coco_cat_id": 23},
35
+ {"synset": "zebra.n.01", "coco_cat_id": 24},
36
+ {"synset": "giraffe.n.01", "coco_cat_id": 25},
37
+ {"synset": "backpack.n.01", "coco_cat_id": 27},
38
+ {"synset": "umbrella.n.01", "coco_cat_id": 28},
39
+ {"synset": "bag.n.04", "coco_cat_id": 31},
40
+ {"synset": "necktie.n.01", "coco_cat_id": 32},
41
+ {"synset": "bag.n.06", "coco_cat_id": 33},
42
+ {"synset": "frisbee.n.01", "coco_cat_id": 34},
43
+ {"synset": "ski.n.01", "coco_cat_id": 35},
44
+ {"synset": "snowboard.n.01", "coco_cat_id": 36},
45
+ {"synset": "ball.n.06", "coco_cat_id": 37},
46
+ {"synset": "kite.n.03", "coco_cat_id": 38},
47
+ {"synset": "baseball_bat.n.01", "coco_cat_id": 39},
48
+ {"synset": "baseball_glove.n.01", "coco_cat_id": 40},
49
+ {"synset": "skateboard.n.01", "coco_cat_id": 41},
50
+ {"synset": "surfboard.n.01", "coco_cat_id": 42},
51
+ {"synset": "tennis_racket.n.01", "coco_cat_id": 43},
52
+ {"synset": "bottle.n.01", "coco_cat_id": 44},
53
+ {"synset": "wineglass.n.01", "coco_cat_id": 46},
54
+ {"synset": "cup.n.01", "coco_cat_id": 47},
55
+ {"synset": "fork.n.01", "coco_cat_id": 48},
56
+ {"synset": "knife.n.01", "coco_cat_id": 49},
57
+ {"synset": "spoon.n.01", "coco_cat_id": 50},
58
+ {"synset": "bowl.n.03", "coco_cat_id": 51},
59
+ {"synset": "banana.n.02", "coco_cat_id": 52},
60
+ {"synset": "apple.n.01", "coco_cat_id": 53},
61
+ {"synset": "sandwich.n.01", "coco_cat_id": 54},
62
+ {"synset": "orange.n.01", "coco_cat_id": 55},
63
+ {"synset": "broccoli.n.01", "coco_cat_id": 56},
64
+ {"synset": "carrot.n.01", "coco_cat_id": 57},
65
+ {"synset": "frank.n.02", "coco_cat_id": 58},
66
+ {"synset": "pizza.n.01", "coco_cat_id": 59},
67
+ {"synset": "doughnut.n.02", "coco_cat_id": 60},
68
+ {"synset": "cake.n.03", "coco_cat_id": 61},
69
+ {"synset": "chair.n.01", "coco_cat_id": 62},
70
+ {"synset": "sofa.n.01", "coco_cat_id": 63},
71
+ {"synset": "pot.n.04", "coco_cat_id": 64},
72
+ {"synset": "bed.n.01", "coco_cat_id": 65},
73
+ {"synset": "dining_table.n.01", "coco_cat_id": 67},
74
+ {"synset": "toilet.n.02", "coco_cat_id": 70},
75
+ {"synset": "television_receiver.n.01", "coco_cat_id": 72},
76
+ {"synset": "laptop.n.01", "coco_cat_id": 73},
77
+ {"synset": "mouse.n.04", "coco_cat_id": 74},
78
+ {"synset": "remote_control.n.01", "coco_cat_id": 75},
79
+ {"synset": "computer_keyboard.n.01", "coco_cat_id": 76},
80
+ {"synset": "cellular_telephone.n.01", "coco_cat_id": 77},
81
+ {"synset": "microwave.n.02", "coco_cat_id": 78},
82
+ {"synset": "oven.n.01", "coco_cat_id": 79},
83
+ {"synset": "toaster.n.02", "coco_cat_id": 80},
84
+ {"synset": "sink.n.01", "coco_cat_id": 81},
85
+ {"synset": "electric_refrigerator.n.01", "coco_cat_id": 82},
86
+ {"synset": "book.n.01", "coco_cat_id": 84},
87
+ {"synset": "clock.n.01", "coco_cat_id": 85},
88
+ {"synset": "vase.n.01", "coco_cat_id": 86},
89
+ {"synset": "scissors.n.01", "coco_cat_id": 87},
90
+ {"synset": "teddy.n.01", "coco_cat_id": 88},
91
+ {"synset": "hand_blower.n.01", "coco_cat_id": 89},
92
+ {"synset": "toothbrush.n.01", "coco_cat_id": 90},
93
+ ]
94
+
95
+
96
+ def cocofy_lvis(input_filename, output_filename):
97
+ """
98
+ Filter LVIS instance segmentation annotations to remove all categories that are not included in
99
+ COCO. The new json files can be used to evaluate COCO AP using `lvis-api`. The category ids in
100
+ the output json are the incontiguous COCO dataset ids.
101
+
102
+ Args:
103
+ input_filename (str): path to the LVIS json file.
104
+ output_filename (str): path to the COCOfied json file.
105
+ """
106
+
107
+ with open(input_filename, "r") as f:
108
+ lvis_json = json.load(f)
109
+
110
+ lvis_annos = lvis_json.pop("annotations")
111
+ cocofied_lvis = copy.deepcopy(lvis_json)
112
+ lvis_json["annotations"] = lvis_annos
113
+
114
+ # Mapping from lvis cat id to coco cat id via synset
115
+ lvis_cat_id_to_synset = {cat["id"]: cat["synset"] for cat in lvis_json["categories"]}
116
+ synset_to_coco_cat_id = {x["synset"]: x["coco_cat_id"] for x in COCO_SYNSET_CATEGORIES}
117
+ # Synsets that we will keep in the dataset
118
+ synsets_to_keep = set(synset_to_coco_cat_id.keys())
119
+ coco_cat_id_with_instances = defaultdict(int)
120
+
121
+ new_annos = []
122
+ ann_id = 1
123
+ for ann in lvis_annos:
124
+ lvis_cat_id = ann["category_id"]
125
+ synset = lvis_cat_id_to_synset[lvis_cat_id]
126
+ if synset not in synsets_to_keep:
127
+ continue
128
+ coco_cat_id = synset_to_coco_cat_id[synset]
129
+ new_ann = copy.deepcopy(ann)
130
+ new_ann["category_id"] = coco_cat_id
131
+ new_ann["id"] = ann_id
132
+ ann_id += 1
133
+ new_annos.append(new_ann)
134
+ coco_cat_id_with_instances[coco_cat_id] += 1
135
+ cocofied_lvis["annotations"] = new_annos
136
+
137
+ for image in cocofied_lvis["images"]:
138
+ for key in ["not_exhaustive_category_ids", "neg_category_ids"]:
139
+ new_category_list = []
140
+ for lvis_cat_id in image[key]:
141
+ synset = lvis_cat_id_to_synset[lvis_cat_id]
142
+ if synset not in synsets_to_keep:
143
+ continue
144
+ coco_cat_id = synset_to_coco_cat_id[synset]
145
+ new_category_list.append(coco_cat_id)
146
+ coco_cat_id_with_instances[coco_cat_id] += 1
147
+ image[key] = new_category_list
148
+
149
+ coco_cat_id_with_instances = set(coco_cat_id_with_instances.keys())
150
+
151
+ new_categories = []
152
+ for cat in lvis_json["categories"]:
153
+ synset = cat["synset"]
154
+ if synset not in synsets_to_keep:
155
+ continue
156
+ coco_cat_id = synset_to_coco_cat_id[synset]
157
+ if coco_cat_id not in coco_cat_id_with_instances:
158
+ continue
159
+ new_cat = copy.deepcopy(cat)
160
+ new_cat["id"] = coco_cat_id
161
+ new_categories.append(new_cat)
162
+ cocofied_lvis["categories"] = new_categories
163
+
164
+ with open(output_filename, "w") as f:
165
+ json.dump(cocofied_lvis, f)
166
+ print("{} is COCOfied and stored in {}.".format(input_filename, output_filename))
167
+
168
+
169
+ if __name__ == "__main__":
170
+ dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "lvis")
171
+ for s in ["lvis_v0.5_train", "lvis_v0.5_val"]:
172
+ print("Start COCOfing {}.".format(s))
173
+ cocofy_lvis(
174
+ os.path.join(dataset_dir, "{}.json".format(s)),
175
+ os.path.join(dataset_dir, "{}_cocofied.json".format(s)),
176
+ )
datasets/prepare_for_tests.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -e
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # Download some files needed for running tests.
5
+
6
+ cd "${0%/*}"
7
+
8
+ BASE=https://dl.fbaipublicfiles.com/detectron2
9
+ mkdir -p coco/annotations
10
+
11
+ for anno in instances_val2017_100 \
12
+ person_keypoints_val2017_100 \
13
+ instances_minival2014_100 \
14
+ person_keypoints_minival2014_100; do
15
+
16
+ dest=coco/annotations/$anno.json
17
+ [[ -s $dest ]] && {
18
+ echo "$dest exists. Skipping ..."
19
+ } || {
20
+ wget $BASE/annotations/coco/$anno.json -O $dest
21
+ }
22
+ done
datasets/prepare_panoptic_fpn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ import functools
6
+ import json
7
+ import multiprocessing as mp
8
+ import numpy as np
9
+ import os
10
+ import time
11
+ from fvcore.common.download import download
12
+ from panopticapi.utils import rgb2id
13
+ from PIL import Image
14
+
15
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
16
+
17
+
18
+ def _process_panoptic_to_semantic(input_panoptic, output_semantic, segments, id_map):
19
+ panoptic = np.asarray(Image.open(input_panoptic), dtype=np.uint32)
20
+ panoptic = rgb2id(panoptic)
21
+ output = np.zeros_like(panoptic, dtype=np.uint8) + 255
22
+ for seg in segments:
23
+ cat_id = seg["category_id"]
24
+ new_cat_id = id_map[cat_id]
25
+ output[panoptic == seg["id"]] = new_cat_id
26
+ Image.fromarray(output).save(output_semantic)
27
+
28
+
29
+ def separate_coco_semantic_from_panoptic(panoptic_json, panoptic_root, sem_seg_root, categories):
30
+ """
31
+ Create semantic segmentation annotations from panoptic segmentation
32
+ annotations, to be used by PanopticFPN.
33
+
34
+ It maps all thing categories to class 0, and maps all unlabeled pixels to class 255.
35
+ It maps all stuff categories to contiguous ids starting from 1.
36
+
37
+ Args:
38
+ panoptic_json (str): path to the panoptic json file, in COCO's format.
39
+ panoptic_root (str): a directory with panoptic annotation files, in COCO's format.
40
+ sem_seg_root (str): a directory to output semantic annotation files
41
+ categories (list[dict]): category metadata. Each dict needs to have:
42
+ "id": corresponds to the "category_id" in the json annotations
43
+ "isthing": 0 or 1
44
+ """
45
+ os.makedirs(sem_seg_root, exist_ok=True)
46
+
47
+ stuff_ids = [k["id"] for k in categories if k["isthing"] == 0]
48
+ thing_ids = [k["id"] for k in categories if k["isthing"] == 1]
49
+ id_map = {} # map from category id to id in the output semantic annotation
50
+ assert len(stuff_ids) <= 254
51
+ for i, stuff_id in enumerate(stuff_ids):
52
+ id_map[stuff_id] = i + 1
53
+ for thing_id in thing_ids:
54
+ id_map[thing_id] = 0
55
+ id_map[0] = 255
56
+
57
+ with open(panoptic_json) as f:
58
+ obj = json.load(f)
59
+
60
+ pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4))
61
+
62
+ def iter_annotations():
63
+ for anno in obj["annotations"]:
64
+ file_name = anno["file_name"]
65
+ segments = anno["segments_info"]
66
+ input = os.path.join(panoptic_root, file_name)
67
+ output = os.path.join(sem_seg_root, file_name)
68
+ yield input, output, segments
69
+
70
+ print("Start writing to {} ...".format(sem_seg_root))
71
+ start = time.time()
72
+ pool.starmap(
73
+ functools.partial(_process_panoptic_to_semantic, id_map=id_map),
74
+ iter_annotations(),
75
+ chunksize=100,
76
+ )
77
+ print("Finished. time: {:.2f}s".format(time.time() - start))
78
+
79
+
80
+ if __name__ == "__main__":
81
+ dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "coco")
82
+ for s in ["val2017", "train2017"]:
83
+ separate_coco_semantic_from_panoptic(
84
+ os.path.join(dataset_dir, "annotations/panoptic_{}.json".format(s)),
85
+ os.path.join(dataset_dir, "panoptic_{}".format(s)),
86
+ os.path.join(dataset_dir, "panoptic_stuff_{}".format(s)),
87
+ COCO_CATEGORIES,
88
+ )
89
+
90
+ # Prepare val2017_100 for quick testing:
91
+
92
+ dest_dir = os.path.join(dataset_dir, "annotations/")
93
+ URL_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/"
94
+ download(URL_PREFIX + "annotations/coco/panoptic_val2017_100.json", dest_dir)
95
+ with open(os.path.join(dest_dir, "panoptic_val2017_100.json")) as f:
96
+ obj = json.load(f)
97
+
98
+ def link_val100(dir_full, dir_100):
99
+ print("Creating " + dir_100 + " ...")
100
+ os.makedirs(dir_100, exist_ok=True)
101
+ for img in obj["images"]:
102
+ basename = os.path.splitext(img["file_name"])[0]
103
+ src = os.path.join(dir_full, basename + ".png")
104
+ dst = os.path.join(dir_100, basename + ".png")
105
+ src = os.path.relpath(src, start=dir_100)
106
+ os.symlink(src, dst)
107
+
108
+ link_val100(
109
+ os.path.join(dataset_dir, "panoptic_val2017"),
110
+ os.path.join(dataset_dir, "panoptic_val2017_100"),
111
+ )
112
+
113
+ link_val100(
114
+ os.path.join(dataset_dir, "panoptic_stuff_val2017"),
115
+ os.path.join(dataset_dir, "panoptic_stuff_val2017_100"),
116
+ )
detectron2/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ from .utils.env import setup_environment
4
+
5
+ setup_environment()
6
+
7
+
8
+ # This line will be programatically read/write by setup.py.
9
+ # Leave them at the bottom of this file and don't touch them.
10
+ __version__ = "0.4"
detectron2/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (223 Bytes). View file
 
detectron2/checkpoint/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # File:
4
+
5
+
6
+ from . import catalog as _UNUSED # register the handler
7
+ from .detection_checkpoint import DetectionCheckpointer
8
+ from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
9
+
10
+ __all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"]
detectron2/checkpoint/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (391 Bytes). View file
 
detectron2/checkpoint/__pycache__/c2_model_loading.cpython-39.pyc ADDED
Binary file (16.6 kB). View file
 
detectron2/checkpoint/__pycache__/catalog.cpython-39.pyc ADDED
Binary file (4.82 kB). View file
 
detectron2/checkpoint/__pycache__/clip_model_loading.cpython-39.pyc ADDED
Binary file (14.5 kB). View file
 
detectron2/checkpoint/__pycache__/detection_checkpoint.cpython-39.pyc ADDED
Binary file (3.88 kB). View file
 
detectron2/checkpoint/c2_model_loading.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import re
5
+ from typing import Dict, List
6
+ import torch
7
+ from tabulate import tabulate
8
+
9
+
10
+ def convert_basic_c2_names(original_keys):
11
+ """
12
+ Apply some basic name conversion to names in C2 weights.
13
+ It only deals with typical backbone models.
14
+
15
+ Args:
16
+ original_keys (list[str]):
17
+ Returns:
18
+ list[str]: The same number of strings matching those in original_keys.
19
+ """
20
+ layer_keys = copy.deepcopy(original_keys)
21
+ layer_keys = [
22
+ {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
23
+ ] # some hard-coded mappings
24
+
25
+ layer_keys = [k.replace("_", ".") for k in layer_keys]
26
+ layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
27
+ layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
28
+ # Uniform both bn and gn names to "norm"
29
+ layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
30
+ layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
31
+ layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
32
+ layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
33
+ layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
34
+ layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
35
+ layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
36
+ layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
37
+ layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
38
+ layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
39
+
40
+ # stem
41
+ layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
42
+ # to avoid mis-matching with "conv1" in other components (e.g. detection head)
43
+ layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
44
+
45
+ # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
46
+ # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
47
+ # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
48
+ # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
49
+ # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
50
+
51
+ # blocks
52
+ layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
53
+ layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
54
+ layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
55
+ layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
56
+
57
+ # DensePose substitutions
58
+ layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
59
+ layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
60
+ layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
61
+ layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
62
+ layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
63
+ return layer_keys
64
+
65
+
66
+ def convert_c2_detectron_names(weights):
67
+ """
68
+ Map Caffe2 Detectron weight names to Detectron2 names.
69
+
70
+ Args:
71
+ weights (dict): name -> tensor
72
+
73
+ Returns:
74
+ dict: detectron2 names -> tensor
75
+ dict: detectron2 names -> C2 names
76
+ """
77
+ logger = logging.getLogger(__name__)
78
+ logger.info("Renaming Caffe2 weights ......")
79
+ original_keys = sorted(weights.keys())
80
+ layer_keys = copy.deepcopy(original_keys)
81
+
82
+ layer_keys = convert_basic_c2_names(layer_keys)
83
+
84
+ # --------------------------------------------------------------------------
85
+ # RPN hidden representation conv
86
+ # --------------------------------------------------------------------------
87
+ # FPN case
88
+ # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
89
+ # shared for all other levels, hence the appearance of "fpn2"
90
+ layer_keys = [
91
+ k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
92
+ ]
93
+ # Non-FPN case
94
+ layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
95
+
96
+ # --------------------------------------------------------------------------
97
+ # RPN box transformation conv
98
+ # --------------------------------------------------------------------------
99
+ # FPN case (see note above about "fpn2")
100
+ layer_keys = [
101
+ k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
102
+ for k in layer_keys
103
+ ]
104
+ layer_keys = [
105
+ k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
106
+ for k in layer_keys
107
+ ]
108
+ # Non-FPN case
109
+ layer_keys = [
110
+ k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
111
+ ]
112
+ layer_keys = [
113
+ k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
114
+ for k in layer_keys
115
+ ]
116
+
117
+ # --------------------------------------------------------------------------
118
+ # Fast R-CNN box head
119
+ # --------------------------------------------------------------------------
120
+ layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
121
+ layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
122
+ layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
123
+ layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
124
+ # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
125
+ layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
126
+
127
+ # --------------------------------------------------------------------------
128
+ # FPN lateral and output convolutions
129
+ # --------------------------------------------------------------------------
130
+ def fpn_map(name):
131
+ """
132
+ Look for keys with the following patterns:
133
+ 1) Starts with "fpn.inner."
134
+ Example: "fpn.inner.res2.2.sum.lateral.weight"
135
+ Meaning: These are lateral pathway convolutions
136
+ 2) Starts with "fpn.res"
137
+ Example: "fpn.res2.2.sum.weight"
138
+ Meaning: These are FPN output convolutions
139
+ """
140
+ splits = name.split(".")
141
+ norm = ".norm" if "norm" in splits else ""
142
+ if name.startswith("fpn.inner."):
143
+ # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
144
+ stage = int(splits[2][len("res") :])
145
+ return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
146
+ elif name.startswith("fpn.res"):
147
+ # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
148
+ stage = int(splits[1][len("res") :])
149
+ return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
150
+ return name
151
+
152
+ layer_keys = [fpn_map(k) for k in layer_keys]
153
+
154
+ # --------------------------------------------------------------------------
155
+ # Mask R-CNN mask head
156
+ # --------------------------------------------------------------------------
157
+ # roi_heads.StandardROIHeads case
158
+ layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
159
+ layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
160
+ layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
161
+ # roi_heads.Res5ROIHeads case
162
+ layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
163
+
164
+ # --------------------------------------------------------------------------
165
+ # Keypoint R-CNN head
166
+ # --------------------------------------------------------------------------
167
+ # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
168
+ layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
169
+ layer_keys = [
170
+ k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
171
+ ]
172
+ layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
173
+
174
+ # --------------------------------------------------------------------------
175
+ # Done with replacements
176
+ # --------------------------------------------------------------------------
177
+ assert len(set(layer_keys)) == len(layer_keys)
178
+ assert len(original_keys) == len(layer_keys)
179
+
180
+ new_weights = {}
181
+ new_keys_to_original_keys = {}
182
+ for orig, renamed in zip(original_keys, layer_keys):
183
+ new_keys_to_original_keys[renamed] = orig
184
+ if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
185
+ # remove the meaningless prediction weight for background class
186
+ new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
187
+ new_weights[renamed] = weights[orig][new_start_idx:]
188
+ logger.info(
189
+ "Remove prediction weight for background class in {}. The shape changes from "
190
+ "{} to {}.".format(
191
+ renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
192
+ )
193
+ )
194
+ elif renamed.startswith("cls_score."):
195
+ # move weights of bg class from original index 0 to last index
196
+ logger.info(
197
+ "Move classification weights for background class in {} from index 0 to "
198
+ "index {}.".format(renamed, weights[orig].shape[0] - 1)
199
+ )
200
+ new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
201
+ else:
202
+ new_weights[renamed] = weights[orig]
203
+
204
+ return new_weights, new_keys_to_original_keys
205
+
206
+
207
+ # Note the current matching is not symmetric.
208
+ # it assumes model_state_dict will have longer names.
209
+ def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
210
+ """
211
+ Match names between the two state-dict, and returns a new chkpt_state_dict with names
212
+ converted to match model_state_dict with heuristics. The returned dict can be later
213
+ loaded with fvcore checkpointer.
214
+ If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
215
+ model and will be renamed at first.
216
+
217
+ Strategy: suppose that the models that we will create will have prefixes appended
218
+ to each of its keys, for example due to an extra level of nesting that the original
219
+ pre-trained weights from ImageNet won't contain. For example, model.state_dict()
220
+ might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
221
+ res2.conv1.weight. We thus want to match both parameters together.
222
+ For that, we look for each model weight, look among all loaded keys if there is one
223
+ that is a suffix of the current weight name, and use it if that's the case.
224
+ If multiple matches exist, take the one with longest size
225
+ of the corresponding name. For example, for the same model as before, the pretrained
226
+ weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
227
+ we want to match backbone[0].body.conv1.weight to conv1.weight, and
228
+ backbone[0].body.res2.conv1.weight to res2.conv1.weight.
229
+ """
230
+ model_keys = sorted(model_state_dict.keys())
231
+ if c2_conversion:
232
+ ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
233
+ # original_keys: the name in the original dict (before renaming)
234
+ else:
235
+ original_keys = {x: x for x in ckpt_state_dict.keys()}
236
+ ckpt_keys = sorted(ckpt_state_dict.keys())
237
+
238
+ def match(a, b):
239
+ # Matched ckpt_key should be a complete (starts with '.') suffix.
240
+ # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
241
+ # but matches whatever_conv1 or mesh_head.whatever_conv1.
242
+ return a == b or a.endswith("." + b)
243
+
244
+ # get a matrix of string matches, where each (i, j) entry correspond to the size of the
245
+ # ckpt_key string, if it matches
246
+ match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
247
+ match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
248
+ # use the matched one with longest size in case of multiple matches
249
+ max_match_size, idxs = match_matrix.max(1)
250
+ # remove indices that correspond to no-match
251
+ idxs[max_match_size == 0] = -1
252
+
253
+ logger = logging.getLogger(__name__)
254
+ # matched_pairs (matched checkpoint key --> matched model key)
255
+ matched_keys = {}
256
+ result_state_dict = {}
257
+ for idx_model, idx_ckpt in enumerate(idxs.tolist()):
258
+ if idx_ckpt == -1:
259
+ continue
260
+ key_model = model_keys[idx_model]
261
+ key_ckpt = ckpt_keys[idx_ckpt]
262
+ value_ckpt = ckpt_state_dict[key_ckpt]
263
+ shape_in_model = model_state_dict[key_model].shape
264
+
265
+ if shape_in_model != value_ckpt.shape:
266
+ logger.warning(
267
+ "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
268
+ key_ckpt, value_ckpt.shape, key_model, shape_in_model
269
+ )
270
+ )
271
+ logger.warning(
272
+ "{} will not be loaded. Please double check and see if this is desired.".format(
273
+ key_ckpt
274
+ )
275
+ )
276
+ continue
277
+
278
+ assert key_model not in result_state_dict
279
+ result_state_dict[key_model] = value_ckpt
280
+ if key_ckpt in matched_keys: # already added to matched_keys
281
+ logger.error(
282
+ "Ambiguity found for {} in checkpoint!"
283
+ "It matches at least two keys in the model ({} and {}).".format(
284
+ key_ckpt, key_model, matched_keys[key_ckpt]
285
+ )
286
+ )
287
+ raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
288
+
289
+ matched_keys[key_ckpt] = key_model
290
+
291
+ # logging:
292
+ matched_model_keys = sorted(matched_keys.values())
293
+ if len(matched_model_keys) == 0:
294
+ logger.warning("No weights in checkpoint matched with model.")
295
+ return ckpt_state_dict
296
+ common_prefix = _longest_common_prefix(matched_model_keys)
297
+ rev_matched_keys = {v: k for k, v in matched_keys.items()}
298
+ original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
299
+
300
+ model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
301
+ table = []
302
+ memo = set()
303
+ for key_model in matched_model_keys:
304
+ if key_model in memo:
305
+ continue
306
+ if key_model in model_key_groups:
307
+ group = model_key_groups[key_model]
308
+ memo |= set(group)
309
+ shapes = [tuple(model_state_dict[k].shape) for k in group]
310
+ table.append(
311
+ (
312
+ _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
313
+ _group_str([original_keys[k] for k in group]),
314
+ " ".join([str(x).replace(" ", "") for x in shapes]),
315
+ )
316
+ )
317
+ else:
318
+ key_checkpoint = original_keys[key_model]
319
+ shape = str(tuple(model_state_dict[key_model].shape))
320
+ table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
321
+ table_str = tabulate(
322
+ table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"]
323
+ )
324
+ logger.info(
325
+ "Following weights matched with "
326
+ + (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
327
+ + ":\n"
328
+ + table_str
329
+ )
330
+
331
+ unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
332
+ for k in unmatched_ckpt_keys:
333
+ result_state_dict[k] = ckpt_state_dict[k]
334
+ return result_state_dict
335
+
336
+
337
+ def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
338
+ """
339
+ Params in the same submodule are grouped together.
340
+
341
+ Args:
342
+ keys: names of all parameters
343
+ original_names: mapping from parameter name to their name in the checkpoint
344
+
345
+ Returns:
346
+ dict[name -> all other names in the same group]
347
+ """
348
+
349
+ def _submodule_name(key):
350
+ pos = key.rfind(".")
351
+ if pos < 0:
352
+ return None
353
+ prefix = key[: pos + 1]
354
+ return prefix
355
+
356
+ all_submodules = [_submodule_name(k) for k in keys]
357
+ all_submodules = [x for x in all_submodules if x]
358
+ all_submodules = sorted(all_submodules, key=len)
359
+
360
+ ret = {}
361
+ for prefix in all_submodules:
362
+ group = [k for k in keys if k.startswith(prefix)]
363
+ if len(group) <= 1:
364
+ continue
365
+ original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
366
+ if len(original_name_lcp) == 0:
367
+ # don't group weights if original names don't share prefix
368
+ continue
369
+
370
+ for k in group:
371
+ if k in ret:
372
+ continue
373
+ ret[k] = group
374
+ return ret
375
+
376
+
377
+ def _longest_common_prefix(names: List[str]) -> str:
378
+ """
379
+ ["abc.zfg", "abc.zef"] -> "abc."
380
+ """
381
+ names = [n.split(".") for n in names]
382
+ m1, m2 = min(names), max(names)
383
+ ret = [a for a, b in zip(m1, m2) if a == b]
384
+ ret = ".".join(ret) + "." if len(ret) else ""
385
+ return ret
386
+
387
+
388
+ def _longest_common_prefix_str(names: List[str]) -> str:
389
+ m1, m2 = min(names), max(names)
390
+ lcp = [a for a, b in zip(m1, m2) if a == b]
391
+ lcp = "".join(lcp)
392
+ return lcp
393
+
394
+
395
+ def _group_str(names: List[str]) -> str:
396
+ """
397
+ Turn "common1", "common2", "common3" into "common{1,2,3}"
398
+ """
399
+ lcp = _longest_common_prefix_str(names)
400
+ rest = [x[len(lcp) :] for x in names]
401
+ rest = "{" + ",".join(rest) + "}"
402
+ ret = lcp + rest
403
+
404
+ # add some simplification for BN specifically
405
+ ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
406
+ ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
407
+ return ret
detectron2/checkpoint/catalog.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+
4
+ from detectron2.utils.file_io import PathHandler, PathManager
5
+
6
+
7
+ class ModelCatalog(object):
8
+ """
9
+ Store mappings from names to third-party models.
10
+ """
11
+
12
+ S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron"
13
+
14
+ # MSRA models have STRIDE_IN_1X1=True. False otherwise.
15
+ # NOTE: all BN models here have fused BN into an affine layer.
16
+ # As a result, you should only load them to a model with "FrozenBN".
17
+ # Loading them to a model with regular BN or SyncBN is wrong.
18
+ # Even when loaded to FrozenBN, it is still different from affine by an epsilon,
19
+ # which should be negligible for training.
20
+ # NOTE: all models here uses PIXEL_STD=[1,1,1]
21
+ # NOTE: Most of the BN models here are no longer used. We use the
22
+ # re-converted pre-trained models under detectron2 model zoo instead.
23
+ C2_IMAGENET_MODELS = {
24
+ "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
25
+ "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
26
+ "FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
27
+ "FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
28
+ "FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
29
+ "FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl",
30
+ "FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl",
31
+ }
32
+
33
+ C2_DETECTRON_PATH_FORMAT = (
34
+ "{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950
35
+ )
36
+
37
+ C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival"
38
+ C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival"
39
+
40
+ # format: {model_name} -> part of the url
41
+ C2_DETECTRON_MODELS = {
42
+ "35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950
43
+ "35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950
44
+ "35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950
45
+ "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950
46
+ "35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950
47
+ "35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950
48
+ "35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950
49
+ "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950
50
+ "48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950
51
+ "37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950
52
+ "35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950
53
+ "35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950
54
+ "36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950
55
+ }
56
+
57
+ @staticmethod
58
+ def get(name):
59
+ if name.startswith("Caffe2Detectron/COCO"):
60
+ return ModelCatalog._get_c2_detectron_baseline(name)
61
+ if name.startswith("ImageNetPretrained/"):
62
+ return ModelCatalog._get_c2_imagenet_pretrained(name)
63
+ raise RuntimeError("model not present in the catalog: {}".format(name))
64
+
65
+ @staticmethod
66
+ def _get_c2_imagenet_pretrained(name):
67
+ prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX
68
+ name = name[len("ImageNetPretrained/") :]
69
+ name = ModelCatalog.C2_IMAGENET_MODELS[name]
70
+ url = "/".join([prefix, name])
71
+ return url
72
+
73
+ @staticmethod
74
+ def _get_c2_detectron_baseline(name):
75
+ name = name[len("Caffe2Detectron/COCO/") :]
76
+ url = ModelCatalog.C2_DETECTRON_MODELS[name]
77
+ if "keypoint_rcnn" in name:
78
+ dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS
79
+ else:
80
+ dataset = ModelCatalog.C2_DATASET_COCO
81
+
82
+ if "35998355/rpn_R-50-C4_1x" in name:
83
+ # this one model is somehow different from others ..
84
+ type = "rpn"
85
+ else:
86
+ type = "generalized_rcnn"
87
+
88
+ # Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`.
89
+ url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format(
90
+ prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset
91
+ )
92
+ return url
93
+
94
+
95
+ class ModelCatalogHandler(PathHandler):
96
+ """
97
+ Resolve URL like catalog://.
98
+ """
99
+
100
+ PREFIX = "catalog://"
101
+
102
+ def _get_supported_prefixes(self):
103
+ return [self.PREFIX]
104
+
105
+ def _get_local_path(self, path, **kwargs):
106
+ logger = logging.getLogger(__name__)
107
+ catalog_path = ModelCatalog.get(path[len(self.PREFIX) :])
108
+ logger.info("Catalog entry {} points to {}".format(path, catalog_path))
109
+ return PathManager.get_local_path(catalog_path, **kwargs)
110
+
111
+ def _open(self, path, mode="r", **kwargs):
112
+ return PathManager.open(self._get_local_path(path), mode, **kwargs)
113
+
114
+
115
+ PathManager.register_handler(ModelCatalogHandler())
detectron2/checkpoint/clip_model_loading.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import re
5
+ from typing import Dict, List
6
+ import torch
7
+ from tabulate import tabulate
8
+
9
+
10
+ def convert_basic_clip_names(original_keys, add_backbone_prefix=False, use_whole_clip=False, use_fpn_arch=False, regionclip=False):
11
+ """
12
+ Apply some basic name conversion to names in CLIP weights.
13
+ It only deals with typical backbone models.
14
+
15
+ Args:
16
+ original_keys (list[str]):
17
+ Returns:
18
+ list[str]: The same number of strings matching those in original_keys.
19
+ """
20
+ layer_keys = copy.deepcopy(original_keys)
21
+
22
+ vit = False
23
+ for l_k in layer_keys:
24
+ if 'visual.transformer' in l_k:
25
+ vit = True
26
+
27
+ # load pretrained oai clip
28
+ if not vit: # resnet
29
+ if add_backbone_prefix: # CLIPRCNN or CLIPFastRCNN
30
+ if use_whole_clip: # CLIPRCNN
31
+ layer_keys = [k.replace("visual.", "clip_backbone.visual.") for k in layer_keys]
32
+ else: # CLIPFastRCNN
33
+ if use_fpn_arch: # FPN
34
+ layer_keys = [k.replace("visual.", "backbone.bottom_up.") for k in layer_keys]
35
+ else: # C4
36
+ layer_keys = [k.replace("visual.", "backbone.") for k in layer_keys]
37
+ else: # GeneralizedRCNN or ProposalNetwork
38
+ #layer_keys = [k.replace("visual.", "backbone.bottom_up.") for k in layer_keys] #
39
+ layer_keys = [k.replace("visual.", "") for k in layer_keys] #
40
+ #layer_keys = [k.replace("visual.", "backbone.visual.") for k in layer_keys] #
41
+ else: # vit
42
+ pass
43
+
44
+ return layer_keys, vit
45
+
46
+
47
+ def convert_clip_names(weights, add_backbone_prefix=False, use_whole_clip=False, use_fpn_arch=False, regionclip=False):
48
+ """
49
+ Map CLIP Detectron weight names to Detectron2 names.
50
+
51
+ Args:
52
+ weights (dict): name -> tensor
53
+
54
+ Returns:
55
+ dict: detectron2 names -> tensor
56
+ dict: detectron2 names -> C2 names
57
+ """
58
+ logger = logging.getLogger(__name__)
59
+ logger.info("Renaming CLIP weights ......")
60
+ original_keys = sorted(weights.keys())
61
+ layer_keys = copy.deepcopy(original_keys)
62
+
63
+ layer_keys, use_vit = convert_basic_clip_names(layer_keys, add_backbone_prefix, use_whole_clip, use_fpn_arch, regionclip)
64
+
65
+ # --------------------------------------------------------------------------
66
+ # RPN hidden representation conv
67
+ # --------------------------------------------------------------------------
68
+ # FPN case
69
+ # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
70
+ # shared for all other levels, hence the appearance of "fpn2"
71
+ layer_keys = [
72
+ k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
73
+ ]
74
+ # Non-FPN case
75
+ layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
76
+
77
+ # --------------------------------------------------------------------------
78
+ # RPN box transformation conv
79
+ # --------------------------------------------------------------------------
80
+ # FPN case (see note above about "fpn2")
81
+ layer_keys = [
82
+ k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
83
+ for k in layer_keys
84
+ ]
85
+ layer_keys = [
86
+ k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
87
+ for k in layer_keys
88
+ ]
89
+ # Non-FPN case
90
+ layer_keys = [
91
+ k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
92
+ ]
93
+ layer_keys = [
94
+ k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
95
+ for k in layer_keys
96
+ ]
97
+
98
+ # --------------------------------------------------------------------------
99
+ # Fast R-CNN box head
100
+ # --------------------------------------------------------------------------
101
+ layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
102
+ layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
103
+ layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
104
+ layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
105
+ # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
106
+ layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
107
+
108
+ # --------------------------------------------------------------------------
109
+ # FPN lateral and output convolutions
110
+ # --------------------------------------------------------------------------
111
+ def fpn_map(name):
112
+ """
113
+ Look for keys with the following patterns:
114
+ 1) Starts with "fpn.inner."
115
+ Example: "fpn.inner.res2.2.sum.lateral.weight"
116
+ Meaning: These are lateral pathway convolutions
117
+ 2) Starts with "fpn.res"
118
+ Example: "fpn.res2.2.sum.weight"
119
+ Meaning: These are FPN output convolutions
120
+ """
121
+ splits = name.split(".")
122
+ norm = ".norm" if "norm" in splits else ""
123
+ if name.startswith("fpn.inner."):
124
+ # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
125
+ stage = int(splits[2][len("res") :])
126
+ return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
127
+ elif name.startswith("fpn.res"):
128
+ # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
129
+ stage = int(splits[1][len("res") :])
130
+ return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
131
+ return name
132
+
133
+ layer_keys = [fpn_map(k) for k in layer_keys]
134
+
135
+ # --------------------------------------------------------------------------
136
+ # Mask R-CNN mask head
137
+ # --------------------------------------------------------------------------
138
+ # roi_heads.StandardROIHeads case
139
+ layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
140
+ layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
141
+ layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
142
+ # roi_heads.Res5ROIHeads case
143
+ layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
144
+
145
+ # --------------------------------------------------------------------------
146
+ # Keypoint R-CNN head
147
+ # --------------------------------------------------------------------------
148
+ # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
149
+ layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
150
+ layer_keys = [
151
+ k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
152
+ ]
153
+ layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
154
+
155
+ # --------------------------------------------------------------------------
156
+ # Done with replacements
157
+ # --------------------------------------------------------------------------
158
+ assert len(set(layer_keys)) == len(layer_keys)
159
+ assert len(original_keys) == len(layer_keys)
160
+
161
+ new_weights = {}
162
+ new_keys_to_original_keys = {}
163
+ for orig, renamed in zip(original_keys, layer_keys):
164
+ new_keys_to_original_keys[renamed] = orig
165
+ if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
166
+ # remove the meaningless prediction weight for background class
167
+ new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
168
+ new_weights[renamed] = weights[orig][new_start_idx:]
169
+ logger.info(
170
+ "Remove prediction weight for background class in {}. The shape changes from "
171
+ "{} to {}.".format(
172
+ renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
173
+ )
174
+ )
175
+ elif renamed.startswith("cls_score."):
176
+ # move weights of bg class from original index 0 to last index
177
+ logger.info(
178
+ "Move classification weights for background class in {} from index 0 to "
179
+ "index {}.".format(renamed, weights[orig].shape[0] - 1)
180
+ )
181
+ new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
182
+ else:
183
+ new_weights[renamed] = weights[orig]
184
+
185
+ return new_weights, new_keys_to_original_keys, use_vit
186
+
187
+
188
+ # Note the current matching is not symmetric.
189
+ # it assumes model_state_dict will have longer names.
190
+ def align_and_update_state_dicts_for_CLIP(model_state_dict, ckpt_state_dict, c2_conversion=True, bb_rpn_weights=False, regionclip=False):
191
+ """
192
+ Extended from ./c2_model_loading.py
193
+ Match names between the two state-dict, and returns a new chkpt_state_dict with names
194
+ converted to match model_state_dict with heuristics. The returned dict can be later
195
+ loaded with fvcore checkpointer.
196
+ If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
197
+ model and will be renamed at first.
198
+
199
+ Strategy: suppose that the models that we will create will have prefixes appended
200
+ to each of its keys, for example due to an extra level of nesting that the original
201
+ pre-trained weights from ImageNet won't contain. For example, model.state_dict()
202
+ might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
203
+ res2.conv1.weight. We thus want to match both parameters together.
204
+ For that, we look for each model weight, look among all loaded keys if there is one
205
+ that is a suffix of the current weight name, and use it if that's the case.
206
+ If multiple matches exist, take the one with longest size
207
+ of the corresponding name. For example, for the same model as before, the pretrained
208
+ weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
209
+ we want to match backbone[0].body.conv1.weight to conv1.weight, and
210
+ backbone[0].body.res2.conv1.weight to res2.conv1.weight.
211
+ """
212
+ model_keys = sorted(model_state_dict.keys())
213
+ use_whole_clip = False # whether use the whole clip (text & visual encoders), typically in CLIPRCNN meta arch
214
+ add_backbone_prefix = False # convert to 'backbone.' prefix, typically in CLIPFastRCNN meta arch
215
+ use_fpn_arch = False # if use FPN arch then convert to `bottom_up`, typically in CLIPFastRCNN meta arch with FPN backbone
216
+ if bb_rpn_weights: # a 2nd pretrained weights to load, for offline backbone & RPN, then convert the ckpt key names and only keep the ones we need
217
+ new_ckpt_state_dict = {}
218
+ for original_k in ckpt_state_dict:
219
+ if 'backbone' in original_k:
220
+ new_key = original_k.replace('backbone', 'offline_backbone')
221
+ new_ckpt_state_dict[new_key] = ckpt_state_dict[original_k]
222
+ if 'proposal_generator' in original_k:
223
+ new_key = original_k.replace('proposal_generator', 'offline_proposal_generator')
224
+ new_ckpt_state_dict[new_key] = ckpt_state_dict[original_k]
225
+ new_ckpt_state_dict['ignore_others'] = torch.tensor([1]) # ignore other model weights (not 'offline_*') in batch_norm.py
226
+ ckpt_state_dict = new_ckpt_state_dict
227
+ else: # the 1st pretrained weigths to load
228
+ for model_key in model_keys: # if use the whole clip, then convert ckpt 'visual.' names to 'clip_backbone.visual.'
229
+ if 'clip_backbone' in model_key:
230
+ use_whole_clip = True
231
+ for model_key in model_keys: # if there are backbone & offline_backbone, then convert the ckpt 'visual.' names to 'backbone.' to avoid ambiguity
232
+ if 'offline_backbone' in model_key:
233
+ add_backbone_prefix = True
234
+ if 'fpn' in model_key:
235
+ use_fpn_arch = True
236
+ # original_keys: the name in the original dict (before renaming)
237
+ ckpt_state_dict, original_keys, use_vit = convert_clip_names(ckpt_state_dict, add_backbone_prefix, use_whole_clip, use_fpn_arch, regionclip)
238
+ ckpt_keys = sorted(ckpt_state_dict.keys())
239
+
240
+ def match(a, b):
241
+ # Matched ckpt_key should be a complete (starts with '.') suffix.
242
+ # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
243
+ # but matches whatever_conv1 or mesh_head.whatever_conv1.
244
+ return a == b or a.endswith("." + b)
245
+
246
+ # get a matrix of string matches, where each (i, j) entry correspond to the size of the
247
+ # ckpt_key string, if it matches
248
+ match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
249
+ match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
250
+ # use the matched one with longest size in case of multiple matches
251
+ max_match_size, idxs = match_matrix.max(1)
252
+ # remove indices that correspond to no-match
253
+ idxs[max_match_size == 0] = -1
254
+
255
+ logger = logging.getLogger(__name__)
256
+ # matched_pairs (matched checkpoint key --> matched model key)
257
+ matched_keys = {}
258
+ result_state_dict = {}
259
+ for idx_model, idx_ckpt in enumerate(idxs.tolist()):
260
+ if idx_ckpt == -1:
261
+ continue
262
+ key_model = model_keys[idx_model]
263
+ key_ckpt = ckpt_keys[idx_ckpt]
264
+ value_ckpt = ckpt_state_dict[key_ckpt]
265
+ shape_in_model = model_state_dict[key_model].shape
266
+
267
+ if shape_in_model != value_ckpt.shape:
268
+ logger.warning(
269
+ "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
270
+ key_ckpt, value_ckpt.shape, key_model, shape_in_model
271
+ )
272
+ )
273
+ logger.warning(
274
+ "{} will not be loaded. Please double check and see if this is desired.".format(
275
+ key_ckpt
276
+ )
277
+ )
278
+ continue
279
+
280
+ assert key_model not in result_state_dict
281
+ result_state_dict[key_model] = value_ckpt
282
+ if key_ckpt in matched_keys: # already added to matched_keys
283
+ logger.error(
284
+ "Ambiguity found for {} in checkpoint!"
285
+ "It matches at least two keys in the model ({} and {}).".format(
286
+ key_ckpt, key_model, matched_keys[key_ckpt]
287
+ )
288
+ )
289
+ raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
290
+
291
+ matched_keys[key_ckpt] = key_model
292
+
293
+ # logging:
294
+ matched_model_keys = sorted(matched_keys.values())
295
+ mmk_list = "The following model parameters are loaded from checkpoints:\n"
296
+ for mmk in matched_model_keys:
297
+ mmk_list += mmk + "\n"
298
+ if len(matched_model_keys) == 0:
299
+ logger.warning("No weights in checkpoint matched with model.")
300
+ return ckpt_state_dict
301
+ common_prefix = _longest_common_prefix(matched_model_keys)
302
+ rev_matched_keys = {v: k for k, v in matched_keys.items()}
303
+ original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
304
+
305
+ model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
306
+ table = []
307
+ memo = set()
308
+ for key_model in matched_model_keys:
309
+ if key_model in memo:
310
+ continue
311
+ if key_model in model_key_groups:
312
+ group = model_key_groups[key_model]
313
+ memo |= set(group)
314
+ shapes = [tuple(model_state_dict[k].shape) for k in group]
315
+ table.append(
316
+ (
317
+ _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
318
+ _group_str([original_keys[k] for k in group]),
319
+ " ".join([str(x).replace(" ", "") for x in shapes]),
320
+ )
321
+ )
322
+ else:
323
+ key_checkpoint = original_keys[key_model]
324
+ shape = str(tuple(model_state_dict[key_model].shape))
325
+ table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
326
+ table_str = tabulate(
327
+ table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"]
328
+ )
329
+ if len(table) != 1 and not use_vit: # do this for now; the table function has some bugs when the whole CLIP is loaded
330
+ logger.info(
331
+ "Following weights matched with "
332
+ + (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
333
+ + ":\n"
334
+ + table_str
335
+ )
336
+ else:
337
+ logger.info(mmk_list)
338
+
339
+ unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
340
+ for k in unmatched_ckpt_keys:
341
+ result_state_dict[k] = ckpt_state_dict[k]
342
+ return result_state_dict
343
+
344
+
345
+ def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
346
+ """
347
+ Params in the same submodule are grouped together.
348
+
349
+ Args:
350
+ keys: names of all parameters
351
+ original_names: mapping from parameter name to their name in the checkpoint
352
+
353
+ Returns:
354
+ dict[name -> all other names in the same group]
355
+ """
356
+
357
+ def _submodule_name(key):
358
+ pos = key.rfind(".")
359
+ if pos < 0:
360
+ return None
361
+ prefix = key[: pos + 1]
362
+ return prefix
363
+
364
+ all_submodules = [_submodule_name(k) for k in keys]
365
+ all_submodules = [x for x in all_submodules if x]
366
+ all_submodules = sorted(all_submodules, key=len)
367
+
368
+ ret = {}
369
+ for prefix in all_submodules:
370
+ group = [k for k in keys if k.startswith(prefix)]
371
+ if len(group) <= 1:
372
+ continue
373
+ original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
374
+ if len(original_name_lcp) == 0:
375
+ # don't group weights if original names don't share prefix
376
+ continue
377
+
378
+ for k in group:
379
+ if k in ret:
380
+ continue
381
+ ret[k] = group
382
+ return ret
383
+
384
+
385
+ def _longest_common_prefix(names: List[str]) -> str:
386
+ """
387
+ ["abc.zfg", "abc.zef"] -> "abc."
388
+ """
389
+ names = [n.split(".") for n in names]
390
+ m1, m2 = min(names), max(names)
391
+ ret = [a for a, b in zip(m1, m2) if a == b]
392
+ ret = ".".join(ret) + "." if len(ret) else ""
393
+ return ret
394
+
395
+
396
+ def _longest_common_prefix_str(names: List[str]) -> str:
397
+ m1, m2 = min(names), max(names)
398
+ lcp = [a for a, b in zip(m1, m2) if a == b]
399
+ lcp = "".join(lcp)
400
+ return lcp
401
+
402
+
403
+ def _group_str(names: List[str]) -> str:
404
+ """
405
+ Turn "common1", "common2", "common3" into "common{1,2,3}"
406
+ """
407
+ lcp = _longest_common_prefix_str(names)
408
+ rest = [x[len(lcp) :] for x in names]
409
+ rest = "{" + ",".join(rest) + "}"
410
+ ret = lcp + rest
411
+
412
+ # add some simplification for BN specifically
413
+ ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
414
+ ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
415
+ return ret
detectron2/checkpoint/detection_checkpoint.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import os
4
+ import pickle
5
+ import torch
6
+ from fvcore.common.checkpoint import Checkpointer
7
+ from torch.nn.parallel import DistributedDataParallel
8
+
9
+ import detectron2.utils.comm as comm
10
+ from detectron2.utils.env import TORCH_VERSION
11
+ from detectron2.utils.file_io import PathManager
12
+
13
+ from .c2_model_loading import align_and_update_state_dicts
14
+ from .clip_model_loading import align_and_update_state_dicts_for_CLIP
15
+
16
+ class DetectionCheckpointer(Checkpointer):
17
+ """
18
+ Same as :class:`Checkpointer`, but is able to:
19
+ 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models.
20
+ 2. correctly load checkpoints that are only available on the master worker
21
+ """
22
+
23
+ def __init__(self, model, save_dir="", *, save_to_disk=None, bb_rpn_weights=False, **checkpointables):
24
+ is_main_process = comm.is_main_process()
25
+ super().__init__(
26
+ model,
27
+ save_dir,
28
+ save_to_disk=is_main_process if save_to_disk is None else save_to_disk,
29
+ **checkpointables,
30
+ )
31
+ self.path_manager = PathManager
32
+ self.bb_rpn_weights = bb_rpn_weights
33
+
34
+ def load(self, path, *args, **kwargs):
35
+ need_sync = False
36
+
37
+ if path and isinstance(self.model, DistributedDataParallel):
38
+ logger = logging.getLogger(__name__)
39
+ path = self.path_manager.get_local_path(path)
40
+ has_file = os.path.isfile(path)
41
+ all_has_file = comm.all_gather(has_file)
42
+ if not all_has_file[0]:
43
+ raise OSError(f"File {path} not found on main worker.")
44
+ if not all(all_has_file):
45
+ logger.warning(
46
+ f"Not all workers can read checkpoint {path}. "
47
+ "Training may fail to fully resume."
48
+ )
49
+ # TODO: broadcast the checkpoint file contents from main
50
+ # worker, and load from it instead.
51
+ need_sync = True
52
+ if not has_file:
53
+ path = None # don't load if not readable
54
+ ret = super().load(path, *args, **kwargs)
55
+
56
+ if need_sync:
57
+ logger.info("Broadcasting model states from main worker ...")
58
+ if TORCH_VERSION >= (1, 7):
59
+ self.model._sync_params_and_buffers()
60
+ return ret
61
+
62
+ def _load_file(self, filename):
63
+ if filename.endswith(".pkl"):
64
+ with PathManager.open(filename, "rb") as f:
65
+ data = pickle.load(f, encoding="latin1")
66
+ if "model" in data and "__author__" in data:
67
+ # file is in Detectron2 model zoo format
68
+ self.logger.info("Reading a file from '{}'".format(data["__author__"]))
69
+ return data
70
+ else:
71
+ # assume file is from Caffe2 / Detectron1 model zoo
72
+ if "blobs" in data:
73
+ # Detection models have "blobs", but ImageNet models don't
74
+ data = data["blobs"]
75
+ data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
76
+ return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
77
+ elif filename.endswith(".pyth"):
78
+ # assume file is from pycls; no one else seems to use the ".pyth" extension
79
+ with PathManager.open(filename, "rb") as f:
80
+ data = torch.load(f)
81
+ assert (
82
+ "model_state" in data
83
+ ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'."
84
+ model_state = {
85
+ k: v
86
+ for k, v in data["model_state"].items()
87
+ if not k.endswith("num_batches_tracked")
88
+ }
89
+ return {"model": model_state, "__author__": "pycls", "matching_heuristics": True}
90
+ elif "OAI_CLIP" in filename:
91
+ # assume file is from OpenAI CLIP pre-trained model
92
+ loaded = super()._load_file(filename) # load native pth checkpoint
93
+ if "model" not in loaded:
94
+ loaded = {"model": loaded}
95
+ return {"model": loaded["model"], "__author__": "OAI_CLIP", "matching_heuristics": True}
96
+
97
+ loaded = super()._load_file(filename) # load native pth checkpoint
98
+ if "model" not in loaded:
99
+ loaded = {"model": loaded}
100
+ return loaded
101
+
102
+ def _load_model(self, checkpoint):
103
+ # if checkpoint.get("matching_heuristics", False) or self.bb_rpn_weights:
104
+ # self._convert_ndarray_to_tensor(checkpoint["model"])
105
+ # # convert weights by name-matching heuristics
106
+ # if checkpoint.get("__author__", "NA") == "OAI_CLIP" or self.bb_rpn_weights: # for OAI_CLIP or 2nd ckpt (offline modules)
107
+ # checkpoint["model"] = align_and_update_state_dicts_for_CLIP(
108
+ # self.model.state_dict(),
109
+ # checkpoint["model"],
110
+ # bb_rpn_weights=self.bb_rpn_weights,
111
+ # )
112
+ # else: # default loading
113
+ # checkpoint["model"] = align_and_update_state_dicts(
114
+ # self.model.state_dict(),
115
+ # checkpoint["model"],
116
+ # c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
117
+ # )
118
+ # for non-caffe2 models, use standard ways to load it
119
+ # if not self.bb_rpn_weights:
120
+ # checkpoint = {'model': {'backbone.' + key: val for key, val in checkpoint['model'].items()}}
121
+ incompatible = super()._load_model(checkpoint)
122
+ del checkpoint # try saving memory
123
+
124
+ model_buffers = dict(self.model.named_buffers(recurse=False))
125
+ for k in ["pixel_mean", "pixel_std"]:
126
+ # Ignore missing key message about pixel_mean/std.
127
+ # Though they may be missing in old checkpoints, they will be correctly
128
+ # initialized from config anyway.
129
+ if k in model_buffers:
130
+ try:
131
+ incompatible.missing_keys.remove(k)
132
+ except ValueError:
133
+ pass
134
+ return incompatible
detectron2/config/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .compat import downgrade_config, upgrade_config
3
+ from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
4
+ from .instantiate import instantiate
5
+ from .lazy import LazyCall, LazyConfig
6
+
7
+ __all__ = [
8
+ "CfgNode",
9
+ "get_cfg",
10
+ "global_cfg",
11
+ "set_global_cfg",
12
+ "downgrade_config",
13
+ "upgrade_config",
14
+ "configurable",
15
+ "instantiate",
16
+ "LazyCall",
17
+ "LazyConfig",
18
+ ]
19
+
20
+
21
+ from detectron2.utils.env import fixup_module_metadata
22
+
23
+ fixup_module_metadata(__name__, globals(), __all__)
24
+ del fixup_module_metadata
detectron2/config/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (617 Bytes). View file
 
detectron2/config/__pycache__/compat.cpython-39.pyc ADDED
Binary file (7.7 kB). View file
 
detectron2/config/__pycache__/config.cpython-39.pyc ADDED
Binary file (7.57 kB). View file
 
detectron2/config/__pycache__/defaults.cpython-39.pyc ADDED
Binary file (9.22 kB). View file
 
detectron2/config/__pycache__/instantiate.cpython-39.pyc ADDED
Binary file (2.58 kB). View file
 
detectron2/config/__pycache__/lazy.cpython-39.pyc ADDED
Binary file (11.5 kB). View file
 
detectron2/config/compat.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ """
3
+ Backward compatibility of configs.
4
+
5
+ Instructions to bump version:
6
+ + It's not needed to bump version if new keys are added.
7
+ It's only needed when backward-incompatible changes happen
8
+ (i.e., some existing keys disappear, or the meaning of a key changes)
9
+ + To bump version, do the following:
10
+ 1. Increment _C.VERSION in defaults.py
11
+ 2. Add a converter in this file.
12
+
13
+ Each ConverterVX has a function "upgrade" which in-place upgrades config from X-1 to X,
14
+ and a function "downgrade" which in-place downgrades config from X to X-1
15
+
16
+ In each function, VERSION is left unchanged.
17
+
18
+ Each converter assumes that its input has the relevant keys
19
+ (i.e., the input is not a partial config).
20
+ 3. Run the tests (test_config.py) to make sure the upgrade & downgrade
21
+ functions are consistent.
22
+ """
23
+
24
+ import logging
25
+ from typing import List, Optional, Tuple
26
+
27
+ from .config import CfgNode as CN
28
+ from .defaults import _C
29
+
30
+ __all__ = ["upgrade_config", "downgrade_config"]
31
+
32
+
33
+ def upgrade_config(cfg: CN, to_version: Optional[int] = None) -> CN:
34
+ """
35
+ Upgrade a config from its current version to a newer version.
36
+
37
+ Args:
38
+ cfg (CfgNode):
39
+ to_version (int): defaults to the latest version.
40
+ """
41
+ cfg = cfg.clone()
42
+ if to_version is None:
43
+ to_version = _C.VERSION
44
+
45
+ assert cfg.VERSION <= to_version, "Cannot upgrade from v{} to v{}!".format(
46
+ cfg.VERSION, to_version
47
+ )
48
+ for k in range(cfg.VERSION, to_version):
49
+ converter = globals()["ConverterV" + str(k + 1)]
50
+ converter.upgrade(cfg)
51
+ cfg.VERSION = k + 1
52
+ return cfg
53
+
54
+
55
+ def downgrade_config(cfg: CN, to_version: int) -> CN:
56
+ """
57
+ Downgrade a config from its current version to an older version.
58
+
59
+ Args:
60
+ cfg (CfgNode):
61
+ to_version (int):
62
+
63
+ Note:
64
+ A general downgrade of arbitrary configs is not always possible due to the
65
+ different functionalities in different versions.
66
+ The purpose of downgrade is only to recover the defaults in old versions,
67
+ allowing it to load an old partial yaml config.
68
+ Therefore, the implementation only needs to fill in the default values
69
+ in the old version when a general downgrade is not possible.
70
+ """
71
+ cfg = cfg.clone()
72
+ assert cfg.VERSION >= to_version, "Cannot downgrade from v{} to v{}!".format(
73
+ cfg.VERSION, to_version
74
+ )
75
+ for k in range(cfg.VERSION, to_version, -1):
76
+ converter = globals()["ConverterV" + str(k)]
77
+ converter.downgrade(cfg)
78
+ cfg.VERSION = k - 1
79
+ return cfg
80
+
81
+
82
+ def guess_version(cfg: CN, filename: str) -> int:
83
+ """
84
+ Guess the version of a partial config where the VERSION field is not specified.
85
+ Returns the version, or the latest if cannot make a guess.
86
+
87
+ This makes it easier for users to migrate.
88
+ """
89
+ logger = logging.getLogger(__name__)
90
+
91
+ def _has(name: str) -> bool:
92
+ cur = cfg
93
+ for n in name.split("."):
94
+ if n not in cur:
95
+ return False
96
+ cur = cur[n]
97
+ return True
98
+
99
+ # Most users' partial configs have "MODEL.WEIGHT", so guess on it
100
+ ret = None
101
+ if _has("MODEL.WEIGHT") or _has("TEST.AUG_ON"):
102
+ ret = 1
103
+
104
+ if ret is not None:
105
+ logger.warning("Config '{}' has no VERSION. Assuming it to be v{}.".format(filename, ret))
106
+ else:
107
+ ret = _C.VERSION
108
+ logger.warning(
109
+ "Config '{}' has no VERSION. Assuming it to be compatible with latest v{}.".format(
110
+ filename, ret
111
+ )
112
+ )
113
+ return ret
114
+
115
+
116
+ def _rename(cfg: CN, old: str, new: str) -> None:
117
+ old_keys = old.split(".")
118
+ new_keys = new.split(".")
119
+
120
+ def _set(key_seq: List[str], val: str) -> None:
121
+ cur = cfg
122
+ for k in key_seq[:-1]:
123
+ if k not in cur:
124
+ cur[k] = CN()
125
+ cur = cur[k]
126
+ cur[key_seq[-1]] = val
127
+
128
+ def _get(key_seq: List[str]) -> CN:
129
+ cur = cfg
130
+ for k in key_seq:
131
+ cur = cur[k]
132
+ return cur
133
+
134
+ def _del(key_seq: List[str]) -> None:
135
+ cur = cfg
136
+ for k in key_seq[:-1]:
137
+ cur = cur[k]
138
+ del cur[key_seq[-1]]
139
+ if len(cur) == 0 and len(key_seq) > 1:
140
+ _del(key_seq[:-1])
141
+
142
+ _set(new_keys, _get(old_keys))
143
+ _del(old_keys)
144
+
145
+
146
+ class _RenameConverter:
147
+ """
148
+ A converter that handles simple rename.
149
+ """
150
+
151
+ RENAME: List[Tuple[str, str]] = [] # list of tuples of (old name, new name)
152
+
153
+ @classmethod
154
+ def upgrade(cls, cfg: CN) -> None:
155
+ for old, new in cls.RENAME:
156
+ _rename(cfg, old, new)
157
+
158
+ @classmethod
159
+ def downgrade(cls, cfg: CN) -> None:
160
+ for old, new in cls.RENAME[::-1]:
161
+ _rename(cfg, new, old)
162
+
163
+
164
+ class ConverterV1(_RenameConverter):
165
+ RENAME = [("MODEL.RPN_HEAD.NAME", "MODEL.RPN.HEAD_NAME")]
166
+
167
+
168
+ class ConverterV2(_RenameConverter):
169
+ """
170
+ A large bulk of rename, before public release.
171
+ """
172
+
173
+ RENAME = [
174
+ ("MODEL.WEIGHT", "MODEL.WEIGHTS"),
175
+ ("MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT"),
176
+ ("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"),
177
+ ("MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT"),
178
+ ("MODEL.PANOPTIC_FPN.COMBINE_ON", "MODEL.PANOPTIC_FPN.COMBINE.ENABLED"),
179
+ (
180
+ "MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD",
181
+ "MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH",
182
+ ),
183
+ (
184
+ "MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT",
185
+ "MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT",
186
+ ),
187
+ (
188
+ "MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD",
189
+ "MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH",
190
+ ),
191
+ ("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"),
192
+ ("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"),
193
+ ("MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", "MODEL.RETINANET.SCORE_THRESH_TEST"),
194
+ ("MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", "MODEL.RETINANET.TOPK_CANDIDATES_TEST"),
195
+ ("MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", "MODEL.RETINANET.NMS_THRESH_TEST"),
196
+ ("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"),
197
+ ("TEST.AUG_ON", "TEST.AUG.ENABLED"),
198
+ ("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"),
199
+ ("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"),
200
+ ("TEST.AUG_FLIP", "TEST.AUG.FLIP"),
201
+ ]
202
+
203
+ @classmethod
204
+ def upgrade(cls, cfg: CN) -> None:
205
+ super().upgrade(cfg)
206
+
207
+ if cfg.MODEL.META_ARCHITECTURE == "RetinaNet":
208
+ _rename(
209
+ cfg, "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS"
210
+ )
211
+ _rename(cfg, "MODEL.RETINANET.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
212
+ del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"]
213
+ del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"]
214
+ else:
215
+ _rename(cfg, "MODEL.RPN.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS")
216
+ _rename(cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
217
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"]
218
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"]
219
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"]
220
+
221
+ @classmethod
222
+ def downgrade(cls, cfg: CN) -> None:
223
+ super().downgrade(cfg)
224
+
225
+ _rename(cfg, "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", "MODEL.RPN.ANCHOR_ASPECT_RATIOS")
226
+ _rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES")
227
+ cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS
228
+ cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES
229
+ cfg.MODEL.RETINANET.ANCHOR_STRIDES = [] # this is not used anywhere in any version
detectron2/config/config.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ import functools
5
+ import inspect
6
+ import logging
7
+ from fvcore.common.config import CfgNode as _CfgNode
8
+
9
+ from detectron2.utils.file_io import PathManager
10
+
11
+
12
+ class CfgNode(_CfgNode):
13
+ """
14
+ The same as `fvcore.common.config.CfgNode`, but different in:
15
+
16
+ 1. Use unsafe yaml loading by default.
17
+ Note that this may lead to arbitrary code execution: you must not
18
+ load a config file from untrusted sources before manually inspecting
19
+ the content of the file.
20
+ 2. Support config versioning.
21
+ When attempting to merge an old config, it will convert the old config automatically.
22
+ """
23
+
24
+ @classmethod
25
+ def _open_cfg(cls, filename):
26
+ return PathManager.open(filename, "r")
27
+
28
+ # Note that the default value of allow_unsafe is changed to True
29
+ def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
30
+ assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!"
31
+ loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
32
+ loaded_cfg = type(self)(loaded_cfg)
33
+
34
+ # defaults.py needs to import CfgNode
35
+ from .defaults import _C
36
+
37
+ latest_ver = _C.VERSION
38
+ assert (
39
+ latest_ver == self.VERSION
40
+ ), "CfgNode.merge_from_file is only allowed on a config object of latest version!"
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ loaded_ver = loaded_cfg.get("VERSION", None)
45
+ if loaded_ver is None:
46
+ from .compat import guess_version
47
+
48
+ loaded_ver = guess_version(loaded_cfg, cfg_filename)
49
+ assert loaded_ver <= self.VERSION, "Cannot merge a v{} config into a v{} config.".format(
50
+ loaded_ver, self.VERSION
51
+ )
52
+
53
+ if loaded_ver == self.VERSION:
54
+ self.merge_from_other_cfg(loaded_cfg)
55
+ else:
56
+ # compat.py needs to import CfgNode
57
+ from .compat import upgrade_config, downgrade_config
58
+
59
+ logger.warning(
60
+ "Loading an old v{} config file '{}' by automatically upgrading to v{}. "
61
+ "See docs/CHANGELOG.md for instructions to update your files.".format(
62
+ loaded_ver, cfg_filename, self.VERSION
63
+ )
64
+ )
65
+ # To convert, first obtain a full config at an old version
66
+ old_self = downgrade_config(self, to_version=loaded_ver)
67
+ old_self.merge_from_other_cfg(loaded_cfg)
68
+ new_config = upgrade_config(old_self)
69
+ self.clear()
70
+ self.update(new_config)
71
+
72
+ def dump(self, *args, **kwargs):
73
+ """
74
+ Returns:
75
+ str: a yaml string representation of the config
76
+ """
77
+ # to make it show up in docs
78
+ return super().dump(*args, **kwargs)
79
+
80
+
81
+ global_cfg = CfgNode()
82
+
83
+
84
+ def get_cfg() -> CfgNode:
85
+ """
86
+ Get a copy of the default config.
87
+
88
+ Returns:
89
+ a detectron2 CfgNode instance.
90
+ """
91
+ from .defaults import _C
92
+
93
+ return _C.clone()
94
+
95
+
96
+ def set_global_cfg(cfg: CfgNode) -> None:
97
+ """
98
+ Let the global config point to the given cfg.
99
+
100
+ Assume that the given "cfg" has the key "KEY", after calling
101
+ `set_global_cfg(cfg)`, the key can be accessed by:
102
+ ::
103
+ from detectron2.config import global_cfg
104
+ print(global_cfg.KEY)
105
+
106
+ By using a hacky global config, you can access these configs anywhere,
107
+ without having to pass the config object or the values deep into the code.
108
+ This is a hacky feature introduced for quick prototyping / research exploration.
109
+ """
110
+ global global_cfg
111
+ global_cfg.clear()
112
+ global_cfg.update(cfg)
113
+
114
+
115
+ def configurable(init_func=None, *, from_config=None):
116
+ """
117
+ Decorate a function or a class's __init__ method so that it can be called
118
+ with a :class:`CfgNode` object using a :func:`from_config` function that translates
119
+ :class:`CfgNode` to arguments.
120
+
121
+ Examples:
122
+ ::
123
+ # Usage 1: Decorator on __init__:
124
+ class A:
125
+ @configurable
126
+ def __init__(self, a, b=2, c=3):
127
+ pass
128
+
129
+ @classmethod
130
+ def from_config(cls, cfg): # 'cfg' must be the first argument
131
+ # Returns kwargs to be passed to __init__
132
+ return {"a": cfg.A, "b": cfg.B}
133
+
134
+ a1 = A(a=1, b=2) # regular construction
135
+ a2 = A(cfg) # construct with a cfg
136
+ a3 = A(cfg, b=3, c=4) # construct with extra overwrite
137
+
138
+ # Usage 2: Decorator on any function. Needs an extra from_config argument:
139
+ @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
140
+ def a_func(a, b=2, c=3):
141
+ pass
142
+
143
+ a1 = a_func(a=1, b=2) # regular call
144
+ a2 = a_func(cfg) # call with a cfg
145
+ a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
146
+
147
+ Args:
148
+ init_func (callable): a class's ``__init__`` method in usage 1. The
149
+ class must have a ``from_config`` classmethod which takes `cfg` as
150
+ the first argument.
151
+ from_config (callable): the from_config function in usage 2. It must take `cfg`
152
+ as its first argument.
153
+ """
154
+
155
+ if init_func is not None:
156
+ assert (
157
+ inspect.isfunction(init_func)
158
+ and from_config is None
159
+ and init_func.__name__ == "__init__"
160
+ ), "Incorrect use of @configurable. Check API documentation for examples."
161
+
162
+ @functools.wraps(init_func)
163
+ def wrapped(self, *args, **kwargs):
164
+ try:
165
+ from_config_func = type(self).from_config
166
+ except AttributeError as e:
167
+ raise AttributeError(
168
+ "Class with @configurable must have a 'from_config' classmethod."
169
+ ) from e
170
+ if not inspect.ismethod(from_config_func):
171
+ raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
172
+
173
+ if _called_with_cfg(*args, **kwargs):
174
+ explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
175
+ init_func(self, **explicit_args)
176
+ else:
177
+ init_func(self, *args, **kwargs)
178
+
179
+ return wrapped
180
+
181
+ else:
182
+ if from_config is None:
183
+ return configurable # @configurable() is made equivalent to @configurable
184
+ assert inspect.isfunction(
185
+ from_config
186
+ ), "from_config argument of configurable must be a function!"
187
+
188
+ def wrapper(orig_func):
189
+ @functools.wraps(orig_func)
190
+ def wrapped(*args, **kwargs):
191
+ if _called_with_cfg(*args, **kwargs):
192
+ explicit_args = _get_args_from_config(from_config, *args, **kwargs)
193
+ return orig_func(**explicit_args)
194
+ else:
195
+ return orig_func(*args, **kwargs)
196
+
197
+ return wrapped
198
+
199
+ return wrapper
200
+
201
+
202
+ def _get_args_from_config(from_config_func, *args, **kwargs):
203
+ """
204
+ Use `from_config` to obtain explicit arguments.
205
+
206
+ Returns:
207
+ dict: arguments to be used for cls.__init__
208
+ """
209
+ signature = inspect.signature(from_config_func)
210
+ if list(signature.parameters.keys())[0] != "cfg":
211
+ if inspect.isfunction(from_config_func):
212
+ name = from_config_func.__name__
213
+ else:
214
+ name = f"{from_config_func.__self__}.from_config"
215
+ raise TypeError(f"{name} must take 'cfg' as the first argument!")
216
+ support_var_arg = any(
217
+ param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
218
+ for param in signature.parameters.values()
219
+ )
220
+ if support_var_arg: # forward all arguments to from_config, if from_config accepts them
221
+ ret = from_config_func(*args, **kwargs)
222
+ else:
223
+ # forward supported arguments to from_config
224
+ supported_arg_names = set(signature.parameters.keys())
225
+ extra_kwargs = {}
226
+ for name in list(kwargs.keys()):
227
+ if name not in supported_arg_names:
228
+ extra_kwargs[name] = kwargs.pop(name)
229
+ ret = from_config_func(*args, **kwargs)
230
+ # forward the other arguments to __init__
231
+ ret.update(extra_kwargs)
232
+ return ret
233
+
234
+
235
+ def _called_with_cfg(*args, **kwargs):
236
+ """
237
+ Returns:
238
+ bool: whether the arguments contain CfgNode and should be considered
239
+ forwarded to from_config.
240
+ """
241
+ from omegaconf import DictConfig
242
+
243
+ if len(args) and isinstance(args[0], (_CfgNode, DictConfig)):
244
+ return True
245
+ if isinstance(kwargs.pop("cfg", None), (_CfgNode, DictConfig)):
246
+ return True
247
+ # `from_config`'s first argument is forced to be "cfg".
248
+ # So the above check covers all cases.
249
+ return False
detectron2/config/defaults.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .config import CfgNode as CN
3
+
4
+ # -----------------------------------------------------------------------------
5
+ # Convention about Training / Test specific parameters
6
+ # -----------------------------------------------------------------------------
7
+ # Whenever an argument can be either used for training or for testing, the
8
+ # corresponding name will be post-fixed by a _TRAIN for a training parameter,
9
+ # or _TEST for a test-specific parameter.
10
+ # For example, the number of images during training will be
11
+ # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
12
+ # IMAGES_PER_BATCH_TEST
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Config definition
16
+ # -----------------------------------------------------------------------------
17
+
18
+ _C = CN()
19
+
20
+ # The version number, to upgrade from old configs to new ones if any
21
+ # changes happen. It's recommended to keep a VERSION in your config file.
22
+ _C.VERSION = 2
23
+
24
+ _C.MODEL = CN()
25
+ _C.MODEL.LOAD_PROPOSALS = False
26
+ _C.MODEL.MASK_ON = False
27
+ _C.MODEL.KEYPOINT_ON = False
28
+ _C.MODEL.DEVICE = "cuda"
29
+ _C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
30
+
31
+ # Path (a file path, or URL like detectron2://.., https://..) to a checkpoint file
32
+ # to be loaded to the model. You can find available models in the model zoo.
33
+ _C.MODEL.WEIGHTS = ""
34
+
35
+ # Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR).
36
+ # To train on images of different number of channels, just set different mean & std.
37
+ # Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
38
+ _C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
39
+ # When using pre-trained models in Detectron1 or any MSRA models,
40
+ # std has been absorbed into its conv1 weights, so the std needs to be set 1.
41
+ # Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
42
+ _C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0]
43
+
44
+
45
+ # -----------------------------------------------------------------------------
46
+ # INPUT
47
+ # -----------------------------------------------------------------------------
48
+ _C.INPUT = CN()
49
+ # Size of the smallest side of the image during training
50
+ _C.INPUT.MIN_SIZE_TRAIN = (800,)
51
+ # Sample size of smallest side by choice or random selection from range give by
52
+ # INPUT.MIN_SIZE_TRAIN
53
+ _C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
54
+ # Maximum size of the side of the image during training
55
+ _C.INPUT.MAX_SIZE_TRAIN = 1333
56
+ # Size of the smallest side of the image during testing. Set to zero to disable resize in testing.
57
+ _C.INPUT.MIN_SIZE_TEST = 800
58
+ # Maximum size of the side of the image during testing
59
+ _C.INPUT.MAX_SIZE_TEST = 1333
60
+ # Mode for flipping images used in data augmentation during training
61
+ # choose one of ["horizontal, "vertical", "none"]
62
+ _C.INPUT.RANDOM_FLIP = "horizontal"
63
+
64
+ # `True` if cropping is used for data augmentation during training
65
+ _C.INPUT.CROP = CN({"ENABLED": False})
66
+ # Cropping type. See documentation of `detectron2.data.transforms.RandomCrop` for explanation.
67
+ _C.INPUT.CROP.TYPE = "relative_range"
68
+ # Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of
69
+ # pixels if CROP.TYPE is "absolute"
70
+ _C.INPUT.CROP.SIZE = [0.9, 0.9]
71
+
72
+
73
+ # Whether the model needs RGB, YUV, HSV etc.
74
+ # Should be one of the modes defined here, as we use PIL to read the image:
75
+ # https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
76
+ # with BGR being the one exception. One can set image format to BGR, we will
77
+ # internally use RGB for conversion and flip the channels over
78
+ _C.INPUT.FORMAT = "BGR"
79
+ # The ground truth mask format that the model will use.
80
+ # Mask R-CNN supports either "polygon" or "bitmask" as ground truth.
81
+ _C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask"
82
+
83
+ ################### Text Tokenizer from MSR-CLIP ##################
84
+ _C.INPUT.TEXT_TOKENIZER = "openai_bpe" # "bert-base-cased"
85
+
86
+ ################## Data Augmentation from MSR-CLIP ##################
87
+ _C.AUG = CN()
88
+ _C.AUG.SCALE = (0.08, 1.0)
89
+ _C.AUG.RATIO = (3.0/4.0, 4.0/3.0)
90
+ _C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0]
91
+ _C.AUG.GRAY_SCALE = 0.0
92
+ _C.AUG.GAUSSIAN_BLUR = 0.0
93
+ _C.AUG.DROPBLOCK_LAYERS = [3, 4]
94
+ _C.AUG.DROPBLOCK_KEEP_PROB = 1.0
95
+ _C.AUG.DROPBLOCK_BLOCK_SIZE = 7
96
+ _C.AUG.MIXUP_PROB = 0.0
97
+ _C.AUG.MIXUP = 0.0
98
+ _C.AUG.MIXCUT = 0.0
99
+ _C.AUG.MIXCUT_MINMAX = []
100
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
101
+ _C.AUG.MIXUP_MODE = 'batch'
102
+ _C.AUG.MIXCUT_AND_MIXUP = False
103
+ _C.AUG.INTERPOLATION = 3
104
+ _C.AUG.USE_TIMM = False
105
+ _C.AUG.TIMM_AUG = CN(new_allowed=True)
106
+ _C.AUG.TIMM_AUG.USE_LOADER = False
107
+ _C.AUG.TIMM_AUG.USE_TRANSFORM = False
108
+
109
+ _C.AUG.TRAIN = CN()
110
+ _C.AUG.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
111
+ _C.AUG.TRAIN.MAX_SIZE = None # the maximum size for longer edge after resizing
112
+ _C.AUG.TEST = CN()
113
+ _C.AUG.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
114
+ _C.AUG.TEST.MAX_SIZE = None # the maximum size for longer edge after resizing
115
+ _C.AUG.TEST.CENTER_CROP = False
116
+ _C.AUG.TEST.INTERPOLATION = 3
117
+
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # Dataset
121
+ # -----------------------------------------------------------------------------
122
+ _C.DATASETS = CN()
123
+ # List of the dataset names for training. Must be registered in DatasetCatalog
124
+ # Samples from these datasets will be merged and used as one dataset.
125
+ _C.DATASETS.TRAIN = ()
126
+ # List of the pre-computed proposal files for training, which must be consistent
127
+ # with datasets listed in DATASETS.TRAIN.
128
+ _C.DATASETS.PROPOSAL_FILES_TRAIN = ()
129
+ # Number of top scoring precomputed proposals to keep for training
130
+ _C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000
131
+ # List of the dataset names for testing. Must be registered in DatasetCatalog
132
+ _C.DATASETS.TEST = ()
133
+ # List of the pre-computed proposal files for test, which must be consistent
134
+ # with datasets listed in DATASETS.TEST.
135
+ _C.DATASETS.PROPOSAL_FILES_TEST = ()
136
+ # Number of top scoring precomputed proposals to keep for test
137
+ _C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000
138
+ ################## Data Loading from MSR-CLIP ##################
139
+ # List of dataset class names for training
140
+ _C.DATASETS.FACTORY_TRAIN = ()
141
+ # List of dataset folder for training
142
+ _C.DATASETS.PATH_TRAIN = ()
143
+ # List of the dataset names for auxilary training, as present in paths_catalog.py
144
+ _C.DATASETS.AUX = ()
145
+ # List of dataset class names for auxilary training
146
+ _C.DATASETS.FACTORY_AUX = ()
147
+ # List of dataset folder for auxilary training
148
+ _C.DATASETS.PATH_AUX = ()
149
+ # List of dataset class names for testing
150
+ _C.DATASETS.FACTORY_TEST = ()
151
+ # List of dataset folder for testing
152
+ _C.DATASETS.PATH_TEST = ()
153
+ # Labelmap file to convert to tsv or for demo purpose
154
+ _C.DATASETS.LABELMAP_FILE = ''
155
+ _C.DATASETS.ATTR_LABELMAP_FILE = ''
156
+ _C.DATASETS.FILTERED_CLASSIFICATION_DATASETS = ''
157
+ # hierarchy file for test time score aggregation (developed on OpenImages)
158
+ _C.DATASETS.HIERARCHY_FILE = ''
159
+ # List of box extra fields for training/testing
160
+ # If given, will not infer from the other cfgs.
161
+ _C.DATASETS.BOX_EXTRA_FIELDS = ()
162
+
163
+ _C.DATASETS.NUM_CLASSES = 0
164
+ _C.DATASETS.ROOT = ''
165
+ _C.DATASETS.TRAIN_SET = 'train'
166
+ _C.DATASETS.VAL_SET = ''
167
+ _C.DATASETS.TEST_SET = 'val'
168
+
169
+ # The maximum total input sequence length after WordPiece tokenization
170
+ # Sequences longer than this will be truncated, and sequences shorter than this will be padded.
171
+ _C.DATASETS.MAX_SEQ_LENGTH = 35
172
+
173
+ # -----------------------------------------------------------------------------
174
+ # DataLoader
175
+ # -----------------------------------------------------------------------------
176
+ _C.DATALOADER = CN()
177
+ # Number of data loading threads
178
+ _C.DATALOADER.NUM_WORKERS = 4
179
+ # If True, each batch should contain only images for which the aspect ratio
180
+ # is compatible. This groups portrait images together, and landscape images
181
+ # are not batched with portrait images.
182
+ _C.DATALOADER.ASPECT_RATIO_GROUPING = True
183
+ # Options: TrainingSampler, RepeatFactorTrainingSampler
184
+ _C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
185
+ # Repeat threshold for RepeatFactorTrainingSampler
186
+ _C.DATALOADER.REPEAT_THRESHOLD = 0.0
187
+ # Tf True, when working on datasets that have instance annotations, the
188
+ # training dataloader will filter out images without associated annotations
189
+ _C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True
190
+
191
+ # ---------------------------------------------------------------------------- #
192
+ # CLIP options
193
+ # ---------------------------------------------------------------------------- #
194
+ _C.MODEL.CLIP = CN()
195
+
196
+ _C.MODEL.CLIP.CROP_REGION_TYPE = "" # options: "GT", "RPN"
197
+ _C.MODEL.CLIP.BB_RPN_WEIGHTS = None # the weights of pretrained MaskRCNN
198
+ _C.MODEL.CLIP.IMS_PER_BATCH_TEST = 8 # the #images during inference per batch
199
+
200
+ _C.MODEL.CLIP.USE_TEXT_EMB_CLASSIFIER = False # if True, use the CLIP text embedding as the classifier's weights
201
+ _C.MODEL.CLIP.TEXT_EMB_PATH = None # "/mnt/output_storage/trained_models/lvis_cls_emb/lvis_1203_cls_emb.pth"
202
+ _C.MODEL.CLIP.OFFLINE_RPN_CONFIG = None # option: all configs of pretrained RPN
203
+ _C.MODEL.CLIP.NO_BOX_DELTA = False # if True, during inference, no box delta will be applied to region proposals
204
+
205
+ _C.MODEL.CLIP.BG_CLS_LOSS_WEIGHT = None # if not None, it is the loss weight for bg regions
206
+ _C.MODEL.CLIP.ONLY_SAMPLE_FG_PROPOSALS = False # if True, during training, ignore all bg proposals and only sample fg proposals
207
+ _C.MODEL.CLIP.MULTIPLY_RPN_SCORE = False # if True, during inference, multiply RPN scores with classification scores
208
+
209
+ _C.MODEL.CLIP.OPENSET_TEST_NUM_CLASSES = None # if an integer, it is #all_cls in test
210
+ _C.MODEL.CLIP.OPENSET_TEST_TEXT_EMB_PATH = None # if not None, enables the openset/zero-shot training, the category embeddings during test
211
+
212
+ _C.MODEL.CLIP.CLSS_TEMP = None # if None, dot product wo normalization & temperature; if float, normalization plus temperature
213
+ _C.MODEL.CLIP.RUN_CVPR_OVR = False # if True, train CVPR OVR model with their text embeddings
214
+ _C.MODEL.CLIP.FOCAL_SCALED_LOSS = None # if not None (float value for gamma), apply focal loss scaling idea to standard cross-entropy loss
215
+
216
+ _C.MODEL.CLIP.OFFLINE_RPN_NMS_THRESH = None # the threshold of NMS in offline RPN
217
+ _C.MODEL.CLIP.PRETRAIN_IMG_TXT_LEVEL = True # if True, pretrain model using image-text level matching
218
+ _C.MODEL.CLIP.PRETRAIN_ONLY_EOT = False # if True, use end-of-token emb to match region features, in image-text level matching
219
+ _C.MODEL.CLIP.PRETRAIN_RPN_REGIONS = None # if not None, the number of RPN regions per image during pretraining
220
+ _C.MODEL.CLIP.PRETRAIN_SAMPLE_REGIONS = None # if not None, the number of regions per image during pretraining after sampling, to avoid overfitting
221
+ _C.MODEL.CLIP.GATHER_GPUS = False # if True, gather tensors across GPUS to increase batch size
222
+ _C.MODEL.CLIP.GRID_REGIONS = False # if True, use grid boxes to extract grid features, instead of object proposals
223
+ _C.MODEL.CLIP.CONCEPT_POOL_EMB = None # if not None, it provides the file path of embs of concept pool and thus enables region-concept matching
224
+ _C.MODEL.CLIP.CONCEPT_THRES = None # if not None, the threshold to filter out the regions with low matching score with concept embs, dependent on temp (default: 0.01)
225
+
226
+ _C.MODEL.CLIP.OFFLINE_RPN_LSJ_PRETRAINED = False # if True, use large-scale jittering (LSJ) pretrained RPN
227
+ _C.MODEL.CLIP.TEACHER_RESNETS_DEPTH = 50 # the type of visual encoder of teacher model, sucha as ResNet 50, 101, 200 (a flag for 50x4)
228
+ _C.MODEL.CLIP.TEACHER_CONCEPT_POOL_EMB = None # if not None, it uses the same concept embedding as student model; otherwise, uses a seperate embedding of teacher model
229
+ _C.MODEL.CLIP.TEACHER_POOLER_RESOLUTION = 14 # RoIpooling resolution of teacher model
230
+
231
+ _C.MODEL.CLIP.TEXT_EMB_DIM = 1024 # the dimension of precomputed class embeddings
232
+
233
+ # ---------------------------------------------------------------------------- #
234
+ # Backbone options
235
+ # ---------------------------------------------------------------------------- #
236
+ _C.MODEL.BACKBONE = CN()
237
+
238
+ _C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
239
+ # Freeze the first several stages so they are not trained.
240
+ # There are 5 stages in ResNet. The first is a convolution, and the following
241
+ # stages are each group of residual blocks.
242
+ _C.MODEL.BACKBONE.FREEZE_AT = 2
243
+
244
+ _C.MODEL.TEXT_BACKBONE = CN()
245
+ _C.MODEL.TEXT_BACKBONE.NAME = "build_clip_swin_text_backbone"
246
+
247
+
248
+ # ---------------------------------------------------------------------------- #
249
+ # FPN options
250
+ # ---------------------------------------------------------------------------- #
251
+ _C.MODEL.FPN = CN()
252
+ # Names of the input feature maps to be used by FPN
253
+ # They must have contiguous power of 2 strides
254
+ # e.g., ["res2", "res3", "res4", "res5"]
255
+ _C.MODEL.FPN.IN_FEATURES = []
256
+ _C.MODEL.FPN.OUT_CHANNELS = 256
257
+
258
+ # Options: "" (no norm), "GN"
259
+ _C.MODEL.FPN.NORM = ""
260
+
261
+ # Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg"
262
+ _C.MODEL.FPN.FUSE_TYPE = "sum"
263
+
264
+
265
+ # ---------------------------------------------------------------------------- #
266
+ # Proposal generator options
267
+ # ---------------------------------------------------------------------------- #
268
+ _C.MODEL.PROPOSAL_GENERATOR = CN()
269
+ # Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals"
270
+ _C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN"
271
+ # Proposal height and width both need to be greater than MIN_SIZE
272
+ # (a the scale used during training or inference)
273
+ _C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0
274
+
275
+
276
+ # ---------------------------------------------------------------------------- #
277
+ # Anchor generator options
278
+ # ---------------------------------------------------------------------------- #
279
+ _C.MODEL.ANCHOR_GENERATOR = CN()
280
+ # The generator can be any name in the ANCHOR_GENERATOR registry
281
+ _C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator"
282
+ # Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input.
283
+ # Format: list[list[float]]. SIZES[i] specifies the list of sizes to use for
284
+ # IN_FEATURES[i]; len(SIZES) must be equal to len(IN_FEATURES) or 1.
285
+ # When len(SIZES) == 1, SIZES[0] is used for all IN_FEATURES.
286
+ _C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]]
287
+ # Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect
288
+ # ratios are generated by an anchor generator.
289
+ # Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W)
290
+ # to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true,
291
+ # or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used
292
+ # for all IN_FEATURES.
293
+ _C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]]
294
+ # Anchor angles.
295
+ # list[list[float]], the angle in degrees, for each input feature map.
296
+ # ANGLES[i] specifies the list of angles for IN_FEATURES[i].
297
+ _C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]]
298
+ # Relative offset between the center of the first anchor and the top-left corner of the image
299
+ # Value has to be in [0, 1). Recommend to use 0.5, which means half stride.
300
+ # The value is not expected to affect model accuracy.
301
+ _C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0
302
+
303
+ # ---------------------------------------------------------------------------- #
304
+ # RPN options
305
+ # ---------------------------------------------------------------------------- #
306
+ _C.MODEL.RPN = CN()
307
+ _C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY
308
+
309
+ # Names of the input feature maps to be used by RPN
310
+ # e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN
311
+ _C.MODEL.RPN.IN_FEATURES = ["res4"]
312
+ # Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels
313
+ # Set to -1 or a large value, e.g. 100000, to disable pruning anchors
314
+ _C.MODEL.RPN.BOUNDARY_THRESH = -1
315
+ # IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD]
316
+ # Minimum overlap required between an anchor and ground-truth box for the
317
+ # (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
318
+ # ==> positive RPN example: 1)
319
+ # Maximum overlap allowed between an anchor and ground-truth box for the
320
+ # (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
321
+ # ==> negative RPN example: 0)
322
+ # Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD)
323
+ # are ignored (-1)
324
+ _C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7]
325
+ _C.MODEL.RPN.IOU_LABELS = [0, -1, 1]
326
+ # Number of regions per image used to train RPN
327
+ _C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
328
+ # Target fraction of foreground (positive) examples per RPN minibatch
329
+ _C.MODEL.RPN.POSITIVE_FRACTION = 0.5
330
+ # Options are: "smooth_l1", "giou"
331
+ _C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1"
332
+ _C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0
333
+ # Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets
334
+ _C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
335
+ # The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
336
+ _C.MODEL.RPN.SMOOTH_L1_BETA = 0.0
337
+ _C.MODEL.RPN.LOSS_WEIGHT = 1.0
338
+ # Number of top scoring RPN proposals to keep before applying NMS
339
+ # When FPN is used, this is *per FPN level* (not total)
340
+ _C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000
341
+ _C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000
342
+ # Number of top scoring RPN proposals to keep after applying NMS
343
+ # When FPN is used, this limit is applied per level and then again to the union
344
+ # of proposals from all levels
345
+ # NOTE: When FPN is used, the meaning of this config is different from Detectron1.
346
+ # It means per-batch topk in Detectron1, but per-image topk here.
347
+ # See the "find_top_rpn_proposals" function for details.
348
+ _C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000
349
+ _C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000
350
+ # NMS threshold used on RPN proposals
351
+ _C.MODEL.RPN.NMS_THRESH = 0.7
352
+ # Set this to -1 to use the same number of output channels as input channels.
353
+ _C.MODEL.RPN.CONV_DIMS = [-1]
354
+
355
+ # ---------------------------------------------------------------------------- #
356
+ # ROI HEADS options
357
+ # ---------------------------------------------------------------------------- #
358
+ _C.MODEL.ROI_HEADS = CN()
359
+ _C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads"
360
+ # Number of foreground classes
361
+ _C.MODEL.ROI_HEADS.NUM_CLASSES = 80
362
+ # Names of the input feature maps to be used by ROI heads
363
+ # Currently all heads (box, mask, ...) use the same input feature map list
364
+ # e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN
365
+ _C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"]
366
+ # IOU overlap ratios [IOU_THRESHOLD]
367
+ # Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD)
368
+ # Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD)
369
+ _C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5]
370
+ _C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1]
371
+ # RoI minibatch size *per image* (number of regions of interest [ROIs])
372
+ # Total number of RoIs per training minibatch =
373
+ # ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH
374
+ # E.g., a common configuration is: 512 * 16 = 8192
375
+ _C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
376
+ # Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
377
+ _C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
378
+
379
+ # Only used on test mode
380
+
381
+ # Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
382
+ # balance obtaining high recall with not having too many low precision
383
+ # detections that will slow down inference post processing steps (like NMS)
384
+ # A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down
385
+ # inference.
386
+ _C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05
387
+ # Overlap threshold used for non-maximum suppression (suppress boxes with
388
+ # IoU >= this threshold)
389
+ _C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
390
+ # If True, augment proposals with ground-truth boxes before sampling proposals to
391
+ # train ROI heads.
392
+ _C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True
393
+
394
+ # Use soft NMS instead of standard NMS if set to True
395
+ _C.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
396
+ # See soft NMS paper for definition of these options
397
+ _C.MODEL.ROI_HEADS.SOFT_NMS_METHOD = "gaussian" # "linear"
398
+ _C.MODEL.ROI_HEADS.SOFT_NMS_SIGMA = 0.5
399
+ # For the linear_threshold we use NMS_THRESH_TEST
400
+ _C.MODEL.ROI_HEADS.SOFT_NMS_PRUNE = 0.001
401
+
402
+ # ---------------------------------------------------------------------------- #
403
+ # Box Head
404
+ # ---------------------------------------------------------------------------- #
405
+ _C.MODEL.ROI_BOX_HEAD = CN()
406
+ # C4 don't use head name option
407
+ # Options for non-C4 models: FastRCNNConvFCHead,
408
+ _C.MODEL.ROI_BOX_HEAD.NAME = ""
409
+ # Options are: "smooth_l1", "giou"
410
+ _C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1"
411
+ # The final scaling coefficient on the box regression loss, used to balance the magnitude of its
412
+ # gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`.
413
+ _C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT = 1.0
414
+ # Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
415
+ # These are empirically chosen to approximately lead to unit variance targets
416
+ _C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
417
+ # The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
418
+ _C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0
419
+ _C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
420
+ _C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
421
+ # Type of pooling operation applied to the incoming feature map for each RoI
422
+ _C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2"
423
+
424
+ _C.MODEL.ROI_BOX_HEAD.NUM_FC = 0
425
+ # Hidden layer dimension for FC layers in the RoI box head
426
+ _C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024
427
+ _C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0
428
+ # Channel dimension for Conv layers in the RoI box head
429
+ _C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256
430
+ # Normalization method for the convolution layers.
431
+ # Options: "" (no norm), "GN", "SyncBN".
432
+ _C.MODEL.ROI_BOX_HEAD.NORM = ""
433
+ # Whether to use class agnostic for bbox regression
434
+ _C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False
435
+ # If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes.
436
+ _C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False
437
+
438
+ # ---------------------------------------------------------------------------- #
439
+ # Cascaded Box Head
440
+ # ---------------------------------------------------------------------------- #
441
+ _C.MODEL.ROI_BOX_CASCADE_HEAD = CN()
442
+ # The number of cascade stages is implicitly defined by the length of the following two configs.
443
+ _C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = (
444
+ (10.0, 10.0, 5.0, 5.0),
445
+ (20.0, 20.0, 10.0, 10.0),
446
+ (30.0, 30.0, 15.0, 15.0),
447
+ )
448
+ _C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7)
449
+
450
+
451
+ # ---------------------------------------------------------------------------- #
452
+ # Mask Head
453
+ # ---------------------------------------------------------------------------- #
454
+ _C.MODEL.ROI_MASK_HEAD = CN()
455
+ _C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead"
456
+ _C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
457
+ _C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
458
+ _C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head
459
+ _C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256
460
+ # Normalization method for the convolution layers.
461
+ # Options: "" (no norm), "GN", "SyncBN".
462
+ _C.MODEL.ROI_MASK_HEAD.NORM = ""
463
+ # Whether to use class agnostic for mask prediction
464
+ _C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False
465
+ # Type of pooling operation applied to the incoming feature map for each RoI
466
+ _C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2"
467
+
468
+
469
+ # ---------------------------------------------------------------------------- #
470
+ # Keypoint Head
471
+ # ---------------------------------------------------------------------------- #
472
+ _C.MODEL.ROI_KEYPOINT_HEAD = CN()
473
+ _C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead"
474
+ _C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14
475
+ _C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0
476
+ _C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8))
477
+ _C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO.
478
+
479
+ # Images with too few (or no) keypoints are excluded from training.
480
+ _C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1
481
+ # Normalize by the total number of visible keypoints in the minibatch if True.
482
+ # Otherwise, normalize by the total number of keypoints that could ever exist
483
+ # in the minibatch.
484
+ # The keypoint softmax loss is only calculated on visible keypoints.
485
+ # Since the number of visible keypoints can vary significantly between
486
+ # minibatches, this has the effect of up-weighting the importance of
487
+ # minibatches with few visible keypoints. (Imagine the extreme case of
488
+ # only one visible keypoint versus N: in the case of N, each one
489
+ # contributes 1/N to the gradient compared to the single keypoint
490
+ # determining the gradient direction). Instead, we can normalize the
491
+ # loss by the total number of keypoints, if it were the case that all
492
+ # keypoints were visible in a full minibatch. (Returning to the example,
493
+ # this means that the one visible keypoint contributes as much as each
494
+ # of the N keypoints.)
495
+ _C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True
496
+ # Multi-task loss weight to use for keypoints
497
+ # Recommended values:
498
+ # - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True
499
+ # - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False
500
+ _C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0
501
+ # Type of pooling operation applied to the incoming feature map for each RoI
502
+ _C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2"
503
+
504
+ # ---------------------------------------------------------------------------- #
505
+ # Semantic Segmentation Head
506
+ # ---------------------------------------------------------------------------- #
507
+ _C.MODEL.SEM_SEG_HEAD = CN()
508
+ _C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead"
509
+ _C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"]
510
+ # Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for
511
+ # the correposnding pixel.
512
+ _C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255
513
+ # Number of classes in the semantic segmentation head
514
+ _C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54
515
+ # Number of channels in the 3x3 convs inside semantic-FPN heads.
516
+ _C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128
517
+ # Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride.
518
+ _C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
519
+ # Normalization method for the convolution layers. Options: "" (no norm), "GN".
520
+ _C.MODEL.SEM_SEG_HEAD.NORM = "GN"
521
+ _C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0
522
+
523
+ _C.MODEL.PANOPTIC_FPN = CN()
524
+ # Scaling of all losses from instance detection / segmentation head.
525
+ _C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0
526
+
527
+ # options when combining instance & semantic segmentation outputs
528
+ _C.MODEL.PANOPTIC_FPN.COMBINE = CN({"ENABLED": True}) # "COMBINE.ENABLED" is deprecated & not used
529
+ _C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5
530
+ _C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096
531
+ _C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
532
+
533
+
534
+ # ---------------------------------------------------------------------------- #
535
+ # RetinaNet Head
536
+ # ---------------------------------------------------------------------------- #
537
+ _C.MODEL.RETINANET = CN()
538
+
539
+ # This is the number of foreground classes.
540
+ _C.MODEL.RETINANET.NUM_CLASSES = 80
541
+
542
+ _C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]
543
+
544
+ # Convolutions to use in the cls and bbox tower
545
+ # NOTE: this doesn't include the last conv for logits
546
+ _C.MODEL.RETINANET.NUM_CONVS = 4
547
+
548
+ # IoU overlap ratio [bg, fg] for labeling anchors.
549
+ # Anchors with < bg are labeled negative (0)
550
+ # Anchors with >= bg and < fg are ignored (-1)
551
+ # Anchors with >= fg are labeled positive (1)
552
+ _C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5]
553
+ _C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1]
554
+
555
+ # Prior prob for rare case (i.e. foreground) at the beginning of training.
556
+ # This is used to set the bias for the logits layer of the classifier subnet.
557
+ # This improves training stability in the case of heavy class imbalance.
558
+ _C.MODEL.RETINANET.PRIOR_PROB = 0.01
559
+
560
+ # Inference cls score threshold, only anchors with score > INFERENCE_TH are
561
+ # considered for inference (to improve speed)
562
+ _C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05
563
+ # Select topk candidates before NMS
564
+ _C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000
565
+ _C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5
566
+
567
+ # Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets
568
+ _C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
569
+
570
+ # Loss parameters
571
+ _C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0
572
+ _C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25
573
+ _C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1
574
+ # Options are: "smooth_l1", "giou"
575
+ _C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1"
576
+
577
+ # One of BN, SyncBN, FrozenBN, GN
578
+ # Only supports GN until unshared norm is implemented
579
+ _C.MODEL.RETINANET.NORM = ""
580
+
581
+
582
+ # ---------------------------------------------------------------------------- #
583
+ # ResNe[X]t options (ResNets = {ResNet, ResNeXt}
584
+ # Note that parts of a resnet may be used for both the backbone and the head
585
+ # These options apply to both
586
+ # ---------------------------------------------------------------------------- #
587
+ _C.MODEL.RESNETS = CN()
588
+
589
+ _C.MODEL.RESNETS.DEPTH = 50
590
+ _C.MODEL.RESNETS.OUT_FEATURES = ["res4"] # res4 for C4 backbone, res2..5 for FPN backbone
591
+
592
+ # Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
593
+ _C.MODEL.RESNETS.NUM_GROUPS = 1
594
+
595
+ # Options: FrozenBN, GN, "SyncBN", "BN"
596
+ _C.MODEL.RESNETS.NORM = "FrozenBN"
597
+
598
+ # Baseline width of each group.
599
+ # Scaling this parameters will scale the width of all bottleneck layers.
600
+ _C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
601
+
602
+ # Place the stride 2 conv on the 1x1 filter
603
+ # Use True only for the original MSRA ResNet; use False for C2 and Torch models
604
+ _C.MODEL.RESNETS.STRIDE_IN_1X1 = True
605
+
606
+ # Apply dilation in stage "res5"
607
+ _C.MODEL.RESNETS.RES5_DILATION = 1
608
+
609
+ # Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet
610
+ # For R18 and R34, this needs to be set to 64
611
+ _C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
612
+ _C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
613
+
614
+ # Apply Deformable Convolution in stages
615
+ # Specify if apply deform_conv on Res2, Res3, Res4, Res5
616
+ _C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False]
617
+ # Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168);
618
+ # Use False for DeformableV1.
619
+ _C.MODEL.RESNETS.DEFORM_MODULATED = False
620
+ # Number of groups in deformable conv.
621
+ _C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1
622
+
623
+
624
+ # ---------------------------------------------------------------------------- #
625
+ # Swin options
626
+ # Note that parts of a resnet may be used for both the backbone and the head
627
+ # These options apply to both
628
+ # ---------------------------------------------------------------------------- #
629
+ _C.MODEL.SPEC = CN()
630
+ _C.MODEL.SPEC.EMBED_DIM = 512
631
+
632
+ _C.MODEL.SPEC.VISION = CN()
633
+ _C.MODEL.SPEC.VISION.PATCH_SIZE = 4
634
+ _C.MODEL.SPEC.VISION.IN_CHANS = 3
635
+ _C.MODEL.SPEC.VISION.EMBED_DIM = 96
636
+ _C.MODEL.SPEC.VISION.DEPTHS = [2, 2, 6, 2]
637
+ _C.MODEL.SPEC.VISION.NUM_HEADS = [3, 6, 12, 24]
638
+ _C.MODEL.SPEC.VISION.WINDOW_SIZE = 7
639
+ _C.MODEL.SPEC.VISION.MLP_RATIO = 4.
640
+ _C.MODEL.SPEC.VISION.DROP_RATE = .0
641
+ _C.MODEL.SPEC.VISION.ATTN_DROP_RATE = .0
642
+ _C.MODEL.SPEC.VISION.DROP_PATH_RATE = .0
643
+ _C.MODEL.SPEC.VISION.QKV_BIAS = True
644
+ _C.MODEL.SPEC.VISION.QK_SCALE = False
645
+ _C.MODEL.SPEC.VISION.APE = False
646
+ _C.MODEL.SPEC.VISION.PATCH_NORM = True
647
+ _C.MODEL.SPEC.VISION.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"]
648
+
649
+ _C.MODEL.SPEC.TEXT = CN()
650
+ _C.MODEL.SPEC.TEXT.NAME = 'transformer'
651
+ _C.MODEL.SPEC.TEXT.LOAD_PRETRAINED = False
652
+ _C.MODEL.SPEC.TEXT.PRETRAINED = ''
653
+ _C.MODEL.SPEC.TEXT.TOKENIZER = 'clip'
654
+ _C.MODEL.SPEC.TEXT.CONTEXT_LENGTH = 77
655
+ _C.MODEL.SPEC.TEXT.WIDTH = 512
656
+ _C.MODEL.SPEC.TEXT.HEADS = 8
657
+ _C.MODEL.SPEC.TEXT.LAYERS = 12
658
+ _C.MODEL.SPEC.TEXT.AUTOGRESSIVE = True
659
+
660
+ # ---------------------------------------------------------------------------- #
661
+ # Solver
662
+ # ---------------------------------------------------------------------------- #
663
+ _C.SOLVER = CN()
664
+
665
+ # See detectron2/solver/build.py for LR scheduler options
666
+ _C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
667
+
668
+ _C.SOLVER.MAX_ITER = 40000
669
+
670
+ _C.SOLVER.BASE_LR = 0.001
671
+
672
+ _C.SOLVER.MOMENTUM = 0.9
673
+
674
+ _C.SOLVER.NESTEROV = False
675
+
676
+ _C.SOLVER.WEIGHT_DECAY = 0.0001
677
+ # The weight decay that's applied to parameters of normalization layers
678
+ # (typically the affine transformation)
679
+ _C.SOLVER.WEIGHT_DECAY_NORM = 0.0
680
+
681
+ _C.SOLVER.GAMMA = 0.1
682
+ # The iteration number to decrease learning rate by GAMMA.
683
+ _C.SOLVER.STEPS = (30000,)
684
+
685
+ _C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
686
+ _C.SOLVER.WARMUP_ITERS = 1000
687
+ _C.SOLVER.WARMUP_METHOD = "linear"
688
+
689
+ # Save a checkpoint after every this number of iterations
690
+ _C.SOLVER.CHECKPOINT_PERIOD = 5000
691
+
692
+ # Number of images per batch across all machines. This is also the number
693
+ # of training images per step (i.e. per iteration). If we use 16 GPUs
694
+ # and IMS_PER_BATCH = 32, each GPU will see 2 images per batch.
695
+ # May be adjusted automatically if REFERENCE_WORLD_SIZE is set.
696
+ _C.SOLVER.IMS_PER_BATCH = 16
697
+
698
+ # The reference number of workers (GPUs) this config is meant to train with.
699
+ # It takes no effect when set to 0.
700
+ # With a non-zero value, it will be used by DefaultTrainer to compute a desired
701
+ # per-worker batch size, and then scale the other related configs (total batch size,
702
+ # learning rate, etc) to match the per-worker batch size.
703
+ # See documentation of `DefaultTrainer.auto_scale_workers` for details:
704
+ _C.SOLVER.REFERENCE_WORLD_SIZE = 0
705
+
706
+ # Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for
707
+ # biases. This is not useful (at least for recent models). You should avoid
708
+ # changing these and they exist only to reproduce Detectron v1 training if
709
+ # desired.
710
+ _C.SOLVER.BIAS_LR_FACTOR = 1.0
711
+ _C.SOLVER.WEIGHT_DECAY_BIAS = _C.SOLVER.WEIGHT_DECAY
712
+
713
+ # Gradient clipping
714
+ _C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False})
715
+ # Type of gradient clipping, currently 2 values are supported:
716
+ # - "value": the absolute values of elements of each gradients are clipped
717
+ # - "norm": the norm of the gradient for each parameter is clipped thus
718
+ # affecting all elements in the parameter
719
+ _C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value"
720
+ # Maximum absolute value used for clipping gradients
721
+ _C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
722
+ # Floating point number p for L-p norm to be used with the "norm"
723
+ # gradient clipping type; for L-inf, please specify .inf
724
+ _C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
725
+
726
+ # Enable automatic mixed precision for training
727
+ # Note that this does not change model's inference behavior.
728
+ # To use AMP in inference, run inference under autocast()
729
+ _C.SOLVER.AMP = CN({"ENABLED": False})
730
+
731
+ # ---------------------------------------------------------------------------- #
732
+ # Specific test options
733
+ # ---------------------------------------------------------------------------- #
734
+ _C.TEST = CN()
735
+ # For end-to-end tests to verify the expected accuracy.
736
+ # Each item is [task, metric, value, tolerance]
737
+ # e.g.: [['bbox', 'AP', 38.5, 0.2]]
738
+ _C.TEST.EXPECTED_RESULTS = []
739
+ # The period (in terms of steps) to evaluate the model during training.
740
+ # Set to 0 to disable.
741
+ _C.TEST.EVAL_PERIOD = 0
742
+ # The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval
743
+ # When empty, it will use the defaults in COCO.
744
+ # Otherwise it should be a list[float] with the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
745
+ _C.TEST.KEYPOINT_OKS_SIGMAS = []
746
+ # Maximum number of detections to return per image during inference (100 is
747
+ # based on the limit established for the COCO dataset).
748
+ _C.TEST.DETECTIONS_PER_IMAGE = 100
749
+
750
+ _C.TEST.AUG = CN({"ENABLED": False})
751
+ _C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200)
752
+ _C.TEST.AUG.MAX_SIZE = 4000
753
+ _C.TEST.AUG.FLIP = True
754
+
755
+ _C.TEST.PRECISE_BN = CN({"ENABLED": False})
756
+ _C.TEST.PRECISE_BN.NUM_ITER = 200
757
+
758
+ # ---------------------------------------------------------------------------- #
759
+ # Misc options
760
+ # ---------------------------------------------------------------------------- #
761
+ # Directory where output files are written
762
+ _C.OUTPUT_DIR = "./output"
763
+ # Set seed to negative to fully randomize everything.
764
+ # Set seed to positive to use a fixed seed. Note that a fixed seed increases
765
+ # reproducibility but does not guarantee fully deterministic behavior.
766
+ # Disabling all parallelism further increases reproducibility.
767
+ _C.SEED = -1
768
+ # Benchmark different cudnn algorithms.
769
+ # If input images have very different sizes, this option will have large overhead
770
+ # for about 10k iterations. It usually hurts total time, but can benefit for certain models.
771
+ # If input images have the same or similar sizes, benchmark is often helpful.
772
+ _C.CUDNN_BENCHMARK = False
773
+ # The period (in terms of steps) for minibatch visualization at train time.
774
+ # Set to 0 to disable.
775
+ _C.VIS_PERIOD = 0
776
+
777
+ # global config is for quick hack purposes.
778
+ # You can set them in command line or config files,
779
+ # and access it with:
780
+ #
781
+ # from detectron2.config import global_cfg
782
+ # print(global_cfg.HACK)
783
+ #
784
+ # Do not commit any configs into it.
785
+ _C.GLOBAL = CN()
786
+ _C.GLOBAL.HACK = 1.0
detectron2/config/instantiate.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import dataclasses
3
+ import logging
4
+ from collections import abc
5
+ from typing import Any
6
+
7
+ from detectron2.utils.registry import _convert_target_to_string, locate
8
+
9
+ __all__ = ["dump_dataclass", "instantiate"]
10
+
11
+
12
+ def dump_dataclass(obj: Any):
13
+ """
14
+ Dump a dataclass recursively into a dict that can be later instantiated.
15
+
16
+ Args:
17
+ obj: a dataclass object
18
+
19
+ Returns:
20
+ dict
21
+ """
22
+ assert dataclasses.is_dataclass(obj) and not isinstance(
23
+ obj, type
24
+ ), "dump_dataclass() requires an instance of a dataclass."
25
+ ret = {"_target_": _convert_target_to_string(type(obj))}
26
+ for f in dataclasses.fields(obj):
27
+ v = getattr(obj, f.name)
28
+ if dataclasses.is_dataclass(v):
29
+ v = dump_dataclass(v)
30
+ if isinstance(v, (list, tuple)):
31
+ v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
32
+ ret[f.name] = v
33
+ return ret
34
+
35
+
36
+ def instantiate(cfg):
37
+ """
38
+ Recursively instantiate objects defined in dictionaries by
39
+ "_target_" and arguments.
40
+
41
+ Args:
42
+ cfg: a dict-like object with "_target_" that defines the caller, and
43
+ other keys that define the arguments
44
+
45
+ Returns:
46
+ object instantiated by cfg
47
+ """
48
+ from omegaconf import ListConfig
49
+
50
+ if isinstance(cfg, ListConfig):
51
+ lst = [instantiate(x) for x in cfg]
52
+ return ListConfig(lst, flags={"allow_objects": True})
53
+ if isinstance(cfg, list):
54
+ # Specialize for list, because many classes take
55
+ # list[objects] as arguments, such as ResNet, DatasetMapper
56
+ return [instantiate(x) for x in cfg]
57
+
58
+ if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
59
+ # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
60
+ # but faster: https://github.com/facebookresearch/hydra/issues/1200
61
+ cfg = {k: instantiate(v) for k, v in cfg.items()}
62
+ cls = cfg.pop("_target_")
63
+ cls = instantiate(cls)
64
+
65
+ if isinstance(cls, str):
66
+ cls_name = cls
67
+ cls = locate(cls_name)
68
+ assert cls is not None, cls_name
69
+ else:
70
+ try:
71
+ cls_name = cls.__module__ + "." + cls.__qualname__
72
+ except Exception:
73
+ # target could be anything, so the above could fail
74
+ cls_name = str(cls)
75
+ assert callable(cls), f"_target_ {cls} does not define a callable object"
76
+ try:
77
+ return cls(**cfg)
78
+ except TypeError:
79
+ logger = logging.getLogger(__name__)
80
+ logger.error(f"Error when instantiating {cls_name}!")
81
+ raise
82
+ return cfg # return as-is if don't know what to do
detectron2/config/lazy.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import ast
3
+ import builtins
4
+ import importlib
5
+ import inspect
6
+ import logging
7
+ import os
8
+ import uuid
9
+ from collections import abc
10
+ from contextlib import contextmanager
11
+ from copy import deepcopy
12
+ from typing import List, Tuple, Union
13
+ import cloudpickle
14
+ import yaml
15
+ from omegaconf import DictConfig, ListConfig, OmegaConf
16
+
17
+ from detectron2.utils.file_io import PathManager
18
+ from detectron2.utils.registry import _convert_target_to_string
19
+
20
+ __all__ = ["LazyCall", "LazyConfig"]
21
+
22
+
23
+ class LazyCall:
24
+ """
25
+ Wrap a callable so that when it's called, the call will not be executed,
26
+ but returns a dict that describes the call.
27
+
28
+ LazyCall object has to be called with only keyword arguments. Positional
29
+ arguments are not yet supported.
30
+
31
+ Examples:
32
+ ::
33
+ from detectron2.config import instantiate, LazyCall
34
+
35
+ layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
36
+ layer_cfg.out_channels = 64 # can edit it afterwards
37
+ layer = instantiate(layer_cfg)
38
+ """
39
+
40
+ def __init__(self, target):
41
+ if not (callable(target) or isinstance(target, (str, abc.Mapping))):
42
+ raise TypeError(
43
+ "target of LazyCall must be a callable or defines a callable! Got {target}"
44
+ )
45
+ self._target = target
46
+
47
+ def __call__(self, **kwargs):
48
+ kwargs["_target_"] = self._target
49
+ return DictConfig(content=kwargs, flags={"allow_objects": True})
50
+
51
+
52
+ def _visit_dict_config(cfg, func):
53
+ """
54
+ Apply func recursively to all DictConfig in cfg.
55
+ """
56
+ if isinstance(cfg, DictConfig):
57
+ func(cfg)
58
+ for v in cfg.values():
59
+ _visit_dict_config(v, func)
60
+ elif isinstance(cfg, ListConfig):
61
+ for v in cfg:
62
+ _visit_dict_config(v, func)
63
+
64
+
65
+ def _validate_py_syntax(filename):
66
+ # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
67
+ with PathManager.open(filename, "r") as f:
68
+ content = f.read()
69
+ try:
70
+ ast.parse(content)
71
+ except SyntaxError as e:
72
+ raise SyntaxError(f"Config file {filename} has syntax error!") from e
73
+
74
+
75
+ def _cast_to_config(obj):
76
+ # if given a dict, return DictConfig instead
77
+ if isinstance(obj, dict):
78
+ return DictConfig(obj, flags={"allow_objects": True})
79
+ return obj
80
+
81
+
82
+ _CFG_PACKAGE_NAME = "detectron2._cfg_loader"
83
+ """
84
+ A namespace to put all imported config into.
85
+ """
86
+
87
+
88
+ def _random_package_name(filename):
89
+ # generate a random package name when loading config files
90
+ return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
91
+
92
+
93
+ @contextmanager
94
+ def _patch_import():
95
+ """
96
+ Enhance relative import statements in config files, so that they:
97
+ 1. locate files purely based on relative location, regardless of packages.
98
+ e.g. you can import file without having __init__
99
+ 2. do not cache modules globally; modifications of module states has no side effect
100
+ 3. support other storage system through PathManager
101
+ 4. imported dict are turned into omegaconf.DictConfig automatically
102
+ """
103
+ old_import = builtins.__import__
104
+
105
+ def find_relative_file(original_file, relative_import_path, level):
106
+ cur_file = os.path.dirname(original_file)
107
+ for _ in range(level - 1):
108
+ cur_file = os.path.dirname(cur_file)
109
+ cur_name = relative_import_path.lstrip(".")
110
+ for part in cur_name.split("."):
111
+ cur_file = os.path.join(cur_file, part)
112
+ # NOTE: directory import is not handled. Because then it's unclear
113
+ # if such import should produce python module or DictConfig. This can
114
+ # be discussed further if needed.
115
+ if not cur_file.endswith(".py"):
116
+ cur_file += ".py"
117
+ if not PathManager.isfile(cur_file):
118
+ raise ImportError(
119
+ f"Cannot import name {relative_import_path} from "
120
+ f"{original_file}: {cur_file} has to exist."
121
+ )
122
+ return cur_file
123
+
124
+ def new_import(name, globals=None, locals=None, fromlist=(), level=0):
125
+ if (
126
+ # Only deal with relative imports inside config files
127
+ level != 0
128
+ and globals is not None
129
+ and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
130
+ ):
131
+ cur_file = find_relative_file(globals["__file__"], name, level)
132
+ _validate_py_syntax(cur_file)
133
+ spec = importlib.machinery.ModuleSpec(
134
+ _random_package_name(cur_file), None, origin=cur_file
135
+ )
136
+ module = importlib.util.module_from_spec(spec)
137
+ module.__file__ = cur_file
138
+ with PathManager.open(cur_file) as f:
139
+ content = f.read()
140
+ exec(compile(content, cur_file, "exec"), module.__dict__)
141
+ for name in fromlist: # turn imported dict into DictConfig automatically
142
+ val = _cast_to_config(module.__dict__[name])
143
+ module.__dict__[name] = val
144
+ return module
145
+ return old_import(name, globals, locals, fromlist=fromlist, level=level)
146
+
147
+ builtins.__import__ = new_import
148
+ yield new_import
149
+ builtins.__import__ = old_import
150
+
151
+
152
+ class LazyConfig:
153
+ """
154
+ Provid methods to save, load, and overrides an omegaconf config object
155
+ which may contain definition of lazily-constructed objects.
156
+ """
157
+
158
+ @staticmethod
159
+ def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
160
+ """
161
+ Similar to :meth:`load()`, but load path relative to the caller's
162
+ source file.
163
+
164
+ This has the same functionality as a relative import, except that this method
165
+ accepts filename as a string, so more characters are allowed in the filename.
166
+ """
167
+ caller_frame = inspect.stack()[1]
168
+ caller_fname = caller_frame[0].f_code.co_filename
169
+ assert caller_fname != "<string>", "load_rel Unable to find caller"
170
+ caller_dir = os.path.dirname(caller_fname)
171
+ filename = os.path.join(caller_dir, filename)
172
+ return LazyConfig.load(filename, keys)
173
+
174
+ @staticmethod
175
+ def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
176
+ """
177
+ Load a config file.
178
+
179
+ Args:
180
+ filename: absolute path or relative path w.r.t. the current working directory
181
+ keys: keys to load and return. If not given, return all keys
182
+ (whose values are config objects) in a dict.
183
+ """
184
+ has_keys = keys is not None
185
+ filename = filename.replace("/./", "/") # redundant
186
+ if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
187
+ raise ValueError(f"Config file {filename} has to be a python or yaml file.")
188
+ if filename.endswith(".py"):
189
+ _validate_py_syntax(filename)
190
+
191
+ with _patch_import():
192
+ # Record the filename
193
+ module_namespace = {
194
+ "__file__": filename,
195
+ "__package__": _random_package_name(filename),
196
+ }
197
+ with PathManager.open(filename) as f:
198
+ content = f.read()
199
+ # Compile first with filename to:
200
+ # 1. make filename appears in stacktrace
201
+ # 2. make load_rel able to find its parent's (possibly remote) location
202
+ exec(compile(content, filename, "exec"), module_namespace)
203
+
204
+ ret = module_namespace
205
+ else:
206
+ with PathManager.open(filename) as f:
207
+ obj = yaml.unsafe_load(f)
208
+ ret = OmegaConf.create(obj, flags={"allow_objects": True})
209
+
210
+ if has_keys:
211
+ if isinstance(keys, str):
212
+ return _cast_to_config(ret[keys])
213
+ else:
214
+ return tuple(_cast_to_config(ret[a]) for a in keys)
215
+ else:
216
+ if filename.endswith(".py"):
217
+ # when not specified, only load those that are config objects
218
+ ret = DictConfig(
219
+ {
220
+ name: _cast_to_config(value)
221
+ for name, value in ret.items()
222
+ if isinstance(value, (DictConfig, ListConfig, dict))
223
+ and not name.startswith("_")
224
+ },
225
+ flags={"allow_objects": True},
226
+ )
227
+ return ret
228
+
229
+ @staticmethod
230
+ def save(cfg, filename: str):
231
+ """
232
+ Args:
233
+ cfg: an omegaconf config object
234
+ filename: yaml file name to save the config file
235
+ """
236
+ logger = logging.getLogger(__name__)
237
+ try:
238
+ cfg = deepcopy(cfg)
239
+ except Exception:
240
+ pass
241
+ else:
242
+ # if it's deep-copyable, then...
243
+ def _replace_type_by_name(x):
244
+ if "_target_" in x and callable(x._target_):
245
+ try:
246
+ x._target_ = _convert_target_to_string(x._target_)
247
+ except AttributeError:
248
+ pass
249
+
250
+ # not necessary, but makes yaml looks nicer
251
+ _visit_dict_config(cfg, _replace_type_by_name)
252
+
253
+ try:
254
+ with PathManager.open(filename, "w") as f:
255
+ dict = OmegaConf.to_container(cfg, resolve=False)
256
+ dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999)
257
+ f.write(dumped)
258
+ except Exception:
259
+ logger.exception("Unable to serialize the config to yaml. Error:")
260
+ new_filename = filename + ".pkl"
261
+ try:
262
+ # retry by pickle
263
+ with PathManager.open(new_filename, "wb") as f:
264
+ cloudpickle.dump(cfg, f)
265
+ logger.warning(f"Config saved using cloudpickle at {new_filename} ...")
266
+ except Exception:
267
+ pass
268
+
269
+ @staticmethod
270
+ def apply_overrides(cfg, overrides: List[str]):
271
+ """
272
+ In-place override contents of cfg.
273
+
274
+ Args:
275
+ cfg: an omegaconf config object
276
+ overrides: list of strings in the format of "a=b" to override configs.
277
+ See https://hydra.cc/docs/next/advanced/override_grammar/basic/
278
+ for syntax.
279
+
280
+ Returns:
281
+ the cfg object
282
+ """
283
+
284
+ def safe_update(cfg, key, value):
285
+ parts = key.split(".")
286
+ for idx in range(1, len(parts)):
287
+ prefix = ".".join(parts[:idx])
288
+ v = OmegaConf.select(cfg, prefix, default=None)
289
+ if v is None:
290
+ break
291
+ if not OmegaConf.is_config(v):
292
+ raise KeyError(
293
+ f"Trying to update key {key}, but {prefix} "
294
+ f"is not a config, but has type {type(v)}."
295
+ )
296
+ OmegaConf.update(cfg, key, value, merge=True)
297
+
298
+ from hydra.core.override_parser.overrides_parser import OverridesParser
299
+
300
+ parser = OverridesParser.create()
301
+ overrides = parser.parse_overrides(overrides)
302
+ for o in overrides:
303
+ key = o.key_or_group
304
+ value = o.value()
305
+ if o.is_delete():
306
+ # TODO support this
307
+ raise NotImplementedError("deletion is not yet a supported override")
308
+ safe_update(cfg, key, value)
309
+ return cfg
310
+
311
+ @staticmethod
312
+ def to_py(cfg, prefix: str = "cfg."):
313
+ """
314
+ Convert a config object into its equivalent Python code.
315
+
316
+ Args:
317
+ cfg: an omegaconf config object
318
+ prefix: root name for the resulting code (default: "cfg.")
319
+
320
+
321
+ Returns:
322
+ str of formatted Python code
323
+ """
324
+ import black
325
+
326
+ cfg = OmegaConf.to_container(cfg, resolve=True)
327
+
328
+ def _to_str(obj, prefix=None, inside_call=False):
329
+ if prefix is None:
330
+ prefix = []
331
+ if isinstance(obj, abc.Mapping) and "_target_" in obj:
332
+ # Dict representing a function call
333
+ target = _convert_target_to_string(obj.pop("_target_"))
334
+ args = []
335
+ for k, v in sorted(obj.items()):
336
+ args.append(f"{k}={_to_str(v, inside_call=True)}")
337
+ args = ", ".join(args)
338
+ call = f"{target}({args})"
339
+ return "".join(prefix) + call
340
+ elif isinstance(obj, abc.Mapping) and not inside_call:
341
+ # Dict that is not inside a call is a list of top-level config objects that we
342
+ # render as one object per line with dot separated prefixes
343
+ key_list = []
344
+ for k, v in sorted(obj.items()):
345
+ if isinstance(v, abc.Mapping) and "_target_" not in v:
346
+ key_list.append(_to_str(v, prefix=prefix + [k + "."]))
347
+ else:
348
+ key = "".join(prefix) + k
349
+ key_list.append(f"{key}={_to_str(v)}")
350
+ return "\n".join(key_list)
351
+ elif isinstance(obj, abc.Mapping):
352
+ # Dict that is inside a call is rendered as a regular dict
353
+ return (
354
+ "{"
355
+ + ",".join(
356
+ f"{repr(k)}: {_to_str(v, inside_call=inside_call)}"
357
+ for k, v in sorted(obj.items())
358
+ )
359
+ + "}"
360
+ )
361
+ elif isinstance(obj, list):
362
+ return "[" + ",".join(_to_str(x, inside_call=inside_call) for x in obj) + "]"
363
+ else:
364
+ return repr(obj)
365
+
366
+ py_str = _to_str(cfg, prefix=[prefix])
367
+ try:
368
+ return black.format_str(py_str, mode=black.Mode())
369
+ except black.InvalidInput:
370
+ return py_str
detectron2/data/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import transforms # isort:skip
3
+
4
+ from .build import (
5
+ build_batch_data_loader,
6
+ build_detection_test_loader,
7
+ build_detection_train_loader,
8
+ get_detection_dataset_dicts,
9
+ load_proposals_into_dataset,
10
+ print_instances_class_histogram,
11
+ )
12
+ from .catalog import DatasetCatalog, MetadataCatalog, Metadata
13
+ from .common import DatasetFromList, MapDataset
14
+ from .dataset_mapper import DatasetMapper
15
+
16
+ # ensure the builtin datasets are registered
17
+ from . import datasets, samplers # isort:skip
18
+
19
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
detectron2/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (859 Bytes). View file
 
detectron2/data/__pycache__/build.cpython-39.pyc ADDED
Binary file (16.9 kB). View file
 
detectron2/data/__pycache__/catalog.cpython-39.pyc ADDED
Binary file (7.6 kB). View file
 
detectron2/data/__pycache__/clip_build.cpython-39.pyc ADDED
Binary file (4.32 kB). View file
 
detectron2/data/__pycache__/common.cpython-39.pyc ADDED
Binary file (6.84 kB). View file
 
detectron2/data/__pycache__/dataset_mapper.cpython-39.pyc ADDED
Binary file (5.89 kB). View file
 
detectron2/data/__pycache__/detection_utils.cpython-39.pyc ADDED
Binary file (18.2 kB). View file
 
detectron2/data/build.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import itertools
3
+ import logging
4
+ import numpy as np
5
+ import operator
6
+ import pickle
7
+ import torch.utils.data
8
+ from tabulate import tabulate
9
+ from termcolor import colored
10
+
11
+ from detectron2.config import configurable
12
+ from detectron2.structures import BoxMode
13
+ from detectron2.utils.comm import get_world_size
14
+ from detectron2.utils.env import seed_all_rng
15
+ from detectron2.utils.file_io import PathManager
16
+ from detectron2.utils.logger import _log_api_usage, log_first_n
17
+
18
+ from .catalog import DatasetCatalog, MetadataCatalog
19
+ from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset
20
+ from .dataset_mapper import DatasetMapper
21
+ from .detection_utils import check_metadata_consistency
22
+ from .samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler
23
+
24
+ from .clip_build import make_clip_dataset
25
+
26
+ """
27
+ This file contains the default logic to build a dataloader for training or testing.
28
+ """
29
+
30
+ __all__ = [
31
+ "build_batch_data_loader",
32
+ "build_detection_train_loader",
33
+ "build_detection_test_loader",
34
+ "get_detection_dataset_dicts",
35
+ "load_proposals_into_dataset",
36
+ "print_instances_class_histogram",
37
+ ]
38
+
39
+
40
+ def filter_images_with_only_crowd_annotations(dataset_dicts):
41
+ """
42
+ Filter out images with none annotations or only crowd annotations
43
+ (i.e., images without non-crowd annotations).
44
+ A common training-time preprocessing on COCO dataset.
45
+
46
+ Args:
47
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
48
+
49
+ Returns:
50
+ list[dict]: the same format, but filtered.
51
+ """
52
+ num_before = len(dataset_dicts)
53
+
54
+ def valid(anns):
55
+ for ann in anns:
56
+ if ann.get("iscrowd", 0) == 0:
57
+ return True
58
+ return False
59
+
60
+ dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
61
+ num_after = len(dataset_dicts)
62
+ logger = logging.getLogger(__name__)
63
+ logger.info(
64
+ "Removed {} images with no usable annotations. {} images left.".format(
65
+ num_before - num_after, num_after
66
+ )
67
+ )
68
+ return dataset_dicts
69
+
70
+
71
+ def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
72
+ """
73
+ Filter out images with too few number of keypoints.
74
+
75
+ Args:
76
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
77
+
78
+ Returns:
79
+ list[dict]: the same format as dataset_dicts, but filtered.
80
+ """
81
+ num_before = len(dataset_dicts)
82
+
83
+ def visible_keypoints_in_image(dic):
84
+ # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
85
+ annotations = dic["annotations"]
86
+ return sum(
87
+ (np.array(ann["keypoints"][2::3]) > 0).sum()
88
+ for ann in annotations
89
+ if "keypoints" in ann
90
+ )
91
+
92
+ dataset_dicts = [
93
+ x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
94
+ ]
95
+ num_after = len(dataset_dicts)
96
+ logger = logging.getLogger(__name__)
97
+ logger.info(
98
+ "Removed {} images with fewer than {} keypoints.".format(
99
+ num_before - num_after, min_keypoints_per_image
100
+ )
101
+ )
102
+ return dataset_dicts
103
+
104
+
105
+ def load_proposals_into_dataset(dataset_dicts, proposal_file):
106
+ """
107
+ Load precomputed object proposals into the dataset.
108
+
109
+ The proposal file should be a pickled dict with the following keys:
110
+
111
+ - "ids": list[int] or list[str], the image ids
112
+ - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
113
+ - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
114
+ corresponding to the boxes.
115
+ - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
116
+
117
+ Args:
118
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
119
+ proposal_file (str): file path of pre-computed proposals, in pkl format.
120
+
121
+ Returns:
122
+ list[dict]: the same format as dataset_dicts, but added proposal field.
123
+ """
124
+ logger = logging.getLogger(__name__)
125
+ logger.info("Loading proposals from: {}".format(proposal_file))
126
+
127
+ with PathManager.open(proposal_file, "rb") as f:
128
+ proposals = pickle.load(f, encoding="latin1")
129
+
130
+ # Rename the key names in D1 proposal files
131
+ rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
132
+ for key in rename_keys:
133
+ if key in proposals:
134
+ proposals[rename_keys[key]] = proposals.pop(key)
135
+
136
+ # Fetch the indexes of all proposals that are in the dataset
137
+ # Convert image_id to str since they could be int.
138
+ img_ids = set({str(record["image_id"]) for record in dataset_dicts})
139
+ id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
140
+
141
+ # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
142
+ bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
143
+
144
+ for record in dataset_dicts:
145
+ # Get the index of the proposal
146
+ i = id_to_index[str(record["image_id"])]
147
+
148
+ boxes = proposals["boxes"][i]
149
+ objectness_logits = proposals["objectness_logits"][i]
150
+ # Sort the proposals in descending order of the scores
151
+ inds = objectness_logits.argsort()[::-1]
152
+ record["proposal_boxes"] = boxes[inds]
153
+ record["proposal_objectness_logits"] = objectness_logits[inds]
154
+ record["proposal_bbox_mode"] = bbox_mode
155
+
156
+ return dataset_dicts
157
+
158
+
159
+ def print_instances_class_histogram(dataset_dicts, class_names):
160
+ """
161
+ Args:
162
+ dataset_dicts (list[dict]): list of dataset dicts.
163
+ class_names (list[str]): list of class names (zero-indexed).
164
+ """
165
+ num_classes = len(class_names)
166
+ hist_bins = np.arange(num_classes + 1)
167
+ histogram = np.zeros((num_classes,), dtype=np.int)
168
+ for entry in dataset_dicts:
169
+ annos = entry["annotations"]
170
+ classes = np.asarray(
171
+ [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=np.int
172
+ )
173
+ if len(classes):
174
+ assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
175
+ assert (
176
+ classes.max() < num_classes
177
+ ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
178
+ histogram += np.histogram(classes, bins=hist_bins)[0]
179
+
180
+ N_COLS = min(6, len(class_names) * 2)
181
+
182
+ def short_name(x):
183
+ # make long class names shorter. useful for lvis
184
+ if len(x) > 13:
185
+ return x[:11] + ".."
186
+ return x
187
+
188
+ data = list(
189
+ itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
190
+ )
191
+ total_num_instances = sum(data[1::2])
192
+ data.extend([None] * (N_COLS - (len(data) % N_COLS)))
193
+ if num_classes > 1:
194
+ data.extend(["total", total_num_instances])
195
+ data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
196
+ table = tabulate(
197
+ data,
198
+ headers=["category", "#instances"] * (N_COLS // 2),
199
+ tablefmt="pipe",
200
+ numalign="left",
201
+ stralign="center",
202
+ )
203
+ log_first_n(
204
+ logging.INFO,
205
+ "Distribution of instances among all {} categories:\n".format(num_classes)
206
+ + colored(table, "cyan"),
207
+ key="message",
208
+ )
209
+
210
+
211
+ def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, proposal_files=None):
212
+ """
213
+ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
214
+
215
+ Args:
216
+ names (str or list[str]): a dataset name or a list of dataset names
217
+ filter_empty (bool): whether to filter out images without instance annotations
218
+ min_keypoints (int): filter out images with fewer keypoints than
219
+ `min_keypoints`. Set to 0 to do nothing.
220
+ proposal_files (list[str]): if given, a list of object proposal files
221
+ that match each dataset in `names`.
222
+
223
+ Returns:
224
+ list[dict]: a list of dicts following the standard dataset dict format.
225
+ """
226
+ if isinstance(names, str):
227
+ names = [names]
228
+ assert len(names), names
229
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
230
+ for dataset_name, dicts in zip(names, dataset_dicts):
231
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
232
+
233
+ if proposal_files is not None:
234
+ assert len(names) == len(proposal_files)
235
+ # load precomputed proposals from proposal files
236
+ dataset_dicts = [
237
+ load_proposals_into_dataset(dataset_i_dicts, proposal_file)
238
+ for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
239
+ ]
240
+
241
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
242
+
243
+ has_instances = "annotations" in dataset_dicts[0]
244
+ if filter_empty and has_instances:
245
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
246
+ if min_keypoints > 0 and has_instances:
247
+ dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
248
+
249
+ if has_instances:
250
+ try:
251
+ class_names = MetadataCatalog.get(names[0]).thing_classes
252
+ check_metadata_consistency("thing_classes", names)
253
+ print_instances_class_histogram(dataset_dicts, class_names)
254
+ except AttributeError: # class names are not available for this dataset
255
+ pass
256
+
257
+ assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
258
+ return dataset_dicts
259
+
260
+
261
+ def build_batch_data_loader(
262
+ dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0
263
+ ):
264
+ """
265
+ Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
266
+ 1. support aspect ratio grouping options
267
+ 2. use no "batch collation", because this is common for detection training
268
+
269
+ Args:
270
+ dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed.
271
+ sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices
272
+ total_batch_size, aspect_ratio_grouping, num_workers): see
273
+ :func:`build_detection_train_loader`.
274
+
275
+ Returns:
276
+ iterable[list]. Length of each list is the batch size of the current
277
+ GPU. Each element in the list comes from the dataset.
278
+ """
279
+ world_size = get_world_size()
280
+ assert (
281
+ total_batch_size > 0 and total_batch_size % world_size == 0
282
+ ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
283
+ total_batch_size, world_size
284
+ )
285
+
286
+ batch_size = total_batch_size // world_size
287
+ if aspect_ratio_grouping:
288
+ data_loader = torch.utils.data.DataLoader(
289
+ dataset,
290
+ sampler=sampler,
291
+ num_workers=num_workers,
292
+ batch_sampler=None,
293
+ collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
294
+ worker_init_fn=worker_init_reset_seed,
295
+ ) # yield individual mapped dict
296
+ return AspectRatioGroupedDataset(data_loader, batch_size)
297
+ else:
298
+ batch_sampler = torch.utils.data.sampler.BatchSampler(
299
+ sampler, batch_size, drop_last=True
300
+ ) # drop_last so the batch always have the same size
301
+ return torch.utils.data.DataLoader(
302
+ dataset,
303
+ num_workers=num_workers,
304
+ batch_sampler=batch_sampler,
305
+ collate_fn=trivial_batch_collator,
306
+ worker_init_fn=worker_init_reset_seed,
307
+ )
308
+
309
+
310
+ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
311
+ if 'yfcc100m' in cfg.DATASETS.TRAIN: # dataset, transform/aug., sampler for image-text pairs training
312
+ logger = logging.getLogger(__name__)
313
+ logger.info("Creating dataset {}".format(cfg.DATASETS.TRAIN))
314
+ datasets, precomputed_tokens, dataset_classes = make_clip_dataset(
315
+ cfg, is_train=True,
316
+ transforms=None, # for training, we use our own defined transforms
317
+ )
318
+ dataset = datasets[0] # during training, a single (possibly concatenated) dataset was returned
319
+ if sampler is None:
320
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
321
+ logger = logging.getLogger(__name__)
322
+ logger.info("Using training sampler {}".format(sampler_name))
323
+ if sampler_name == "TrainingSampler":
324
+ sampler = TrainingSampler(len(dataset))
325
+ elif sampler_name == "RepeatFactorTrainingSampler":
326
+ repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
327
+ dataset, cfg.DATALOADER.REPEAT_THRESHOLD
328
+ )
329
+ sampler = RepeatFactorTrainingSampler(repeat_factors)
330
+ else:
331
+ raise ValueError("Unknown training sampler: {}".format(sampler_name))
332
+ return {
333
+ "dataset": dataset,
334
+ "sampler": sampler,
335
+ "mapper": None,
336
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
337
+ "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
338
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
339
+ }
340
+ # the following is the default code in Detectron2
341
+ if dataset is None:
342
+ dataset = get_detection_dataset_dicts(
343
+ cfg.DATASETS.TRAIN,
344
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
345
+ min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
346
+ if cfg.MODEL.KEYPOINT_ON
347
+ else 0,
348
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
349
+ )
350
+ _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
351
+
352
+ if mapper is None:
353
+ mapper = DatasetMapper(cfg, True)
354
+
355
+ if sampler is None:
356
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
357
+ logger = logging.getLogger(__name__)
358
+ logger.info("Using training sampler {}".format(sampler_name))
359
+ if sampler_name == "TrainingSampler":
360
+ sampler = TrainingSampler(len(dataset))
361
+ elif sampler_name == "RepeatFactorTrainingSampler":
362
+ repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
363
+ dataset, cfg.DATALOADER.REPEAT_THRESHOLD
364
+ )
365
+ sampler = RepeatFactorTrainingSampler(repeat_factors)
366
+ else:
367
+ raise ValueError("Unknown training sampler: {}".format(sampler_name))
368
+
369
+ return {
370
+ "dataset": dataset,
371
+ "sampler": sampler,
372
+ "mapper": mapper,
373
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
374
+ "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
375
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
376
+ }
377
+
378
+
379
+ # TODO can allow dataset as an iterable or IterableDataset to make this function more general
380
+ @configurable(from_config=_train_loader_from_config)
381
+ def build_detection_train_loader(
382
+ dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
383
+ ):
384
+ """
385
+ Build a dataloader for object detection with some default features.
386
+ This interface is experimental.
387
+
388
+ Args:
389
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
390
+ or a map-style pytorch dataset. They can be obtained by using
391
+ :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
392
+ mapper (callable): a callable which takes a sample (dict) from dataset and
393
+ returns the format to be consumed by the model.
394
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
395
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
396
+ indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
397
+ which coordinates an infinite random shuffle sequence across all workers.
398
+ total_batch_size (int): total batch size across all workers. Batching
399
+ simply puts data into a list.
400
+ aspect_ratio_grouping (bool): whether to group images with similar
401
+ aspect ratio for efficiency. When enabled, it requires each
402
+ element in dataset be a dict with keys "width" and "height".
403
+ num_workers (int): number of parallel data loading workers
404
+
405
+ Returns:
406
+ torch.utils.data.DataLoader:
407
+ a dataloader. Each output from it is a ``list[mapped_element]`` of length
408
+ ``total_batch_size / num_workers``, where ``mapped_element`` is produced
409
+ by the ``mapper``.
410
+ """
411
+ if isinstance(dataset, list):
412
+ dataset = DatasetFromList(dataset, copy=False)
413
+ if mapper is not None:
414
+ dataset = MapDataset(dataset, mapper)
415
+ if sampler is None:
416
+ sampler = TrainingSampler(len(dataset))
417
+ assert isinstance(sampler, torch.utils.data.sampler.Sampler)
418
+ return build_batch_data_loader(
419
+ dataset,
420
+ sampler,
421
+ total_batch_size,
422
+ aspect_ratio_grouping=aspect_ratio_grouping,
423
+ num_workers=num_workers,
424
+ )
425
+
426
+
427
+ def _test_loader_from_config(cfg, dataset_name, mapper=None):
428
+ """
429
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
430
+ standard practice is to evaluate each test set individually (not combining them).
431
+ """
432
+ if 'yfcc100m' in cfg.DATASETS.TEST: # dataset, no {transform/aug., sampler for image-text pairs training}
433
+ logger = logging.getLogger(__name__)
434
+ logger.info("Creating dataset {}".format(cfg.DATASETS.TEST))
435
+ datasets, precomputed_tokens, dataset_classes = make_clip_dataset(
436
+ cfg, is_train=False,
437
+ transforms=None, # for training, we use our own defined transforms
438
+ )
439
+ dataset = datasets[0] # during training, a single (possibly concatenated) dataset was returned
440
+ return {
441
+ "dataset": dataset,
442
+ "mapper": None,
443
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
444
+ }
445
+
446
+ # the following is the default code in Detectron2
447
+ dataset = get_detection_dataset_dicts(
448
+ [dataset_name],
449
+ filter_empty=False,
450
+ proposal_files=[
451
+ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
452
+ ]
453
+ if cfg.MODEL.LOAD_PROPOSALS
454
+ else None,
455
+ )
456
+ if mapper is None:
457
+ mapper = DatasetMapper(cfg, False)
458
+ if cfg.MODEL.META_ARCHITECTURE == 'CLIPRCNN': # speed up when using CLIP in inference
459
+ return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS,\
460
+ "clip_batch_size": cfg.MODEL.CLIP.IMS_PER_BATCH_TEST}
461
+ return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS}
462
+
463
+
464
+ @configurable(from_config=_test_loader_from_config)
465
+ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0, clip_batch_size=None):
466
+ """
467
+ Similar to `build_detection_train_loader`, but uses a batch size of 1,
468
+ and :class:`InferenceSampler`. This sampler coordinates all workers to
469
+ produce the exact set of all samples.
470
+ This interface is experimental.
471
+
472
+ Args:
473
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
474
+ or a map-style pytorch dataset. They can be obtained by using
475
+ :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
476
+ mapper (callable): a callable which takes a sample (dict) from dataset
477
+ and returns the format to be consumed by the model.
478
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
479
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
480
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
481
+ which splits the dataset across all workers.
482
+ num_workers (int): number of parallel data loading workers
483
+
484
+ Returns:
485
+ DataLoader: a torch DataLoader, that loads the given detection
486
+ dataset, with test-time transformation and batching.
487
+
488
+ Examples:
489
+ ::
490
+ data_loader = build_detection_test_loader(
491
+ DatasetRegistry.get("my_test"),
492
+ mapper=DatasetMapper(...))
493
+
494
+ # or, instantiate with a CfgNode:
495
+ data_loader = build_detection_test_loader(cfg, "my_test")
496
+ """
497
+ if isinstance(dataset, list):
498
+ dataset = DatasetFromList(dataset, copy=False)
499
+ if mapper is not None:
500
+ dataset = MapDataset(dataset, mapper)
501
+ if sampler is None:
502
+ sampler = InferenceSampler(len(dataset))
503
+
504
+ if clip_batch_size: # multiple images per gpu
505
+ world_size = get_world_size()
506
+ batch_size = clip_batch_size // world_size
507
+ batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last=False)
508
+ data_loader = torch.utils.data.DataLoader(
509
+ dataset,
510
+ num_workers=num_workers,
511
+ batch_sampler=batch_sampler,
512
+ collate_fn=trivial_batch_collator,
513
+ )
514
+ return data_loader
515
+ # Always use 1 image per worker during inference since this is the
516
+ # standard when reporting inference time in papers.
517
+ batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False)
518
+ data_loader = torch.utils.data.DataLoader(
519
+ dataset,
520
+ num_workers=num_workers,
521
+ batch_sampler=batch_sampler,
522
+ collate_fn=trivial_batch_collator,
523
+ )
524
+ return data_loader
525
+
526
+
527
+ def trivial_batch_collator(batch):
528
+ """
529
+ A batch collator that does nothing.
530
+ """
531
+ return batch
532
+
533
+
534
+ def worker_init_reset_seed(worker_id):
535
+ initial_seed = torch.initial_seed() % 2 ** 31
536
+ seed_all_rng(initial_seed + worker_id)
detectron2/data/catalog.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import types
5
+ from collections import UserDict
6
+ from typing import List
7
+
8
+ from detectron2.utils.logger import log_first_n
9
+
10
+ __all__ = ["DatasetCatalog", "MetadataCatalog", "Metadata"]
11
+
12
+
13
+ class _DatasetCatalog(UserDict):
14
+ """
15
+ A global dictionary that stores information about the datasets and how to obtain them.
16
+
17
+ It contains a mapping from strings
18
+ (which are names that identify a dataset, e.g. "coco_2014_train")
19
+ to a function which parses the dataset and returns the samples in the
20
+ format of `list[dict]`.
21
+
22
+ The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details)
23
+ if used with the data loader functionalities in `data/build.py,data/detection_transform.py`.
24
+
25
+ The purpose of having this catalog is to make it easy to choose
26
+ different datasets, by just using the strings in the config.
27
+ """
28
+
29
+ def register(self, name, func):
30
+ """
31
+ Args:
32
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
33
+ func (callable): a callable which takes no arguments and returns a list of dicts.
34
+ It must return the same results if called multiple times.
35
+ """
36
+ assert callable(func), "You must register a function with `DatasetCatalog.register`!"
37
+ assert name not in self, "Dataset '{}' is already registered!".format(name)
38
+ self[name] = func
39
+
40
+ def get(self, name):
41
+ """
42
+ Call the registered function and return its results.
43
+
44
+ Args:
45
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
46
+
47
+ Returns:
48
+ list[dict]: dataset annotations.
49
+ """
50
+ try:
51
+ f = self[name]
52
+ except KeyError as e:
53
+ raise KeyError(
54
+ "Dataset '{}' is not registered! Available datasets are: {}".format(
55
+ name, ", ".join(list(self.keys()))
56
+ )
57
+ ) from e
58
+ return f()
59
+
60
+ def list(self) -> List[str]:
61
+ """
62
+ List all registered datasets.
63
+
64
+ Returns:
65
+ list[str]
66
+ """
67
+ return list(self.keys())
68
+
69
+ def remove(self, name):
70
+ """
71
+ Alias of ``pop``.
72
+ """
73
+ self.pop(name)
74
+
75
+ def __str__(self):
76
+ return "DatasetCatalog(registered datasets: {})".format(", ".join(self.keys()))
77
+
78
+ __repr__ = __str__
79
+
80
+
81
+ DatasetCatalog = _DatasetCatalog()
82
+ DatasetCatalog.__doc__ = (
83
+ _DatasetCatalog.__doc__
84
+ + """
85
+ .. automethod:: detectron2.data.catalog.DatasetCatalog.register
86
+ .. automethod:: detectron2.data.catalog.DatasetCatalog.get
87
+ """
88
+ )
89
+
90
+
91
+ class Metadata(types.SimpleNamespace):
92
+ """
93
+ A class that supports simple attribute setter/getter.
94
+ It is intended for storing metadata of a dataset and make it accessible globally.
95
+
96
+ Examples:
97
+ ::
98
+ # somewhere when you load the data:
99
+ MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"]
100
+
101
+ # somewhere when you print statistics or visualize:
102
+ classes = MetadataCatalog.get("mydataset").thing_classes
103
+ """
104
+
105
+ # the name of the dataset
106
+ # set default to N/A so that `self.name` in the errors will not trigger getattr again
107
+ name: str = "N/A"
108
+
109
+ _RENAMED = {
110
+ "class_names": "thing_classes",
111
+ "dataset_id_to_contiguous_id": "thing_dataset_id_to_contiguous_id",
112
+ "stuff_class_names": "stuff_classes",
113
+ }
114
+
115
+ def __getattr__(self, key):
116
+ if key in self._RENAMED:
117
+ log_first_n(
118
+ logging.WARNING,
119
+ "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
120
+ n=10,
121
+ )
122
+ return getattr(self, self._RENAMED[key])
123
+
124
+ # "name" exists in every metadata
125
+ if len(self.__dict__) > 1:
126
+ raise AttributeError(
127
+ "Attribute '{}' does not exist in the metadata of dataset '{}'. Available "
128
+ "keys are {}.".format(key, self.name, str(self.__dict__.keys()))
129
+ )
130
+ else:
131
+ raise AttributeError(
132
+ f"Attribute '{key}' does not exist in the metadata of dataset '{self.name}': "
133
+ "metadata is empty."
134
+ )
135
+
136
+ def __setattr__(self, key, val):
137
+ if key in self._RENAMED:
138
+ log_first_n(
139
+ logging.WARNING,
140
+ "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
141
+ n=10,
142
+ )
143
+ setattr(self, self._RENAMED[key], val)
144
+
145
+ # Ensure that metadata of the same name stays consistent
146
+ try:
147
+ oldval = getattr(self, key)
148
+ assert oldval == val, (
149
+ "Attribute '{}' in the metadata of '{}' cannot be set "
150
+ "to a different value!\n{} != {}".format(key, self.name, oldval, val)
151
+ )
152
+ except AttributeError:
153
+ super().__setattr__(key, val)
154
+
155
+ def as_dict(self):
156
+ """
157
+ Returns all the metadata as a dict.
158
+ Note that modifications to the returned dict will not reflect on the Metadata object.
159
+ """
160
+ return copy.copy(self.__dict__)
161
+
162
+ def set(self, **kwargs):
163
+ """
164
+ Set multiple metadata with kwargs.
165
+ """
166
+ for k, v in kwargs.items():
167
+ setattr(self, k, v)
168
+ return self
169
+
170
+ def get(self, key, default=None):
171
+ """
172
+ Access an attribute and return its value if exists.
173
+ Otherwise return default.
174
+ """
175
+ try:
176
+ return getattr(self, key)
177
+ except AttributeError:
178
+ return default
179
+
180
+
181
+ class _MetadataCatalog(UserDict):
182
+ """
183
+ MetadataCatalog is a global dictionary that provides access to
184
+ :class:`Metadata` of a given dataset.
185
+
186
+ The metadata associated with a certain name is a singleton: once created, the
187
+ metadata will stay alive and will be returned by future calls to ``get(name)``.
188
+
189
+ It's like global variables, so don't abuse it.
190
+ It's meant for storing knowledge that's constant and shared across the execution
191
+ of the program, e.g.: the class names in COCO.
192
+ """
193
+
194
+ def get(self, name):
195
+ """
196
+ Args:
197
+ name (str): name of a dataset (e.g. coco_2014_train).
198
+
199
+ Returns:
200
+ Metadata: The :class:`Metadata` instance associated with this name,
201
+ or create an empty one if none is available.
202
+ """
203
+ assert len(name)
204
+ r = super().get(name, None)
205
+ if r is None:
206
+ r = self[name] = Metadata(name=name)
207
+ return r
208
+
209
+ def list(self):
210
+ """
211
+ List all registered metadata.
212
+
213
+ Returns:
214
+ list[str]: keys (names of datasets) of all registered metadata
215
+ """
216
+ return list(self.keys())
217
+
218
+ def remove(self, name):
219
+ """
220
+ Alias of ``pop``.
221
+ """
222
+ self.pop(name)
223
+
224
+ def __str__(self):
225
+ return "MetadataCatalog(registered metadata: {})".format(", ".join(self.keys()))
226
+
227
+ __repr__ = __str__
228
+
229
+
230
+ MetadataCatalog = _MetadataCatalog()
231
+ MetadataCatalog.__doc__ = (
232
+ _MetadataCatalog.__doc__
233
+ + """
234
+ .. automethod:: detectron2.data.catalog.MetadataCatalog.get
235
+ """
236
+ )
detectron2/data/clip_build.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ import bisect
3
+ import copy
4
+ import logging
5
+ import os
6
+ import torch
7
+ import torch.utils.data
8
+ import torch.distributed
9
+ from torch.utils.data.dataset import ConcatDataset
10
+
11
+ from .catalog import DatasetCatalog
12
+ from .clip_datasets.clip_img_txt_pair_tsv import CLIPImgTxtPairTSVDataset
13
+
14
+ from .transforms.build import build_clip_transforms
15
+
16
+ def config_tsv_dataset_args(cfg, dataset_file, factory_name=None, is_train=True):
17
+ ############### code removecd as tsv_dataset_name = factory_name = "CLIPImgTxtPairTSVDataset" ##############
18
+ if factory_name is not None:
19
+ tsv_dataset_name = factory_name
20
+
21
+ if tsv_dataset_name in ["CLIPImgTxtPairTSVDataset"]:
22
+ # no need for extra arguments
23
+ args = {}
24
+ args['args'] = cfg
25
+ args['seq_len'] = cfg.DATASETS.MAX_SEQ_LENGTH # cfg.max_seq_length
26
+
27
+ return args, tsv_dataset_name
28
+
29
+
30
+ def build_dataset(cfg, transforms, dataset_catalog, is_train=True, is_aux=False):
31
+ """
32
+ Arguments:
33
+ cfg: config file.
34
+ transforms (callable): transforms to apply to each (image, target) sample
35
+ dataset_catalog (DatasetCatalog): contains the information on how to construct a dataset.
36
+ is_train (bool): whether to setup the dataset for training or testing
37
+ """
38
+
39
+ dataset_list = (cfg.DATASETS.TRAIN if not is_aux else cfg.DATASETS.AUX) if is_train else cfg.DATASETS.TEST
40
+ factory_list = (cfg.DATASETS.FACTORY_TRAIN if not is_aux else cfg.DATASETS.FACTORY_AUX) if is_train else cfg.DATASETS.FACTORY_TEST
41
+ path_list = (cfg.DATASETS.PATH_TRAIN if not is_aux else cfg.DATASETS.PATH_AUX) if is_train else cfg.DATASETS.PATH_TEST
42
+
43
+ if not isinstance(dataset_list, (list, tuple)):
44
+ raise RuntimeError(
45
+ "dataset_list should be a list of strings, got {}".format(dataset_list))
46
+ if not isinstance(factory_list, (list, tuple)):
47
+ raise RuntimeError(
48
+ "factory_list should be a list of strings, got {}".format(factory_list))
49
+ datasets = []
50
+ target_offset = 0
51
+ for i, dataset_name in enumerate(dataset_list):
52
+ factory_name = factory_list[i] if i < len(factory_list) else None
53
+
54
+ if factory_name == "CLIPImgTxtPairTSVDataset":
55
+ dataset_names_merged = dataset_name.split('+')
56
+ path_lists_merged = path_list[i].split('+')
57
+
58
+ assert len(dataset_names_merged) == len(path_lists_merged), "number of datasets must match that of dataset paths"
59
+
60
+ image_tsv_list = []
61
+ text_tsv_list = []
62
+ dataset_name_list = []
63
+ map_files = []
64
+ max_num_tsv = 20 # maximum tsv files to load within a given folder
65
+
66
+ for dname, dpath in zip(dataset_names_merged, path_lists_merged):
67
+ args, tsv_dataset_name = config_tsv_dataset_args(
68
+ cfg, dataset_name, factory_name, is_train
69
+ )
70
+ factory = CLIPImgTxtPairTSVDataset if tsv_dataset_name in ["CLIPImgTxtPairTSVDataset"] else None
71
+ prev_len = len(image_tsv_list)
72
+
73
+ isFile = os.path.isfile(dpath)
74
+ if isFile:
75
+ dpath_listed_files = [os.path.basename(dpath)]
76
+ dpath = os.path.dirname(dpath)
77
+ else:
78
+ dpath_listed_files = sorted(os.listdir(dpath))
79
+
80
+ for filename in dpath_listed_files:
81
+ if ("images" in filename or "image" in filename or "img" in filename) and filename.endswith(".tsv"):
82
+ image_tsv_list.append(os.path.join(dpath, filename))
83
+ if "images" in filename: # "images" - "text"
84
+ text_tsv_list.append(os.path.join(dpath, filename.replace("images", "text")))
85
+ elif "image" in filename: # "image"-"text"
86
+ text_tsv_list.append(os.path.join(dpath, filename.replace("image", "text")))
87
+ elif "img" in filename: # "img"-"caption"
88
+ text_tsv_list.append(os.path.join(dpath, filename.replace("img", "caption")))
89
+ if len(image_tsv_list) - prev_len == max_num_tsv:
90
+ break
91
+ dataset_name_list += [dname] * (len(image_tsv_list) - prev_len)
92
+
93
+ if dname == "imagenet22k":
94
+ map_files += [os.path.join(dpath, 'darknet_data_imagenet.labels.list')] * (len(image_tsv_list) - prev_len)
95
+ else:
96
+ map_files += [None] * (len(image_tsv_list) - prev_len)
97
+
98
+ assert len(image_tsv_list) == len(text_tsv_list), \
99
+ "the number image tsv files must be equal to that of text tsv files, otherwise check your data!"
100
+
101
+ args["image_tsv_file"] = image_tsv_list
102
+ args["text_tsv_file"] = text_tsv_list
103
+ args["dataset_name"] = dataset_name_list
104
+ args["map_file"] = map_files
105
+ args["filtered_datasets"] = cfg.DATASETS.FILTERED_CLASSIFICATION_DATASETS
106
+ assert len(image_tsv_list) == len(text_tsv_list) == len(dataset_name_list) == len(map_files)
107
+
108
+ print("number of image tsv files: ", len(image_tsv_list))
109
+ print("number of text tsv fies: ", len(text_tsv_list))
110
+
111
+ args["is_train"] = is_train
112
+ args["transforms"] = transforms
113
+ args["target_offset"] = target_offset
114
+ if "bpe" in cfg.INPUT.TEXT_TOKENIZER:
115
+ from detectron2.data.datasets.clip_prompt_utils import SimpleTokenizer as _Tokenizer
116
+ tokenizer = _Tokenizer()
117
+ args["tokenizer_type"] = "bpe"
118
+ args["tokenizer"] = tokenizer
119
+ # make dataset from factory
120
+ dataset = factory(**args)
121
+ datasets.append(dataset)
122
+
123
+ precomputed_tokens = {}
124
+ dataset_classes = {}
125
+ for dataset in datasets:
126
+ if hasattr(dataset, "input_ids_all_classes"):
127
+ precomputed_tokens["imagenet"] = \
128
+ [dataset.input_ids_all_classes, dataset.input_mask_all_classes, dataset.segment_ids_all_classes]
129
+ if hasattr(dataset, "classnames"):
130
+ if isinstance(dataset.classnames, dict):
131
+ dataset_classes.update(dataset.classnames)
132
+ else:
133
+ dataset_classes[dataset.dataset_name] = dataset.classnames
134
+
135
+ # for testing, return a list of datasets
136
+ if not is_train:
137
+ return datasets, precomputed_tokens, dataset_classes
138
+
139
+ if len(datasets) == 0:
140
+ return None, None, None
141
+
142
+ # for training, concatenate all datasets into a single one
143
+ dataset = datasets[0]
144
+ if len(datasets) > 1:
145
+ dataset = ConcatDataset(datasets)
146
+ return [dataset], precomputed_tokens, dataset_classes
147
+
148
+
149
+ def make_clip_dataset(cfg, is_train=True, is_aux=False, transforms=None):
150
+ if transforms is None:
151
+ transforms = build_clip_transforms(cfg, is_train)
152
+ print("data transforms: ")
153
+ print(transforms)
154
+ datasets, precomputed_tokens, dataset_classes = build_dataset(cfg, transforms, DatasetCatalog, is_train, is_aux)
155
+
156
+ if not datasets:
157
+ return None, None, None
158
+ return datasets, precomputed_tokens, dataset_classes