subhc commited on
Commit
5e88f62
1 Parent(s): 8694ca4

Code Commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +5 -5
  3. app.py +158 -0
  4. checkpoints/checkpoint_best.pth +3 -0
  5. config.py +377 -0
  6. configs/README.md +16 -0
  7. configs/maskformer/Base-ADE20K-150.yaml +60 -0
  8. configs/maskformer/Base-unsup-vidseg.yaml +3 -0
  9. configs/maskformer/cityscapes-19/Base-Cityscapes-19.yaml +59 -0
  10. configs/maskformer/cityscapes-19/maskformer_R101_bs16_90k.yaml +36 -0
  11. configs/maskformer/cityscapes-19/maskformer_R101c_bs16_90k.yaml +16 -0
  12. configs/maskformer/maskformer_R50_bs16_160k.yaml +27 -0
  13. configs/maskformer/maskformer_R50_bs16_160k_dino.yaml +31 -0
  14. configs/maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml +45 -0
  15. configs/maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml +45 -0
  16. configs/maskformer/swin/maskformer_swin_small_bs16_160k.yaml +23 -0
  17. configs/maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml +23 -0
  18. configs/maskformer/swin/maskformer_swin_tiny_bs16_160k_moby.yaml +23 -0
  19. datasets/__init__.py +2 -0
  20. datasets/flow_eval_detectron.py +209 -0
  21. datasets/flow_pair_detectron.py +275 -0
  22. determinism.py +24 -0
  23. dist.py +34 -0
  24. eval_utils.py +282 -0
  25. flow_reconstruction.py +54 -0
  26. losses/__init__.py +28 -0
  27. losses/reconstruction_loss.py +85 -0
  28. main.py +270 -0
  29. mask_former/__init__.py +19 -0
  30. mask_former/config.py +85 -0
  31. mask_former/data/__init__.py +2 -0
  32. mask_former/data/dataset_mappers/__init__.py +1 -0
  33. mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py +180 -0
  34. mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py +165 -0
  35. mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py +184 -0
  36. mask_former/data/datasets/__init__.py +7 -0
  37. mask_former/data/datasets/register_ade20k_full.py +964 -0
  38. mask_former/data/datasets/register_ade20k_panoptic.py +387 -0
  39. mask_former/data/datasets/register_coco_stuff_10k.py +223 -0
  40. mask_former/data/datasets/register_mapillary_vistas.py +507 -0
  41. mask_former/mask_former_model.py +355 -0
  42. mask_former/modeling/__init__.py +9 -0
  43. mask_former/modeling/backbone/__init__.py +1 -0
  44. mask_former/modeling/backbone/swin.py +772 -0
  45. mask_former/modeling/backbone/vit.py +441 -0
  46. mask_former/modeling/criterion.py +187 -0
  47. mask_former/modeling/heads/__init__.py +1 -0
  48. mask_former/modeling/heads/big_pixel_decoder.py +228 -0
  49. mask_former/modeling/heads/mask_former_head.py +120 -0
  50. mask_former/modeling/heads/mask_former_head_baseline.py +123 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ checkpoints/checkpoint_best.pth filter=lfs diff=lfs merge=lfs -text
36
+ samples/1920px-Woman_at_work,_Gujarat.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Guess What Moves
3
- emoji: 🐨
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.17.0
8
  app_file: app.py
@@ -10,4 +10,4 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: GWM
3
+ emoji: 🏄
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.17.0
8
  app_file: app.py
 
10
  license: mit
11
  ---
12
 
13
+ This is a demo for https://www.robots.ox.ac.uk/~vgg/research/gwm/.
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ try:
5
+ import detectron2
6
+ except:
7
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
8
+
9
+ import logging
10
+ logging.disable(logging.CRITICAL) # comment out to enable verbose logging
11
+
12
+ #########################################################
13
+ import pathlib
14
+ import gradio as gr
15
+ import numpy as np
16
+ import PIL.Image as Image
17
+ import os
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import matplotlib.pyplot as plt
22
+ from PIL import Image
23
+ from collections import defaultdict
24
+ from pathlib import Path
25
+ from torch.utils.data import Dataset, DataLoader
26
+ from torchvision import transforms as T
27
+ from tqdm import tqdm
28
+ from types import SimpleNamespace
29
+ from detectron2.checkpoint import DetectionCheckpointer
30
+ from detectron2.data import MetadataCatalog
31
+ from detectron2.utils.visualizer import Visualizer
32
+
33
+ import config
34
+ import utils as ut
35
+ from eval_utils import MaskMerger
36
+ from mask_former_trainer import setup, Trainer
37
+
38
+
39
+ def load_model_cfg(dataset=None):
40
+
41
+ args = SimpleNamespace(config_file='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml', opts=["GWM.DATASET", dataset], wandb_sweep_mode=False, resume_path=str('checkpoints/checkpoint_best.pth'), eval_only=True)
42
+ cfg = setup(args)
43
+ cfg.defrost()
44
+ cfg.MODEL.DEVICE = 'cpu'
45
+ cfg.freeze()
46
+ random_state = ut.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE))
47
+
48
+ model = Trainer.build_model(cfg)
49
+ checkpointer = DetectionCheckpointer(model,
50
+ random_state=random_state,
51
+ save_dir=None)
52
+
53
+ checkpoint_path = 'checkpoints/checkpoint_best.pth'
54
+ checkpoint = checkpointer.resume_or_load(checkpoint_path, resume=False)
55
+ model.eval()
56
+
57
+ return model, cfg
58
+
59
+ def edgeness(masks):
60
+
61
+ em = torch.zeros(1, masks.shape[-2], masks.shape[-1], device=masks.device)
62
+ lm = em.clone()
63
+ lm[..., :2] = 1.
64
+ rm = em.clone()
65
+ rm[...,-2:] = 1.
66
+ tm = em.clone()
67
+ tm[..., :2, :] = 1.
68
+ bm = em.clone()
69
+ bm[..., -2:,:] = 1.
70
+
71
+ one = torch.tensor(1.,dtype= masks.dtype, device=masks.device)
72
+
73
+ l = (masks * lm).flatten(-2).sum(-1) / lm.sum()
74
+ l = torch.where(l > 0.3, one, l)
75
+ r = (masks * rm).flatten(-2).sum(-1) / rm.sum()
76
+ r = torch.where(r > 0.3, one, r)
77
+ t = (masks * tm).flatten(-2).sum(-1) / tm.sum()
78
+ t = torch.where(t > 0.3, one, t)
79
+ b = (masks * bm).flatten(-2).sum(-1) / bm.sum()
80
+ b = torch.where(b > 0.3, one, b)
81
+ return (l + r + t + b )
82
+
83
+ def expand2sizedivisible(pil_img, background_color, size_divisibility):
84
+ width, height = pil_img.size
85
+ if width % size_divisibility == 0 and height % size_divisibility == 0:
86
+ return pil_img
87
+ result = Image.new(pil_img.mode, (width + (size_divisibility - width%size_divisibility)%size_divisibility, height + (size_divisibility - height%size_divisibility)%size_divisibility), background_color)
88
+ result.paste(pil_img, (((size_divisibility - width%size_divisibility)%size_divisibility) // 2, ((size_divisibility - height%size_divisibility)%size_divisibility) // 2))
89
+
90
+ return result
91
+
92
+ def cropfromsizedivisible(img, size_divisibility, orig_size):
93
+ height, width = img.shape[:2]
94
+ owidth, oheight = orig_size
95
+ result = img[(height-oheight)//2:oheight+(height-oheight)//2, (width-owidth)//2:owidth+(width-owidth)//2]
96
+
97
+ return result
98
+
99
+
100
+ def evaluate_image(image_path):
101
+ binary_threshold = 0.5
102
+ metadata = MetadataCatalog.get("__unused")
103
+
104
+ model, cfg = load_model_cfg("DAVIS")
105
+
106
+ merger = MaskMerger(cfg, model, merger_model="dino_vitb8")
107
+
108
+
109
+ image_pil = Image.open(image_path).convert('RGB')
110
+
111
+ image_pil.thumbnail((384, 384))
112
+
113
+ osize = image_pil.size
114
+ if cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY > 0:
115
+ image_pil = expand2sizedivisible(image_pil, 0, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY)
116
+
117
+ image = np.asarray(image_pil)
118
+ image_pt = torch.from_numpy(np.array(image)).permute(2,0,1)
119
+
120
+ with torch.no_grad():
121
+ sample = [{'rgb': image_pt}]
122
+ preds = model.forward_base(sample, keys=['rgb'], get_eval=True)
123
+ masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
124
+
125
+ K = masks_raw.shape[1]
126
+ if K > 2:
127
+ masks_softmaxed = torch.softmax(masks_raw, dim=1)
128
+ masks_dict = merger(sample, masks_softmaxed)
129
+ K = 2
130
+ masks = masks_dict['cos']
131
+ else:
132
+ print(K)
133
+ masks = masks_raw.softmax(1)
134
+ masks_raw = F.interpolate(masks, size=(image_pt.shape[-2], image_pt.shape[-1]), mode='bilinear') # t s 1 h w
135
+ bg = edgeness(masks_raw)[0].argmax().item()
136
+
137
+ masks = masks_raw[0] > binary_threshold
138
+ frame_visualizer = Visualizer(image, metadata)
139
+ out = frame_visualizer.overlay_instances(
140
+ masks=masks[[int(bg==0)]],
141
+ alpha=0.3,
142
+ assigned_colors=[(1,0,1)]
143
+ ).get_image()
144
+
145
+ return cropfromsizedivisible(out, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, osize)
146
+
147
+
148
+ paths = sorted(pathlib.Path('samples').glob('*.jpg'))
149
+ css = ".component-1 {height: 256px !important;}"
150
+ demo = gr.Interface(
151
+ fn=evaluate_image,
152
+ inputs=gr.Image(label='Image', type='filepath'),
153
+ outputs=gr.Image(label='Annotated Image', type='numpy'),
154
+ examples=[[path.as_posix(), 0.15, 6] for path in paths],
155
+ title="Guess What Moves",
156
+ description="#### Unsupervised Image segmentation mode of [Guess What Moves](https://www.robots.ox.ac.uk/~vgg/research/gwm/)",
157
+ css=css)
158
+ demo.queue().launch()
checkpoints/checkpoint_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a573e43ffc78e84dbf8b4f2c9e31195bed3c44d8e7be942daba92c7437b8b17d
3
+ size 63672283
config.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch.utils.data
9
+ from detectron2.config import CfgNode as CN
10
+
11
+ import utils
12
+ from datasets import FlowPairDetectron, FlowEvalDetectron
13
+
14
+ logger = logging.getLogger('gwm')
15
+
16
+ def scan_train_flow(folders, res, pairs, basepath):
17
+ pair_list = [p for p in itertools.combinations(pairs, 2)]
18
+
19
+ flow_dir = {}
20
+ for pair in pair_list:
21
+ p1, p2 = pair
22
+ flowpairs = []
23
+ for f in folders:
24
+ path1 = basepath / f'Flows_gap{p1}' / res / f
25
+ path2 = basepath / f'Flows_gap{p2}' / res / f
26
+
27
+ flows1 = [p.name for p in path1.glob('*.flo')]
28
+ flows2 = [p.name for p in path2.glob('*.flo')]
29
+
30
+ flows1 = sorted(flows1)
31
+ flows2 = sorted(flows2)
32
+
33
+ intersect = list(set(flows1).intersection(flows2))
34
+ intersect.sort()
35
+
36
+ flowpair = np.array([[path1 / i, path2 / i] for i in intersect])
37
+ flowpairs += [flowpair]
38
+ flow_dir['gap_{}_{}'.format(p1, p2)] = flowpairs
39
+
40
+ # flow_dir is a dictionary, with keys indicating the flow gap, and each value is a list of sequence names,
41
+ # each item then is an array with Nx2, N indicates the number of available pairs.
42
+ return flow_dir
43
+
44
+
45
+ def setup_dataset(cfg=None, multi_val=False):
46
+ dataset_str = cfg.GWM.DATASET
47
+ if '+' in dataset_str:
48
+ datasets = dataset_str.split('+')
49
+ logger.info(f'Multiple datasets detected: {datasets}')
50
+ train_datasets = []
51
+ val_datasets = []
52
+ for ds in datasets:
53
+ proxy_cfg = copy.deepcopy(cfg)
54
+ proxy_cfg.merge_from_list(['GWM.DATASET', ds]),
55
+ train_ds, val_ds = setup_dataset(proxy_cfg, multi_val=multi_val)
56
+ train_datasets.append(train_ds)
57
+ val_datasets.append(val_ds)
58
+ logger.info(f'Multiple datasets detected: {datasets}')
59
+ logger.info(f'Validation is still : {datasets[0]}')
60
+ return torch.utils.data.ConcatDataset(train_datasets), val_datasets[0]
61
+
62
+ resolution = cfg.GWM.RESOLUTION # h,w
63
+ res = ""
64
+ with_gt = True
65
+ pairs = [1, 2, -1, -2]
66
+ trainval_data_dir = None
67
+
68
+ if cfg.GWM.DATASET == 'DAVIS':
69
+ basepath = '/DAVIS2016'
70
+ img_dir = '/DAVIS2016/JPEGImages/480p'
71
+ gt_dir = '/DAVIS2016/Annotations/480p'
72
+
73
+ val_flow_dir = '/DAVIS2016/Flows_gap1/1080p'
74
+ val_seq = ['dog', 'cows', 'goat', 'camel', 'libby', 'parkour', 'soapbox', 'blackswan', 'bmx-trees',
75
+ 'kite-surf', 'car-shadow', 'breakdance', 'dance-twirl', 'scooter-black', 'drift-chicane',
76
+ 'motocross-jump', 'horsejump-high', 'drift-straight', 'car-roundabout', 'paragliding-launch']
77
+ val_data_dir = [val_flow_dir, img_dir, gt_dir]
78
+ res = "1080p"
79
+
80
+ elif cfg.GWM.DATASET in ['FBMS']:
81
+ basepath = '/FBMS_clean'
82
+ img_dir = '/FBMS_clean/JPEGImages/'
83
+ gt_dir = '/FBMS_clean/Annotations/'
84
+
85
+ val_flow_dir = '/FBMS_val/Flows_gap1/'
86
+ val_seq = ['camel01', 'cars1', 'cars10', 'cars4', 'cars5', 'cats01', 'cats03', 'cats06',
87
+ 'dogs01', 'dogs02', 'farm01', 'giraffes01', 'goats01', 'horses02', 'horses04',
88
+ 'horses05', 'lion01', 'marple12', 'marple2', 'marple4', 'marple6', 'marple7', 'marple9',
89
+ 'people03', 'people1', 'people2', 'rabbits02', 'rabbits03', 'rabbits04', 'tennis']
90
+ val_img_dir = '/FBMS_val/JPEGImages/'
91
+ val_gt_dir = '/FBMS_val/Annotations/'
92
+ val_data_dir = [val_flow_dir, val_img_dir, val_gt_dir]
93
+ with_gt = False
94
+ pairs = [3, 6, -3, -6]
95
+
96
+ elif cfg.GWM.DATASET in ['STv2']:
97
+ basepath = '/SegTrackv2'
98
+ img_dir = '/SegTrackv2/JPEGImages'
99
+ gt_dir = '/SegTrackv2/Annotations'
100
+
101
+ val_flow_dir = '/SegTrackv2/Flows_gap1/'
102
+ val_seq = ['drift', 'birdfall', 'girl', 'cheetah', 'worm', 'parachute', 'monkeydog',
103
+ 'hummingbird', 'soldier', 'bmx', 'frog', 'penguin', 'monkey', 'bird_of_paradise']
104
+ val_data_dir = [val_flow_dir, img_dir, gt_dir]
105
+
106
+ else:
107
+ raise ValueError('Unknown Setting/Dataset.')
108
+
109
+ # Switching this section to pathlib, which should prevent double // errors in paths and dict keys
110
+
111
+ root_path_str = cfg.GWM.DATA_ROOT
112
+ logger.info(f"Found DATA_ROOT in config: {root_path_str}")
113
+ root_path_str = '../data'
114
+
115
+ if root_path_str.startswith('/'):
116
+ root_path = Path(f"/{root_path_str.lstrip('/').rstrip('/')}")
117
+ else:
118
+ root_path = Path(f"{root_path_str.lstrip('/').rstrip('/')}")
119
+
120
+ logger.info(f"Loading dataset from: {root_path}")
121
+
122
+ basepath = root_path / basepath.lstrip('/').rstrip('/')
123
+ img_dir = root_path / img_dir.lstrip('/').rstrip('/')
124
+ gt_dir = root_path / gt_dir.lstrip('/').rstrip('/')
125
+ val_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in val_data_dir]
126
+
127
+ folders = [p.name for p in (basepath / f'Flows_gap{pairs[0]}' / res).iterdir() if p.is_dir()]
128
+ folders = sorted(folders)
129
+
130
+ # flow_dir is a dictionary, with keys indicating the flow gap, and each value is a list of sequence names,
131
+ # each item then is an array with Nx2, N indicates the number of available pairs.
132
+
133
+ flow_dir = scan_train_flow(folders, res, pairs, basepath)
134
+ data_dir = [flow_dir, img_dir, gt_dir]
135
+
136
+ force1080p = ('DAVIS' not in cfg.GWM.DATASET) and 'RGB_BIG' in cfg.GWM.SAMPLE_KEYS
137
+
138
+ enable_photometric_augmentations = cfg.FLAGS.INF_TPS
139
+
140
+ train_dataset = FlowPairDetectron(data_dir=data_dir,
141
+ resolution=resolution,
142
+ to_rgb=cfg.GWM.FLOW2RGB,
143
+ size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
144
+ enable_photo_aug=enable_photometric_augmentations,
145
+ flow_clip=cfg.GWM.FLOW_CLIP,
146
+ norm=cfg.GWM.FLOW_NORM,
147
+ force1080p=force1080p,
148
+ flow_res=cfg.GWM.FLOW_RES, )
149
+ if multi_val:
150
+ print(f"Using multiple validation datasets from {val_data_dir}")
151
+ val_dataset = [FlowEvalDetectron(data_dir=val_data_dir,
152
+ resolution=resolution,
153
+ pair_list=pairs,
154
+ val_seq=[vs],
155
+ to_rgb=cfg.GWM.FLOW2RGB,
156
+ with_rgb=False,
157
+ size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
158
+ flow_clip=cfg.GWM.FLOW_CLIP,
159
+ norm=cfg.GWM.FLOW_NORM,
160
+ force1080p=force1080p) for vs in val_seq]
161
+ for vs, vds in zip(val_seq, val_dataset):
162
+ print(f"Validation dataset for {vs}: {len(vds)}")
163
+ if len(vds) == 0:
164
+ raise ValueError(f"Empty validation dataset for {vs}")
165
+
166
+ if cfg.GWM.TTA_AS_TRAIN:
167
+ if trainval_data_dir is None:
168
+ trainval_data_dir = val_data_dir
169
+ else:
170
+ trainval_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in trainval_data_dir]
171
+ trainval_dataset = []
172
+ tvd_basepath = root_path / str(trainval_data_dir[0].relative_to(root_path)).split('/')[0]
173
+ print("TVD BASE DIR", tvd_basepath)
174
+ for vs in val_seq:
175
+ tvd_data_dir = [scan_train_flow([vs], res, pairs, tvd_basepath), *trainval_data_dir[1:]]
176
+ tvd = FlowPairDetectron(data_dir=tvd_data_dir,
177
+ resolution=resolution,
178
+ to_rgb=cfg.GWM.FLOW2RGB,
179
+ size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
180
+ enable_photo_aug=cfg.GWM.LOSS_MULT.EQV is not None,
181
+ flow_clip=cfg.GWM.FLOW_CLIP,
182
+ norm=cfg.GWM.FLOW_NORM,
183
+ force1080p=force1080p,
184
+ flow_res=cfg.GWM.FLOW_RES, )
185
+ trainval_dataset.append(tvd)
186
+ print(f'Seq {trainval_data_dir[0]}/{vs} dataset: {len(tvd)}')
187
+ else:
188
+ if trainval_data_dir is None:
189
+ trainval_dataset = val_dataset
190
+ else:
191
+ trainval_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in trainval_data_dir]
192
+ trainval_dataset = []
193
+ for vs in val_seq:
194
+ tvd = FlowEvalDetectron(data_dir=trainval_data_dir,
195
+ resolution=resolution,
196
+ pair_list=pairs,
197
+ val_seq=[vs],
198
+ to_rgb=cfg.GWM.FLOW2RGB,
199
+ with_rgb=False,
200
+ size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
201
+ flow_clip=cfg.GWM.FLOW_CLIP,
202
+ norm=cfg.GWM.FLOW_NORM,
203
+ force1080p=force1080p)
204
+ trainval_dataset.append(tvd)
205
+ print(f'Seq {trainval_data_dir[0]}/{vs} dataset: {len(tvd)}')
206
+ return train_dataset, val_dataset, trainval_dataset
207
+ val_dataset = FlowEvalDetectron(data_dir=val_data_dir,
208
+ resolution=resolution,
209
+ pair_list=pairs,
210
+ val_seq=val_seq,
211
+ to_rgb=cfg.GWM.FLOW2RGB,
212
+ with_rgb=False,
213
+ size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
214
+ flow_clip=cfg.GWM.FLOW_CLIP,
215
+ norm=cfg.GWM.FLOW_NORM,
216
+ force1080p=force1080p)
217
+
218
+ return train_dataset, val_dataset
219
+
220
+
221
+ def loaders(cfg):
222
+ train_dataset, val_dataset = setup_dataset(cfg)
223
+ logger.info(f"Sourcing data from {val_dataset.data_dir[0]}")
224
+
225
+ if cfg.FLAGS.DEV_DATA:
226
+ subset = cfg.SOLVER.IMS_PER_BATCH * 3
227
+ train_dataset = torch.utils.data.Subset(train_dataset, list(range(subset)))
228
+ val_dataset = torch.utils.data.Subset(val_dataset, list(range(subset)))
229
+
230
+ g = torch.Generator()
231
+ data_generator_seed = int(torch.randint(int(1e6), (1,)).item())
232
+ logger.info(f"Dataloaders generator seed {data_generator_seed}")
233
+ g.manual_seed(data_generator_seed)
234
+
235
+ train_loader = torch.utils.data.DataLoader(train_dataset,
236
+ num_workers=cfg.DATALOADER.NUM_WORKERS,
237
+ batch_size=cfg.SOLVER.IMS_PER_BATCH,
238
+ collate_fn=lambda x: x,
239
+ shuffle=True,
240
+ pin_memory=True,
241
+ drop_last=True,
242
+ persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0,
243
+ worker_init_fn=utils.random_state.worker_init_function,
244
+ generator=g
245
+ )
246
+ val_loader = torch.utils.data.DataLoader(val_dataset,
247
+ num_workers=cfg.DATALOADER.NUM_WORKERS,
248
+ batch_size=1,
249
+ shuffle=False,
250
+ pin_memory=True,
251
+ collate_fn=lambda x: x,
252
+ drop_last=False,
253
+ persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0,
254
+ worker_init_fn=utils.random_state.worker_init_function,
255
+ generator=g)
256
+ return train_loader, val_loader
257
+
258
+
259
+ def multi_loaders(cfg):
260
+ train_dataset, val_datasets, train_val_datasets = setup_dataset(cfg, multi_val=True)
261
+ logger.info(f"Sourcing multiple loaders from {len(val_datasets)}")
262
+ logger.info(f"Sourcing data from {val_datasets[0].data_dir[0]}")
263
+
264
+ g = torch.Generator()
265
+ data_generator_seed = int(torch.randint(int(1e6), (1,)).item())
266
+ logger.info(f"Dataloaders generator seed {data_generator_seed}")
267
+ g.manual_seed(data_generator_seed)
268
+
269
+ train_loader = torch.utils.data.DataLoader(train_dataset,
270
+ num_workers=cfg.DATALOADER.NUM_WORKERS,
271
+ batch_size=cfg.SOLVER.IMS_PER_BATCH,
272
+ collate_fn=lambda x: x,
273
+ shuffle=True,
274
+ pin_memory=True,
275
+ drop_last=True,
276
+ persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0,
277
+ worker_init_fn=utils.random_state.worker_init_function,
278
+ generator=g
279
+ )
280
+
281
+ val_loaders = [(torch.utils.data.DataLoader(val_dataset,
282
+ num_workers=0,
283
+ batch_size=1,
284
+ shuffle=False,
285
+ pin_memory=True,
286
+ collate_fn=lambda x: x,
287
+ drop_last=False,
288
+ persistent_workers=False,
289
+ worker_init_fn=utils.random_state.worker_init_function,
290
+ generator=g),
291
+ torch.utils.data.DataLoader(tv_dataset,
292
+ num_workers=0,
293
+ batch_size=cfg.SOLVER.IMS_PER_BATCH,
294
+ shuffle=True,
295
+ pin_memory=False,
296
+ collate_fn=lambda x: x,
297
+ drop_last=False,
298
+ persistent_workers=False,
299
+ worker_init_fn=utils.random_state.worker_init_function,
300
+ generator=g))
301
+ for val_dataset, tv_dataset in zip(val_datasets, train_val_datasets)]
302
+
303
+ return train_loader, val_loaders
304
+
305
+
306
+ def add_gwm_config(cfg):
307
+ cfg.GWM = CN()
308
+ cfg.GWM.MODEL = "MASKFORMER"
309
+ cfg.GWM.RESOLUTION = (128, 224)
310
+ cfg.GWM.FLOW_RES = (480, 854)
311
+ cfg.GWM.SAMPLE_KEYS = ["rgb"]
312
+ cfg.GWM.ADD_POS_EMB = False
313
+ cfg.GWM.CRITERION = "L2"
314
+ cfg.GWM.L1_OPTIMIZE = False
315
+ cfg.GWM.HOMOGRAPHY = 'quad' # False
316
+ cfg.GWM.HOMOGRAPHY_SUBSAMPLE = 8
317
+ cfg.GWM.HOMOGRAPHY_SKIP = 0.4
318
+ cfg.GWM.DATASET = 'DAVIS'
319
+ cfg.GWM.DATA_ROOT = None
320
+ cfg.GWM.FLOW2RGB = False
321
+ cfg.GWM.SIMPLE_REC = False
322
+ cfg.GWM.DAVIS_SINGLE_VID = None
323
+ cfg.GWM.USE_MULT_FLOW = False
324
+ cfg.GWM.FLOW_COLORSPACE_REC = None
325
+
326
+ cfg.GWM.FLOW_CLIP_U_LOW = float('-inf')
327
+ cfg.GWM.FLOW_CLIP_U_HIGH = float('inf')
328
+ cfg.GWM.FLOW_CLIP_V_LOW = float('-inf')
329
+ cfg.GWM.FLOW_CLIP_V_HIGH = float('inf')
330
+
331
+ cfg.GWM.FLOW_CLIP = float('inf')
332
+ cfg.GWM.FLOW_NORM = False
333
+
334
+ cfg.GWM.LOSS_MULT = CN()
335
+ cfg.GWM.LOSS_MULT.REC = 0.03
336
+ cfg.GWM.LOSS_MULT.HEIR_W = [0.1, 0.3, 0.6]
337
+
338
+
339
+ cfg.GWM.TTA = 100 # Test-time-adaptation
340
+ cfg.GWM.TTA_AS_TRAIN = False # Use train-like data logic for test-time-adaptation
341
+
342
+ cfg.GWM.LOSS = 'OG'
343
+
344
+ cfg.FLAGS = CN()
345
+ cfg.FLAGS.MAKE_VIS_VIDEOS = False # Making videos is kinda slow
346
+ cfg.FLAGS.EXTENDED_FLOW_RECON_VIS = False # Does not cost much
347
+ cfg.FLAGS.COMP_NLL_FOR_GT = False # Should we log loss against ground truth?
348
+ cfg.FLAGS.DEV_DATA = False
349
+ cfg.FLAGS.KEEP_ALL = True # Keep all checkoints
350
+ cfg.FLAGS.ORACLE_CHECK = False # Use oracle check to estimate max performance when grouping multiple components
351
+
352
+ cfg.FLAGS.INF_TPS = False
353
+
354
+ # cfg.FLAGS.UNFREEZE_AT = [(1, 10000), (0, 20000), (-1, 30000)]
355
+ cfg.FLAGS.UNFREEZE_AT = [(4, 0), (2, 500), (1, 1000), (-1, 10000)]
356
+
357
+ cfg.FLAGS.IGNORE_SIZE_DIV = False
358
+
359
+ cfg.FLAGS.IGNORE_TMP = True
360
+
361
+ cfg.WANDB = CN()
362
+ cfg.WANDB.ENABLE = False
363
+ cfg.WANDB.BASEDIR = '../'
364
+
365
+ cfg.DEBUG = False
366
+
367
+ cfg.LOG_ID = 'exp'
368
+ cfg.LOG_FREQ = 250
369
+ cfg.OUTPUT_BASEDIR = '../outputs'
370
+ cfg.SLURM = False
371
+ cfg.SKIP_TB = False
372
+ cfg.TOTAL_ITER = 20000
373
+ cfg.CONFIG_FILE = None
374
+
375
+ if os.environ.get('SLURM_JOB_ID', None):
376
+ cfg.LOG_ID = os.environ.get('SLURM_JOB_NAME', cfg.LOG_ID)
377
+ logger.info(f"Setting name {cfg.LOG_ID} based on SLURM job name")
configs/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Available configs:
2
+
3
+
4
+ Use `main.py --config-file=<config-file>`
5
+
6
+ No need to specify `GWM.MODEL`. It is already defined inside the config files
7
+
8
+
9
+ ### Available configs:
10
+ ```
11
+ maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml
12
+ maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml
13
+ maskformer/swin/maskformer_swin_small_bs16_160k.yaml
14
+ maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml
15
+ maskformer/maskformer_R50_bs16_160k.yaml
16
+ ```
configs/maskformer/Base-ADE20K-150.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: Base-unsup-vidseg.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ FREEZE_AT: 0
5
+ NAME: "build_resnet_backbone"
6
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
7
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
8
+ PIXEL_STD: [58.395, 57.120, 57.375]
9
+ RESNETS:
10
+ DEPTH: 50
11
+ STEM_TYPE: "basic" # not used
12
+ STEM_OUT_CHANNELS: 64
13
+ STRIDE_IN_1X1: False
14
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
15
+ # NORM: "SyncBN"
16
+ RES5_MULTI_GRID: [1, 1, 1] # not used
17
+ DATASETS:
18
+ TRAIN: ("ade20k_sem_seg_train",)
19
+ TEST: ("ade20k_sem_seg_val",)
20
+ SOLVER:
21
+ IMS_PER_BATCH: 16
22
+ BASE_LR: 0.0001
23
+ MAX_ITER: 160000
24
+ WARMUP_FACTOR: 1.0
25
+ WARMUP_ITERS: 0
26
+ WEIGHT_DECAY: 0.0001
27
+ OPTIMIZER: "ADAMW"
28
+ LR_SCHEDULER_NAME: "WarmupPolyLR"
29
+ BACKBONE_MULTIPLIER: 0.1
30
+ CLIP_GRADIENTS:
31
+ ENABLED: True
32
+ CLIP_TYPE: "full_model"
33
+ CLIP_VALUE: 0.01
34
+ NORM_TYPE: 2.0
35
+ INPUT:
36
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"]
37
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
38
+ MIN_SIZE_TEST: 512
39
+ MAX_SIZE_TRAIN: 2048
40
+ MAX_SIZE_TEST: 2048
41
+ CROP:
42
+ ENABLED: True
43
+ TYPE: "absolute"
44
+ SIZE: (512, 512)
45
+ SINGLE_CATEGORY_MAX_AREA: 1.0
46
+ COLOR_AUG_SSD: True
47
+ SIZE_DIVISIBILITY: 512 # used in dataset mapper
48
+ FORMAT: "RGB"
49
+ DATASET_MAPPER_NAME: "mask_former_semantic"
50
+ TEST:
51
+ EVAL_PERIOD: 5000
52
+ AUG:
53
+ ENABLED: False
54
+ MIN_SIZES: [256, 384, 512, 640, 768, 896]
55
+ MAX_SIZE: 3584
56
+ FLIP: True
57
+ DATALOADER:
58
+ FILTER_EMPTY_ANNOTATIONS: True
59
+ NUM_WORKERS: 10
60
+ VERSION: 2
configs/maskformer/Base-unsup-vidseg.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ SEED: 42
2
+ GWM:
3
+ MODEL: "MASKFORMER"
configs/maskformer/cityscapes-19/Base-Cityscapes-19.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ BACKBONE:
3
+ FREEZE_AT: 0
4
+ NAME: "build_resnet_backbone"
5
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
6
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
7
+ PIXEL_STD: [58.395, 57.120, 57.375]
8
+ RESNETS:
9
+ DEPTH: 50
10
+ STEM_TYPE: "basic" # not used
11
+ STEM_OUT_CHANNELS: 64
12
+ STRIDE_IN_1X1: False
13
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
14
+ # NORM: "SyncBN"
15
+ RES5_MULTI_GRID: [1, 1, 1] # not used
16
+ DATASETS:
17
+ TRAIN: ("cityscapes_fine_sem_seg_train",)
18
+ TEST: ("cityscapes_fine_sem_seg_val",)
19
+ SOLVER:
20
+ IMS_PER_BATCH: 16
21
+ BASE_LR: 0.0001
22
+ MAX_ITER: 90000
23
+ WARMUP_FACTOR: 1.0
24
+ WARMUP_ITERS: 0
25
+ WEIGHT_DECAY: 0.0001
26
+ OPTIMIZER: "ADAMW"
27
+ LR_SCHEDULER_NAME: "WarmupPolyLR"
28
+ BACKBONE_MULTIPLIER: 0.1
29
+ CLIP_GRADIENTS:
30
+ ENABLED: True
31
+ CLIP_TYPE: "full_model"
32
+ CLIP_VALUE: 0.01
33
+ NORM_TYPE: 2.0
34
+ INPUT:
35
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 1024) for x in range(5, 21)]"]
36
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
37
+ MIN_SIZE_TEST: 1024
38
+ MAX_SIZE_TRAIN: 4096
39
+ MAX_SIZE_TEST: 2048
40
+ CROP:
41
+ ENABLED: True
42
+ TYPE: "absolute"
43
+ SIZE: (512, 1024)
44
+ SINGLE_CATEGORY_MAX_AREA: 1.0
45
+ COLOR_AUG_SSD: True
46
+ SIZE_DIVISIBILITY: -1
47
+ FORMAT: "RGB"
48
+ DATASET_MAPPER_NAME: "mask_former_semantic"
49
+ TEST:
50
+ EVAL_PERIOD: 5000
51
+ AUG:
52
+ ENABLED: False
53
+ MIN_SIZES: [512, 768, 1024, 1280, 1536, 1792]
54
+ MAX_SIZE: 4096
55
+ FLIP: True
56
+ DATALOADER:
57
+ FILTER_EMPTY_ANNOTATIONS: True
58
+ NUM_WORKERS: 4
59
+ VERSION: 2
configs/maskformer/cityscapes-19/maskformer_R101_bs16_90k.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: Base-Cityscapes-19.yaml
2
+ MODEL:
3
+ WEIGHTS: "R-101.pkl"
4
+ RESNETS:
5
+ DEPTH: 101
6
+ STEM_TYPE: "basic" # not used
7
+ STEM_OUT_CHANNELS: 64
8
+ STRIDE_IN_1X1: False
9
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
10
+ # NORM: "SyncBN"
11
+ RES5_MULTI_GRID: [1, 1, 1] # not used
12
+ META_ARCHITECTURE: "MaskFormer"
13
+ SEM_SEG_HEAD:
14
+ NAME: "MaskFormerHead"
15
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
16
+ IGNORE_VALUE: 255
17
+ NUM_CLASSES: 19
18
+ COMMON_STRIDE: 4 # not used, hard-coded
19
+ LOSS_WEIGHT: 1.0
20
+ CONVS_DIM: 256
21
+ MASK_DIM: 256
22
+ NORM: "GN"
23
+ MASK_FORMER:
24
+ TRANSFORMER_IN_FEATURE: "res5"
25
+ DEEP_SUPERVISION: True
26
+ NO_OBJECT_WEIGHT: 0.1
27
+ DICE_WEIGHT: 1.0
28
+ MASK_WEIGHT: 20.0
29
+ HIDDEN_DIM: 256
30
+ NUM_OBJECT_QUERIES: 100
31
+ NHEADS: 8
32
+ DROPOUT: 0.1
33
+ DIM_FEEDFORWARD: 2048
34
+ ENC_LAYERS: 0
35
+ DEC_LAYERS: 6
36
+ PRE_NORM: False
configs/maskformer/cityscapes-19/maskformer_R101c_bs16_90k.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: maskformer_R101_bs16_90k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ FREEZE_AT: 0
5
+ NAME: "build_resnet_deeplab_backbone"
6
+ WEIGHTS: "detectron2://DeepLab/R-103.pkl"
7
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
8
+ PIXEL_STD: [58.395, 57.120, 57.375]
9
+ RESNETS:
10
+ DEPTH: 101
11
+ STEM_TYPE: "deeplab"
12
+ STEM_OUT_CHANNELS: 128
13
+ STRIDE_IN_1X1: False
14
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
15
+ # NORM: "SyncBN"
16
+ RES5_MULTI_GRID: [1, 2, 4]
configs/maskformer/maskformer_R50_bs16_160k.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: Base-ADE20K-150.yaml
2
+ MODEL:
3
+ META_ARCHITECTURE: "MaskFormer"
4
+ SEM_SEG_HEAD:
5
+ NAME: "MaskFormerHead"
6
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
7
+ IGNORE_VALUE: 255
8
+ NUM_CLASSES: 2
9
+ COMMON_STRIDE: 4 # not used, hard-coded
10
+ LOSS_WEIGHT: 1.0
11
+ CONVS_DIM: 256
12
+ MASK_DIM: 256
13
+ NORM: "GN"
14
+ MASK_FORMER:
15
+ TRANSFORMER_IN_FEATURE: "res5"
16
+ DEEP_SUPERVISION: False
17
+ NO_OBJECT_WEIGHT: 0.1
18
+ DICE_WEIGHT: 1.0
19
+ MASK_WEIGHT: 20.0
20
+ HIDDEN_DIM: 256
21
+ NUM_OBJECT_QUERIES: 2
22
+ NHEADS: 8
23
+ DROPOUT: 0.1
24
+ DIM_FEEDFORWARD: 2048
25
+ ENC_LAYERS: 0
26
+ DEC_LAYERS: 6
27
+ PRE_NORM: False
configs/maskformer/maskformer_R50_bs16_160k_dino.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: ./maskformer_R50_bs16_160k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2ViTTransformer"
5
+ FREEZE_AT: -1
6
+ SWIN:
7
+ EMBED_DIM: 768
8
+ DEPTHS: [2, 2, 6, 2]
9
+ NUM_HEADS: [3, 6, 12, 24]
10
+ WINDOW_SIZE: 7
11
+ APE: False
12
+ DROP_PATH_RATE: 0.3
13
+ PATCH_NORM: True
14
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
15
+ PIXEL_STD: [58.395, 57.120, 57.375]
16
+ WEIGHTS: None
17
+ MASK_FORMER:
18
+ NUM_OBJECT_QUERIES: 4
19
+ SEM_SEG_HEAD:
20
+ PIXEL_DECODER_NAME: BigPixelDecoder
21
+ SOLVER:
22
+ BASE_LR: 0.00015
23
+ IMS_PER_BATCH: 8
24
+ WARMUP_FACTOR: 1e-6
25
+ WARMUP_ITERS: 1500
26
+ WEIGHT_DECAY: 0.01
27
+ WEIGHT_DECAY_NORM: 0.0
28
+ WEIGHT_DECAY_EMBED: 0.0
29
+ BACKBONE_MULTIPLIER: 1.0
30
+ FLAGS:
31
+ UNFREEZE_AT: []
configs/maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: ../maskformer_R50_bs16_160k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2SwinTransformer"
5
+ SWIN:
6
+ EMBED_DIM: 128
7
+ DEPTHS: [2, 2, 18, 2]
8
+ NUM_HEADS: [4, 8, 16, 32]
9
+ WINDOW_SIZE: 12
10
+ APE: False
11
+ DROP_PATH_RATE: 0.3
12
+ PATCH_NORM: True
13
+ PRETRAIN_IMG_SIZE: 384
14
+ WEIGHTS: "pretrained_weights/swin_base_patch4_window12_384_22k.pkl"
15
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
16
+ PIXEL_STD: [58.395, 57.120, 57.375]
17
+ SOLVER:
18
+ BASE_LR: 0.00006
19
+ WARMUP_FACTOR: 1e-6
20
+ WARMUP_ITERS: 1500
21
+ WEIGHT_DECAY: 0.01
22
+ WEIGHT_DECAY_NORM: 0.0
23
+ WEIGHT_DECAY_EMBED: 0.0
24
+ BACKBONE_MULTIPLIER: 1.0
25
+ INPUT:
26
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
27
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
28
+ MIN_SIZE_TEST: 640
29
+ MAX_SIZE_TRAIN: 2560
30
+ MAX_SIZE_TEST: 2560
31
+ CROP:
32
+ ENABLED: True
33
+ TYPE: "absolute"
34
+ SIZE: (640, 640)
35
+ SINGLE_CATEGORY_MAX_AREA: 1.0
36
+ COLOR_AUG_SSD: True
37
+ SIZE_DIVISIBILITY: 640 # used in dataset mapper
38
+ FORMAT: "RGB"
39
+ TEST:
40
+ EVAL_PERIOD: 5000
41
+ AUG:
42
+ ENABLED: False
43
+ MIN_SIZES: [320, 480, 640, 800, 960, 1120]
44
+ MAX_SIZE: 4480
45
+ FLIP: True
configs/maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: ../maskformer_R50_bs16_160k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2SwinTransformer"
5
+ SWIN:
6
+ EMBED_DIM: 192
7
+ DEPTHS: [2, 2, 18, 2]
8
+ NUM_HEADS: [6, 12, 24, 48]
9
+ WINDOW_SIZE: 12
10
+ APE: False
11
+ DROP_PATH_RATE: 0.3
12
+ PATCH_NORM: True
13
+ PRETRAIN_IMG_SIZE: 384
14
+ WEIGHTS: "pretrained_weights/swin_large_patch4_window12_384_22k.pkl"
15
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
16
+ PIXEL_STD: [58.395, 57.120, 57.375]
17
+ SOLVER:
18
+ BASE_LR: 0.00006
19
+ WARMUP_FACTOR: 1e-6
20
+ WARMUP_ITERS: 1500
21
+ WEIGHT_DECAY: 0.01
22
+ WEIGHT_DECAY_NORM: 0.0
23
+ WEIGHT_DECAY_EMBED: 0.0
24
+ BACKBONE_MULTIPLIER: 1.0
25
+ INPUT:
26
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
27
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
28
+ MIN_SIZE_TEST: 640
29
+ MAX_SIZE_TRAIN: 2560
30
+ MAX_SIZE_TEST: 2560
31
+ CROP:
32
+ ENABLED: True
33
+ TYPE: "absolute"
34
+ SIZE: (640, 640)
35
+ SINGLE_CATEGORY_MAX_AREA: 1.0
36
+ COLOR_AUG_SSD: True
37
+ SIZE_DIVISIBILITY: 640 # used in dataset mapper
38
+ FORMAT: "RGB"
39
+ TEST:
40
+ EVAL_PERIOD: 5000
41
+ AUG:
42
+ ENABLED: False
43
+ MIN_SIZES: [320, 480, 640, 800, 960, 1120]
44
+ MAX_SIZE: 4480
45
+ FLIP: True
configs/maskformer/swin/maskformer_swin_small_bs16_160k.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: ../maskformer_R50_bs16_160k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2SwinTransformer"
5
+ SWIN:
6
+ EMBED_DIM: 96
7
+ DEPTHS: [2, 2, 18, 2]
8
+ NUM_HEADS: [3, 6, 12, 24]
9
+ WINDOW_SIZE: 7
10
+ APE: False
11
+ DROP_PATH_RATE: 0.3
12
+ PATCH_NORM: True
13
+ WEIGHTS: "pretrained_weights/swin_small_patch4_window7_224.pkl"
14
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
15
+ PIXEL_STD: [58.395, 57.120, 57.375]
16
+ SOLVER:
17
+ BASE_LR: 0.00006
18
+ WARMUP_FACTOR: 1e-6
19
+ WARMUP_ITERS: 1500
20
+ WEIGHT_DECAY: 0.01
21
+ WEIGHT_DECAY_NORM: 0.0
22
+ WEIGHT_DECAY_EMBED: 0.0
23
+ BACKBONE_MULTIPLIER: 1.0
configs/maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: ../maskformer_R50_bs16_160k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2SwinTransformer"
5
+ SWIN:
6
+ EMBED_DIM: 96
7
+ DEPTHS: [2, 2, 6, 2]
8
+ NUM_HEADS: [3, 6, 12, 24]
9
+ WINDOW_SIZE: 7
10
+ APE: False
11
+ DROP_PATH_RATE: 0.3
12
+ PATCH_NORM: True
13
+ WEIGHTS: "pretrained_weights/swin_tiny_patch4_window7_224.pkl"
14
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
15
+ PIXEL_STD: [58.395, 57.120, 57.375]
16
+ SOLVER:
17
+ BASE_LR: 0.00006
18
+ WARMUP_FACTOR: 1e-6
19
+ WARMUP_ITERS: 1500
20
+ WEIGHT_DECAY: 0.01
21
+ WEIGHT_DECAY_NORM: 0.0
22
+ WEIGHT_DECAY_EMBED: 0.0
23
+ BACKBONE_MULTIPLIER: 1.0
configs/maskformer/swin/maskformer_swin_tiny_bs16_160k_moby.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: ../maskformer_R50_bs16_160k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2SwinTransformer"
5
+ SWIN:
6
+ EMBED_DIM: 96
7
+ DEPTHS: [2, 2, 6, 2]
8
+ NUM_HEADS: [3, 6, 12, 24]
9
+ WINDOW_SIZE: 7
10
+ APE: False
11
+ DROP_PATH_RATE: 0.3
12
+ PATCH_NORM: True
13
+ WEIGHTS: "pretrained_weights/moby_swin_t_300ep_pretrained.pkl"
14
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
15
+ PIXEL_STD: [58.395, 57.120, 57.375]
16
+ SOLVER:
17
+ BASE_LR: 0.00006
18
+ WARMUP_FACTOR: 1e-6
19
+ WARMUP_ITERS: 1500
20
+ WEIGHT_DECAY: 0.01
21
+ WEIGHT_DECAY_NORM: 0.0
22
+ WEIGHT_DECAY_EMBED: 0.0
23
+ BACKBONE_MULTIPLIER: 1.0
datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .flow_eval_detectron import FlowEvalDetectron
2
+ from .flow_pair_detectron import FlowPairDetectron
datasets/flow_eval_detectron.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import detectron2.data.transforms as DT
6
+ import einops
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from PIL import Image
11
+ from detectron2.data import detection_utils as d2_utils
12
+ from detectron2.structures import Instances, BitMasks
13
+ from sklearn.model_selection import train_test_split
14
+ from torch.utils.data import Dataset
15
+
16
+ from utils.data import read_flow
17
+
18
+
19
+ class FlowEvalDetectron(Dataset):
20
+ def __init__(self, data_dir, resolution, pair_list, val_seq, to_rgb=False, with_rgb=False, size_divisibility=None,
21
+ small_val=0, flow_clip=1., norm=True, read_big=True, eval_size=True, force1080p=False):
22
+ self.val_seq = val_seq
23
+ self.to_rgb = to_rgb
24
+ self.with_rgb = with_rgb
25
+ self.data_dir = data_dir
26
+ self.pair_list = pair_list
27
+ self.resolution = resolution
28
+
29
+ self.eval_size = eval_size
30
+
31
+ self.samples = []
32
+ self.samples_fid = {}
33
+ for v in self.val_seq:
34
+ seq_dir = Path(self.data_dir[0]) / v
35
+ frames_paths = sorted(seq_dir.glob('*.flo'))
36
+ self.samples_fid[str(seq_dir)] = {fp: i for i, fp in enumerate(frames_paths)}
37
+ self.samples.extend(frames_paths)
38
+ self.samples = [os.path.join(x.parent.name, x.name) for x in self.samples]
39
+ if small_val > 0:
40
+ _, self.samples = train_test_split(self.samples, test_size=small_val, random_state=42)
41
+ self.gaps = ['gap{}'.format(i) for i in pair_list]
42
+ self.neg_gaps = ['gap{}'.format(-i) for i in pair_list]
43
+ self.size_divisibility = size_divisibility
44
+ self.ignore_label = -1
45
+ self.transforms = DT.AugmentationList([
46
+ DT.Resize(self.resolution, interp=Image.BICUBIC),
47
+ ])
48
+ self.flow_clip=flow_clip
49
+ self.norm_flow=norm
50
+ self.read_big=read_big
51
+ self.force1080p_transforms=None
52
+ if force1080p:
53
+ self.force1080p_transforms = DT.AugmentationList([
54
+ DT.Resize((1088, 1920), interp=Image.BICUBIC),
55
+ ])
56
+
57
+
58
+ def __len__(self):
59
+ return len(self.samples)
60
+
61
+ def __getitem__(self, idx):
62
+ dataset_dicts = []
63
+
64
+ dataset_dict = {}
65
+ flow_dir = Path(self.data_dir[0]) / self.samples[idx]
66
+ fid = self.samples_fid[str(flow_dir.parent)][flow_dir]
67
+ flo = einops.rearrange(read_flow(str(flow_dir), self.resolution, self.to_rgb), 'c h w -> h w c')
68
+ dataset_dict["gap"] = 'gap1'
69
+
70
+ suffix = '.png' if 'CLEVR' in self.samples[idx] else '.jpg'
71
+ rgb_dir = (self.data_dir[1] / self.samples[idx]).with_suffix(suffix)
72
+ gt_dir = (self.data_dir[2] / self.samples[idx]).with_suffix('.png')
73
+
74
+ rgb = d2_utils.read_image(str(rgb_dir)).astype(np.float32)
75
+ original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float()
76
+ if self.read_big:
77
+ rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32)
78
+ rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.)
79
+ if self.force1080p_transforms is not None:
80
+ rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0]
81
+
82
+ input = DT.AugInput(rgb)
83
+
84
+ # Apply the augmentation:
85
+ preprocessing_transforms = self.transforms(input) # type: DT.Transform
86
+ rgb = input.image
87
+ rgb = np.transpose(rgb, (2, 0, 1))
88
+ rgb = rgb.clip(0., 255.)
89
+ d2_utils.check_image_size(dataset_dict, flo)
90
+
91
+ if gt_dir.exists():
92
+ sem_seg_gt_ori = d2_utils.read_image(gt_dir)
93
+ sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt_ori)
94
+ if sem_seg_gt.ndim == 3:
95
+ sem_seg_gt = sem_seg_gt[:, :, 0]
96
+ sem_seg_gt_ori = sem_seg_gt_ori[:, :, 0]
97
+ if sem_seg_gt.max() == 255:
98
+ sem_seg_gt = (sem_seg_gt > 128).astype(int)
99
+ sem_seg_gt_ori = (sem_seg_gt_ori > 128).astype(int)
100
+ else:
101
+ sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1]))
102
+ sem_seg_gt_ori = np.zeros((original_rgb.shape[-2], original_rgb.shape[-1]))
103
+
104
+ gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / self.samples[idx]).with_suffix(
105
+ '.png')
106
+ if gwm_dir.exists():
107
+ gwm_seg_gt = preprocessing_transforms.apply_segmentation(d2_utils.read_image(str(gwm_dir)))
108
+ gwm_seg_gt = np.array(gwm_seg_gt)
109
+ if gwm_seg_gt.ndim == 3:
110
+ gwm_seg_gt = gwm_seg_gt[:, :, 0]
111
+ if gwm_seg_gt.max() == 255:
112
+ gwm_seg_gt[gwm_seg_gt == 255] = 1
113
+ else:
114
+ gwm_seg_gt = None
115
+
116
+ if sem_seg_gt is None:
117
+ raise ValueError(
118
+ "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
119
+ dataset_dict["file_name"]
120
+ )
121
+ )
122
+
123
+ # Pad image and segmentation label here!
124
+ if self.to_rgb:
125
+ flo = torch.as_tensor(np.ascontiguousarray(flo.transpose(2, 0, 1))) / 2 + .5
126
+ flo = flo * 255
127
+ else:
128
+ flo = torch.as_tensor(np.ascontiguousarray(flo.transpose(2, 0, 1)))
129
+ if self.norm_flow:
130
+ flo = flo/(flo ** 2).sum(0).max().sqrt()
131
+ flo = flo.clip(-self.flow_clip, self.flow_clip)
132
+
133
+ rgb = torch.as_tensor(np.ascontiguousarray(rgb)).float()
134
+ if sem_seg_gt is not None:
135
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
136
+ sem_seg_gt_ori = torch.as_tensor(sem_seg_gt_ori.astype("long"))
137
+ if gwm_seg_gt is not None:
138
+ gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long"))
139
+
140
+ if self.size_divisibility > 0:
141
+ image_size = (flo.shape[-2], flo.shape[-1])
142
+ padding_size = [
143
+ 0,
144
+ int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1],
145
+ 0,
146
+ int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0],
147
+ ]
148
+ flo = F.pad(flo, padding_size, value=0).contiguous()
149
+ rgb = F.pad(rgb, padding_size, value=128).contiguous()
150
+ if sem_seg_gt is not None:
151
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
152
+ if gwm_seg_gt is not None:
153
+ gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous()
154
+
155
+ image_shape = (flo.shape[-2], flo.shape[-1]) # h, w
156
+ if self.eval_size:
157
+ image_shape = (sem_seg_gt_ori.shape[-2], sem_seg_gt_ori.shape[-1])
158
+
159
+
160
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
161
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
162
+ # Therefore it's important to use torch.Tensor.
163
+ dataset_dict["flow"] = flo
164
+ dataset_dict["rgb"] = rgb
165
+
166
+
167
+ dataset_dict["original_rgb"] = F.interpolate(original_rgb[None], mode='bicubic', size=sem_seg_gt_ori.shape[-2:], align_corners=False).clip(0.,255.)[0]
168
+ if self.read_big:
169
+ dataset_dict["RGB_BIG"] = rgb_big
170
+
171
+ dataset_dict["category"] = str(gt_dir).split('/')[-2:]
172
+ dataset_dict['frame_id'] = fid
173
+
174
+ if sem_seg_gt is not None:
175
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
176
+ dataset_dict["sem_seg_ori"] = sem_seg_gt_ori.long()
177
+
178
+ if gwm_seg_gt is not None:
179
+ dataset_dict["gwm_seg"] = gwm_seg_gt.long()
180
+
181
+ if "annotations" in dataset_dict:
182
+ raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
183
+
184
+ # Prepare per-category binary masks
185
+ if sem_seg_gt is not None:
186
+ sem_seg_gt = sem_seg_gt.numpy()
187
+ instances = Instances(image_shape)
188
+ classes = np.unique(sem_seg_gt)
189
+ # remove ignored region
190
+ classes = classes[classes != self.ignore_label]
191
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
192
+
193
+ masks = []
194
+ for class_id in classes:
195
+ masks.append(sem_seg_gt == class_id)
196
+
197
+ if len(masks) == 0:
198
+ # Some image does not have annotation (all ignored)
199
+ instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
200
+ else:
201
+ masks = BitMasks(
202
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
203
+ )
204
+ instances.gt_masks = masks.tensor
205
+
206
+ dataset_dict["instances"] = instances
207
+ dataset_dicts.append(dataset_dict)
208
+
209
+ return dataset_dicts
datasets/flow_pair_detectron.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+ import random
4
+
5
+ import detectron2.data.transforms as DT
6
+ import einops
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ from PIL import Image
12
+ from detectron2.data import detection_utils as d2_utils
13
+ from detectron2.structures import Instances, BitMasks
14
+ from torch.utils.data import Dataset
15
+
16
+ from utils.data import read_flow, read_flo
17
+
18
+
19
+ def load_flow_tensor(path, resize=None, normalize=True, align_corners=True):
20
+ """
21
+ Load flow, scale the pixel values according to the resized scale.
22
+ If normalize is true, return rescaled in normalized pixel coordinates
23
+ where pixel coordinates are in range [-1, 1].
24
+ NOTE: RAFT USES ALIGN_CORNERS=TRUE SO WE NEED TO ACCOUNT FOR THIS
25
+ Returns (2, H, W) float32
26
+ """
27
+ flow = read_flo(path).astype(np.float32)
28
+ H, W, _ = flow.shape
29
+ h, w = (H, W) if resize is None else resize
30
+ u, v = flow[..., 0], flow[..., 1]
31
+ if normalize:
32
+ if align_corners:
33
+ u = 2.0 * u / (W - 1)
34
+ v = 2.0 * v / (H - 1)
35
+ else:
36
+ u = 2.0 * u / W
37
+ v = 2.0 * v / H
38
+ else:
39
+ h, w = resize
40
+ u = w * u / W
41
+ v = h * v / H
42
+
43
+ if h != H or w !=W:
44
+ u = Image.fromarray(u).resize((w, h), Image.ANTIALIAS)
45
+ v = Image.fromarray(v).resize((w, h), Image.ANTIALIAS)
46
+ u, v = np.array(u), np.array(v)
47
+ return torch.from_numpy(np.stack([u, v], axis=0))
48
+
49
+
50
+ class FlowPairDetectron(Dataset):
51
+ def __init__(self, data_dir, resolution, to_rgb=False, size_divisibility=None, enable_photo_aug=False, flow_clip=1., norm=True, read_big=True, force1080p=False, flow_res=None):
52
+ self.eval = eval
53
+ self.to_rgb = to_rgb
54
+ self.data_dir = data_dir
55
+ self.flow_dir = {k: [e for e in v if e.shape[0] > 0] for k, v in data_dir[0].items()}
56
+ self.flow_dir = {k: v for k, v in self.flow_dir.items() if len(v) > 0}
57
+ self.resolution = resolution
58
+ self.size_divisibility = size_divisibility
59
+ self.ignore_label = -1
60
+ self.transforms = DT.AugmentationList([
61
+ DT.Resize(self.resolution, interp=Image.BICUBIC),
62
+ ])
63
+ self.photometric_aug = T.Compose([
64
+ T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)]),
65
+ p=0.8),
66
+ T.RandomGrayscale(p=0.2),
67
+ ]) if enable_photo_aug else None
68
+ self.flow_clip=flow_clip
69
+ self.norm_flow=norm
70
+ self.read_big = read_big
71
+ self.force1080p_transforms = None
72
+ if force1080p:
73
+ self.force1080p_transforms = DT.AugmentationList([
74
+ DT.Resize((1088, 1920), interp=Image.BICUBIC),
75
+ ])
76
+ self.big_flow_resolution = flow_res
77
+
78
+ def __len__(self):
79
+ return sum([cat.shape[0] for cat in next(iter(self.flow_dir.values()))]) if len(
80
+ self.flow_dir.values()) > 0 else 0
81
+
82
+ def __getitem__(self, idx):
83
+
84
+ dataset_dicts = []
85
+
86
+ random_gap = random.choice(list(self.flow_dir.keys()))
87
+ flowgaps = self.flow_dir[random_gap]
88
+ vid = random.choice(flowgaps)
89
+ flos = random.choice(vid)
90
+ dataset_dict = {}
91
+
92
+ fname = Path(flos[0]).stem
93
+ dname = Path(flos[0]).parent.name
94
+ suffix = '.png' if 'CLEVR' in fname else '.jpg'
95
+ rgb_dir = (self.data_dir[1] / dname / fname).with_suffix(suffix)
96
+ gt_dir = (self.data_dir[2] / dname / fname).with_suffix('.png')
97
+
98
+ flo0 = einops.rearrange(read_flow(str(flos[0]), self.resolution, self.to_rgb), 'c h w -> h w c')
99
+ flo1 = einops.rearrange(read_flow(str(flos[1]), self.resolution, self.to_rgb), 'c h w -> h w c')
100
+ if self.big_flow_resolution is not None:
101
+ flo0_big = einops.rearrange(read_flow(str(flos[0]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c')
102
+ flo1_big = einops.rearrange(read_flow(str(flos[1]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c')
103
+ rgb = d2_utils.read_image(rgb_dir).astype(np.float32)
104
+ original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float()
105
+ if self.read_big:
106
+ rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32)
107
+ rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.)
108
+ if self.force1080p_transforms is not None:
109
+ rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0]
110
+
111
+ # print('not here', rgb.min(), rgb.max())
112
+ input = DT.AugInput(rgb)
113
+
114
+ # Apply the augmentation:
115
+ preprocessing_transforms = self.transforms(input) # type: DT.Transform
116
+ rgb = input.image
117
+ if self.photometric_aug:
118
+ rgb_aug = Image.fromarray(rgb.astype(np.uint8))
119
+ rgb_aug = self.photometric_aug(rgb_aug)
120
+ rgb_aug = d2_utils.convert_PIL_to_numpy(rgb_aug, 'RGB')
121
+ rgb_aug = np.transpose(rgb_aug, (2, 0, 1)).astype(np.float32)
122
+ rgb = np.transpose(rgb, (2, 0, 1))
123
+ rgb = rgb.clip(0., 255.)
124
+ # print('here', rgb.min(), rgb.max())
125
+ d2_utils.check_image_size(dataset_dict, flo0)
126
+ if gt_dir.exists():
127
+ sem_seg_gt = d2_utils.read_image(str(gt_dir))
128
+ sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt)
129
+ # sem_seg_gt = cv2.resize(sem_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST)
130
+ if sem_seg_gt.ndim == 3:
131
+ sem_seg_gt = sem_seg_gt[:, :, 0]
132
+ if sem_seg_gt.max() == 255:
133
+ sem_seg_gt = (sem_seg_gt > 128).astype(int)
134
+ else:
135
+ sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1]))
136
+
137
+
138
+ gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / dname / fname).with_suffix('.png')
139
+ if gwm_dir.exists():
140
+ gwm_seg_gt = d2_utils.read_image(str(gwm_dir))
141
+ gwm_seg_gt = preprocessing_transforms.apply_segmentation(gwm_seg_gt)
142
+ gwm_seg_gt = np.array(gwm_seg_gt)
143
+ # gwm_seg_gt = cv2.resize(gwm_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST)
144
+ if gwm_seg_gt.ndim == 3:
145
+ gwm_seg_gt = gwm_seg_gt[:, :, 0]
146
+ if gwm_seg_gt.max() == 255:
147
+ gwm_seg_gt[gwm_seg_gt == 255] = 1
148
+ else:
149
+ gwm_seg_gt = None
150
+
151
+ if sem_seg_gt is None:
152
+ raise ValueError(
153
+ "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
154
+ dataset_dict["file_name"]
155
+ )
156
+ )
157
+
158
+ # Pad image and segmentation label here!
159
+ if self.to_rgb:
160
+ flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1))) / 2 + .5
161
+ flo0 = flo0 * 255
162
+ flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1))) / 2 + .5
163
+ flo1 = flo1 * 255
164
+ if self.big_flow_resolution is not None:
165
+ flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1))) / 2 + .5
166
+ flo0_big = flo0_big * 255
167
+ flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1))) / 2 + .5
168
+ flo1_big = flo1_big * 255
169
+ else:
170
+ flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1)))
171
+ flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1)))
172
+
173
+ if self.norm_flow:
174
+ flo0 = flo0 / (flo0 ** 2).sum(0).max().sqrt()
175
+ flo1 = flo1 / (flo1 ** 2).sum(0).max().sqrt()
176
+
177
+ flo0 = flo0.clip(-self.flow_clip, self.flow_clip)
178
+ flo1 = flo1.clip(-self.flow_clip, self.flow_clip)
179
+
180
+ if self.big_flow_resolution is not None:
181
+ flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1)))
182
+ flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1)))
183
+ if self.norm_flow:
184
+ flo0_big = flo0_big / (flo0_big ** 2).sum(0).max().sqrt()
185
+ flo1_big = flo1_big / (flo1_big ** 2).sum(0).max().sqrt()
186
+ flo0_big = flo0_big.clip(-self.flow_clip, self.flow_clip)
187
+ flo1_big = flo1_big.clip(-self.flow_clip, self.flow_clip)
188
+
189
+ rgb = torch.as_tensor(np.ascontiguousarray(rgb))
190
+ if self.photometric_aug:
191
+ rgb_aug = torch.as_tensor(np.ascontiguousarray(rgb_aug))
192
+
193
+ if sem_seg_gt is not None:
194
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
195
+ if gwm_seg_gt is not None:
196
+ gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long"))
197
+
198
+ if self.size_divisibility > 0:
199
+ image_size = (flo0.shape[-2], flo0.shape[-1])
200
+ padding_size = [
201
+ 0,
202
+ int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1],
203
+ 0,
204
+ int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0],
205
+ ]
206
+ flo0 = F.pad(flo0, padding_size, value=0).contiguous()
207
+ flo1 = F.pad(flo1, padding_size, value=0).contiguous()
208
+ rgb = F.pad(rgb, padding_size, value=128).contiguous()
209
+ if self.photometric_aug:
210
+ rgb_aug = F.pad(rgb_aug, padding_size, value=128).contiguous()
211
+ if sem_seg_gt is not None:
212
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
213
+ if gwm_seg_gt is not None:
214
+ gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous()
215
+
216
+ image_shape = (rgb.shape[-2], rgb.shape[-1]) # h, w
217
+
218
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
219
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
220
+ # Therefore it's important to use torch.Tensor.
221
+ dataset_dict["flow"] = flo0
222
+ dataset_dict["flow_2"] = flo1
223
+
224
+ # dataset_dict["flow_fwd"] = flo_norm_fwd
225
+ # dataset_dict["flow_bwd"] = flo_norm_bwd
226
+ # dataset_dict["flow_rgb"] = rgb_flo0
227
+ # dataset_dict["flow_gap"] = gap
228
+
229
+ dataset_dict["rgb"] = rgb
230
+ dataset_dict["original_rgb"] = original_rgb
231
+ if self.read_big:
232
+ dataset_dict["RGB_BIG"] = rgb_big
233
+ if self.photometric_aug:
234
+ dataset_dict["rgb_aug"] = rgb_aug
235
+
236
+ if self.big_flow_resolution is not None:
237
+ dataset_dict["flow_big"] = flo0_big
238
+ dataset_dict["flow_big_2"] = flo1_big
239
+
240
+
241
+ if sem_seg_gt is not None:
242
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
243
+
244
+ if gwm_seg_gt is not None:
245
+ dataset_dict["gwm_seg"] = gwm_seg_gt.long()
246
+
247
+ if "annotations" in dataset_dict:
248
+ raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
249
+
250
+ # Prepare per-category binary masks
251
+ if sem_seg_gt is not None:
252
+ sem_seg_gt = sem_seg_gt.numpy()
253
+ instances = Instances(image_shape)
254
+ classes = np.unique(sem_seg_gt)
255
+ # remove ignored region
256
+ classes = classes[classes != self.ignore_label]
257
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
258
+
259
+ masks = []
260
+ for class_id in classes:
261
+ masks.append(sem_seg_gt == class_id)
262
+
263
+ if len(masks) == 0:
264
+ # Some image does not have annotation (all ignored)
265
+ instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
266
+ else:
267
+ masks = BitMasks(
268
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
269
+ )
270
+ instances.gt_masks = masks.tensor
271
+
272
+ dataset_dict["instances"] = instances
273
+ dataset_dicts.append(dataset_dict)
274
+
275
+ return dataset_dicts
determinism.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ lvl = int(os.environ.get('TRY_DETERMISM_LVL', '0'))
3
+ if lvl > 0:
4
+ print(f'Attempting to enable deterministic cuDNN and cuBLAS operations to lvl {lvl}')
5
+ if lvl >= 2:
6
+ # turn on deterministic operations
7
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" #Need to set before torch gets loaded
8
+ import torch
9
+ # Since using unstable torch version, it looks like 1.12.0.devXXXXXXX
10
+ if torch.version.__version__ >= '1.12.0':
11
+ torch.use_deterministic_algorithms(True, warn_only=(lvl < 3))
12
+ elif lvl >= 3:
13
+ torch.use_deterministic_algorithms(True) # This will throw errors if implementations are missing
14
+ else:
15
+ print(f"Torch verions is only {torch.version.__version__}, which will cause errors on lvl {lvl}")
16
+ if lvl >= 1:
17
+ import torch
18
+ if torch.cuda.is_available():
19
+ torch.backends.cudnn.benchmark = False
20
+
21
+
22
+ def i_do_nothing_but_dont_remove_me_otherwise_things_break():
23
+ """This exists to prevent formatters from treating this file as dead code"""
24
+ pass
dist.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.distributions
5
+
6
+ import utils
7
+
8
+ LOGGER = utils.log.getLogger(__name__)
9
+
10
+ __defined_kl = False
11
+
12
+ EPS = 1e-5
13
+
14
+
15
+ def clamp_probs(probs):
16
+ probs = probs.clamp(EPS, 1. - EPS) # Will no longer sum to 1
17
+ return probs / probs.sum(-1, keepdim=True) # to simplex
18
+
19
+
20
+ def grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
21
+ hr = torch.arange(h + 2 * pad, device=device) - pad
22
+ wr = torch.arange(w + 2 * pad, device=device) - pad
23
+ if norm:
24
+ hr = hr / (h + 2 * pad - 1)
25
+ wr = wr / (w + 2 * pad - 1)
26
+ ig, jg = torch.meshgrid(hr, wr)
27
+ g = torch.stack([jg, ig]).to(dtype)[None]
28
+ return g
29
+
30
+
31
+ @functools.lru_cache(2)
32
+ def cached_grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
33
+ return grid(h, w, pad, device, dtype, norm)
34
+
eval_utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import random
3
+ from collections import defaultdict
4
+
5
+ import einops
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ from sklearn.cluster import SpectralClustering
11
+ from tqdm import tqdm
12
+
13
+ import flow_reconstruction
14
+ from utils import visualisation, log, grid
15
+ from utils.vit_extractor import ViTExtractor
16
+
17
+ label_colors = visualisation.create_label_colormap()
18
+ logger = log.getLogger('gwm')
19
+
20
+
21
+ def __default_font(fontsize):
22
+ try:
23
+ FNT = ImageFont.truetype("dejavu/DejaVuSansMono.ttf", fontsize)
24
+ except OSError:
25
+ FNT = ImageFont.truetype("dejavu/DejaVuSans.ttf", fontsize)
26
+ return FNT
27
+
28
+
29
+ @functools.lru_cache(None) # cache the result
30
+ def autosized_default_font(size_limit: float) -> ImageFont.ImageFont:
31
+ fontsize = 1 # starting font size
32
+ font = __default_font(fontsize)
33
+ while font.getsize('test123')[1] < size_limit:
34
+ fontsize += 1
35
+ font = __default_font(fontsize)
36
+ fontsize -= 1
37
+ font = __default_font(fontsize)
38
+ return font
39
+
40
+
41
+ def iou(masks, gt, thres=0.5):
42
+ masks = (masks > thres).float()
43
+ intersect = torch.tensordot(masks, gt, dims=([-2, -1], [0, 1]))
44
+ union = masks.sum(dim=[-2, -1]) + gt.sum(dim=[-2, -1]) - intersect
45
+ return intersect / union.clip(min=1e-12)
46
+
47
+
48
+ def get_unsup_image_viz(model, cfg, sample, criterion):
49
+ if model.training:
50
+ model.eval()
51
+ preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
52
+ model.train()
53
+ else:
54
+ preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
55
+ return get_image_vis(model, cfg, sample, preds, criterion)
56
+
57
+ def get_vis_header(header_size, image_size, header_texts, header_height=20):
58
+ W, H = (image_size, header_height)
59
+ header_labels = []
60
+ font = autosized_default_font(0.8 * H)
61
+
62
+ for text in header_texts:
63
+ im = Image.new("RGB", (W, H), "white")
64
+ draw = ImageDraw.Draw(im)
65
+ w, h = draw.textsize(text, font=font)
66
+ draw.text(((W - w) / 2, (H - h) / 2), text, fill="black", font=font)
67
+ header_labels.append(torch.from_numpy(np.array(im)))
68
+ header_labels = torch.cat(header_labels, dim=1)
69
+ ret = (torch.ones((header_height, header_size, 3)) * 255)
70
+ ret[:, :header_labels.size(1)] = header_labels
71
+
72
+ return ret.permute(2, 0, 1).clip(0, 255).to(torch.uint8)
73
+
74
+ def get_image_vis(model, cfg, sample, preds, criterion):
75
+ masks_pred = torch.stack([x['sem_seg'] for x in preds], 0)
76
+
77
+ with torch.no_grad():
78
+ flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
79
+
80
+ masks_softmaxed = torch.softmax(masks_pred, dim=1)
81
+ masks_pred = masks_softmaxed
82
+ rec_flows = criterion.flow_reconstruction(sample, criterion.process_flow(sample, flow), masks_softmaxed)
83
+ rec_headers = ['rec_flow']
84
+ if len(rec_flows) > 1:
85
+ rec_headers.append('rec_bwd_flow')
86
+
87
+ rgb = torch.stack([x['rgb'] for x in sample])
88
+ flow = criterion.viz_flow(criterion.process_flow(sample, flow).cpu()) * 255
89
+ rec_flows = [
90
+ (criterion.viz_flow(rec_flow_.detach().cpu().cpu()) * 255).clip(0, 255).to(torch.uint8) for rec_flow_ in rec_flows
91
+ ]
92
+
93
+
94
+ gt_labels = torch.stack([x['sem_seg'] for x in sample])
95
+ gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2)
96
+ target_K = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
97
+ masks = F.one_hot(masks_pred.argmax(1).cpu(), target_K).permute(0, 3, 1, 2)
98
+ masks_each = torch.stack([masks_softmaxed, masks_softmaxed, masks_softmaxed], 2) * 255
99
+ masks_each = einops.rearrange(F.pad(masks_each.cpu(), pad=[0, 1], value=255), 'b n c h w -> b c h (n w)')
100
+
101
+ gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1])
102
+ pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:target_K])
103
+ if all('gwm_seg' in d for d in sample):
104
+ gwm_labels = torch.stack([x['gwm_seg'] for x in sample])
105
+ mg = F.one_hot(gwm_labels, gwm_labels.max().item() + 1).permute(0, 3, 1, 2)
106
+ gwm_seg = torch.einsum('b k h w, k c -> b c h w', mg, label_colors[:gwm_labels.max().item() + 1])
107
+ image_viz = torch.cat(
108
+ [rgb, flow, F.pad(gt_seg.cpu(), pad=[0, 1], value=255), F.pad(gwm_seg, pad=[0, 1], value=255),
109
+ pred_seg.cpu(), *rec_flows], -1)
110
+ header_text = ['rgb', 'gt_flow', 'gt_seg', 'GWM', 'pred_seg', *rec_headers]
111
+ else:
112
+ image_viz = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), *rec_flows], -1)
113
+ header_text = ['rgb', 'gt_flow', 'gt_seg', 'pred_seg', *rec_headers]
114
+
115
+ image_viz = torch.cat([image_viz, masks_each], -1)
116
+ header_text.extend(['slot'] * masks_softmaxed.shape[1])
117
+ if 'flow_edges' in sample[0]:
118
+ flow_edges = torch.stack([x['flow_edges'].to(image_viz.device) for x in sample])
119
+ if len(flow_edges.shape) >= 4:
120
+ flow_edges = flow_edges.sum(1, keepdim=len(flow_edges.shape) == 4)
121
+ flow_edges = flow_edges.expand(-1, 3, -1, -1)
122
+ flow_edges = flow_edges * 255
123
+ image_viz = torch.cat([image_viz, flow_edges], -1)
124
+ header_text.append('flow_edges')
125
+ image_viz = einops.rearrange(image_viz[:8], 'b c h w -> c (b h) w').detach().clip(0, 255).to(torch.uint8)
126
+
127
+ return image_viz, header_text
128
+
129
+
130
+ def get_frame_vis(model, cfg, sample, preds):
131
+ masks_pred = torch.stack([x['sem_seg'] for x in preds], 0)
132
+ flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
133
+
134
+ masks_softmaxed = torch.softmax(masks_pred, dim=1)
135
+ if cfg.GWM.SIMPLE_REC:
136
+ mask_denom = einops.reduce(masks_softmaxed, 'b k h w -> b k 1', 'sum') + 1e-7
137
+ means = torch.einsum('brhw, bchw -> brc', masks_softmaxed, flow) / mask_denom
138
+ rec_flow = torch.einsum('bkhw, bkc-> bchw', masks_softmaxed, means)
139
+ elif cfg.GWM.HOMOGRAPHY:
140
+ rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow)
141
+ else:
142
+ grid_x, grid_y = grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device)
143
+ rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow, grid_x, grid_y)
144
+
145
+ rgb = torch.stack([x['rgb'] for x in sample])
146
+ flow = torch.stack([visualisation.flow2rgb_torch(x) for x in flow.cpu()]) * 255
147
+ rec_flow = torch.stack([visualisation.flow2rgb_torch(x) for x in rec_flow.detach().cpu()]) * 255
148
+
149
+ gt_labels = torch.stack([x['sem_seg'] for x in sample])
150
+ gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2)
151
+
152
+ masks = F.one_hot(masks_pred.argmax(1).cpu(), cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES).permute(0, 3, 1, 2)
153
+
154
+ gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1])
155
+ pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES])
156
+ frame_vis = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), rec_flow.clip(0, 255).to(torch.uint8)], -1)
157
+ frame_vis = einops.rearrange(frame_vis, 'b c h w -> b c h w').detach().clip(0, 255).to(torch.uint8)
158
+ return frame_vis
159
+
160
+
161
+ def is_2comp_dataset(dataset):
162
+ if '+' in dataset:
163
+ d = dataset.split('+')[0].strip()
164
+ else:
165
+ d = dataset.strip()
166
+ logger.info_once(f"Is 2comp dataset? {d}")
167
+ for s in ['DAVIS', 'FBMS', 'STv2']:
168
+ if s in d:
169
+ return True
170
+ return d in ['DAVIS',
171
+ 'FBMS',
172
+ 'STv2']
173
+
174
+ def eval_unsupmf(cfg, val_loader, model, criterion, writer=None, writer_iteration=0, use_wandb=False):
175
+ logger.info(f'Running Evaluation: {cfg.LOG_ID} {"Simple" if cfg.GWM.SIMPLE_REC else "Gradient"}:')
176
+ logger.info(f'Model mode: {"train" if model.training else "eval"}, wandb: {use_wandb}')
177
+ logger.info(f'Dataset: {cfg.GWM.DATASET} # components: {cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES}')
178
+
179
+ merger = None
180
+ if cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES > 2:
181
+ merger = MaskMerger(cfg, model)
182
+
183
+ print_idxs = random.sample(range(len(val_loader)), k=10)
184
+
185
+ images_viz = []
186
+ ious_davis_eval = defaultdict(list)
187
+ ious = defaultdict(list)
188
+
189
+ for idx, sample in enumerate(tqdm(val_loader)):
190
+ t = 1
191
+ sample = [e for s in sample for e in s]
192
+ category = [s['category'] for s in sample]
193
+ preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
194
+ masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
195
+
196
+ masks_softmaxed = torch.softmax(masks_raw, dim=1)
197
+ masks_dict = merger(sample, masks_softmaxed)
198
+
199
+ if writer and idx in print_idxs:
200
+ flow = torch.stack([x['flow'] for x in sample]).to(model.device)
201
+ img_viz, header_text = get_image_vis(model, cfg, sample, preds, criterion)
202
+ images_viz.append(img_viz)
203
+
204
+ masks = masks_dict['cos']
205
+ gt_seg = torch.stack([x['sem_seg_ori'] for x in sample]).cpu()
206
+ HW = gt_seg.shape[-2:]
207
+ if HW != masks.shape[-2:]:
208
+ logger.info_once(f"Upsampling predicted masks to {HW} for evaluation")
209
+ masks_softmaxed_sel = F.interpolate(masks.detach().cpu(), size=HW, mode='bilinear', align_corners=False)
210
+ else:
211
+ masks_softmaxed_sel = masks.detach().cpu()
212
+ masks_ = einops.rearrange(masks_softmaxed_sel, '(b t) s h w -> b t s 1 h w', t=t).detach()
213
+ gt_seg = einops.rearrange(gt_seg, 'b h w -> b 1 h w').float()
214
+ for i in range(masks_.size(0)):
215
+ masks_k = F.interpolate(masks_[i], size=(1, gt_seg.shape[-2], gt_seg.shape[-1])) # t s 1 h w
216
+ mask_iou = iou(masks_k[:, :, 0], gt_seg[i, 0], thres=0.5) # t s
217
+ iou_max, slot_max = mask_iou.max(dim=1)
218
+
219
+ ious[category[i][0]].append(iou_max)
220
+ frame_id = category[i][1]
221
+ ious_davis_eval[category[i][0]].append((frame_id.strip().replace('.png', ''), iou_max))
222
+
223
+ frameious = sum(ious.values(), [])
224
+ frame_mean_iou = torch.cat(frameious).sum().item() * 100 / len(frameious)
225
+ if 'DAVIS' in cfg.GWM.DATASET.split('+')[0]:
226
+ logger.info_once("Using DAVIS evaluator methods for evaluting IoU -- mean of mean of sequences without first frame")
227
+ seq_scores = dict()
228
+ for c in ious_davis_eval:
229
+ seq_scores[c] = np.nanmean([v.item() for n, v in ious_davis_eval[c] if int(n) > 1])
230
+
231
+ frame_mean_iou = np.nanmean(list(seq_scores.values())) * 100
232
+
233
+ if writer:
234
+ header = get_vis_header(images_viz[0].size(2), flow.size(3), header_text)
235
+ images_viz = torch.cat(images_viz, dim=1)
236
+ images_viz = torch.cat([header, images_viz], dim=1)
237
+ writer.add_image('val/images', images_viz, writer_iteration) # C H W
238
+ writer.add_scalar('eval/mIoU', frame_mean_iou, writer_iteration)
239
+
240
+ logger.info(f"mIoU: {frame_mean_iou:.3f} \n")
241
+ return frame_mean_iou
242
+
243
+
244
+ class MaskMerger:
245
+ def __init__(self, cfg, model, merger_model="dino_vits8"):
246
+ self.extractor = ViTExtractor(model_type=merger_model, device=model.device)
247
+ self.out_dim = 384
248
+
249
+ self.mu = torch.tensor(self.extractor.mean).to(model.device).view(1, -1, 1, 1)
250
+ self.sigma = torch.tensor(self.extractor.std).to(model.device).view(1, -1, 1, 1)
251
+ self.start_idx = 0
252
+
253
+ def get_feats(self, batch):
254
+ with torch.no_grad():
255
+ feat = self.extractor.extract_descriptors(batch, facet='key', layer=11, bin=False)
256
+ feat = feat.reshape(feat.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
257
+ return F.interpolate(feat, batch.shape[-2:], mode='bilinear')
258
+
259
+ def spectral(self, A):
260
+ clustering = SpectralClustering(n_clusters=2,
261
+ affinity='precomputed',
262
+ random_state=0).fit(A.detach().cpu().numpy())
263
+ return np.arange(A.shape[-1])[clustering.labels_ == 0], np.arange(A.shape[-1])[clustering.labels_ == 1]
264
+
265
+ def cos_merge(self, basis, masks):
266
+ basis = basis / torch.linalg.vector_norm(basis, dim=-1, keepdim=True).clamp(min=1e-6)
267
+ A = torch.einsum('brc, blc -> brl', basis, basis)[0].clamp(min=1e-6)
268
+ inda, indb = self.spectral(A)
269
+ return torch.stack([masks[:, inda].sum(1),
270
+ masks[:, indb].sum(1)], 1)
271
+
272
+ def __call__(self, sample, masks_softmaxed):
273
+ with torch.no_grad():
274
+ masks_softmaxed = masks_softmaxed[:, self.start_idx:]
275
+ batch = torch.stack([x['rgb'].to(masks_softmaxed.device) for x in sample], 0) / 255.0
276
+ features = self.get_feats((batch - self.mu) / self.sigma)
277
+ basis = torch.einsum('brhw, bchw -> brc', masks_softmaxed, features)
278
+ basis /= einops.reduce(masks_softmaxed, 'b r h w -> b r 1', 'sum').clamp_min(1e-12)
279
+
280
+ return {
281
+ 'cos': self.cos_merge(basis, masks_softmaxed),
282
+ }
flow_reconstruction.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+
5
+ from dist import LOGGER
6
+
7
+
8
+ def lstq(A, F_u, F_v, lamda=0.01):
9
+ # cols = A.shape[2]
10
+ # assert all(cols == torch.linalg.matrix_rank(A)) # something better?
11
+ try:
12
+ Q, R = torch.linalg.qr(A)
13
+ theta_x = torch.bmm(torch.bmm(torch.linalg.inv(R), Q.transpose(1, 2)), F_u)
14
+ theta_y = torch.bmm(torch.bmm(torch.linalg.inv(R), Q.transpose(1, 2)), F_v)
15
+ except:
16
+ LOGGER.exception("Least Squares failed")
17
+ sys.exit(-1)
18
+ return theta_x, theta_y
19
+
20
+ def get_quad_flow(masks_softmaxed, flow, grid_x, grid_y):
21
+ rec_flow = 0
22
+ for k in range(masks_softmaxed.size(1)):
23
+ mask = masks_softmaxed[:, k].unsqueeze(1)
24
+ _F = flow * mask
25
+ M = mask.flatten(1)
26
+ bs = _F.shape[0]
27
+ x = grid_x.unsqueeze(0).flatten(1)
28
+ y = grid_y.unsqueeze(0).flatten(1)
29
+
30
+ F_u = _F[:, 0].flatten(1).unsqueeze(2) # B x L x 1
31
+ F_v = _F[:, 1].flatten(1).unsqueeze(2) # B x L x 1
32
+ A = torch.stack([x * M, y * M, x*x *M, y*y*M, x*y*M, torch.ones_like(y) * M], 2) # B x L x 2
33
+
34
+ theta_x, theta_y = lstq(A, F_u, F_v, lamda=.01)
35
+ rec_flow_m = torch.stack([torch.einsum('bln,bnk->blk', A, theta_x).view(bs, *grid_x.shape),
36
+ torch.einsum('bln,bnk->blk', A, theta_y).view(bs, *grid_y.shape)], 1)
37
+
38
+ rec_flow += rec_flow_m
39
+ return rec_flow
40
+
41
+
42
+ SUBSAMPLE = 8
43
+ SKIP = 0.4
44
+ SIZE = 0.3
45
+ NITER = 50
46
+ METHOD = 'inv_score'
47
+
48
+ def set_subsample_skip(sub=None, skip=None, size=None, niter=None, method=None):
49
+ global SUBSAMPLE, SKIP, SIZE, NITER, METHOD
50
+ if sub is not None: SUBSAMPLE=sub
51
+ if skip is not None: SKIP=skip
52
+ if size is not None: SIZE=size
53
+ if niter is not None: NITER=niter
54
+ if method is not None: METHOD=method
losses/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .reconstruction_loss import ReconstructionLoss
2
+ import torch
3
+
4
+
5
+ class CriterionDict:
6
+ def __init__(self, dict):
7
+ self.criterions = dict
8
+
9
+ def __call__(self, sample, flow, masks_softmaxed, iteration, train=True, prefix=''):
10
+ loss = torch.tensor(0., device=masks_softmaxed.device, dtype=masks_softmaxed.dtype)
11
+ log_dict = {}
12
+ for name_i, (criterion_i, loss_multiplier_i, anneal_fn_i) in self.criterions.items():
13
+ loss_i = loss_multiplier_i * anneal_fn_i(iteration) * criterion_i(sample, flow, masks_softmaxed, iteration, train=train)
14
+ loss += loss_i
15
+ log_dict[f'loss_{name_i}'] = loss_i.item()
16
+
17
+ log_dict['loss_total'] = loss.item()
18
+ return loss, log_dict
19
+
20
+ def flow_reconstruction(self, sample, flow, masks_softmaxed):
21
+ return self.criterions['reconstruction'][0].rec_flow(sample, flow, masks_softmaxed)
22
+
23
+ def process_flow(self, sample, flow):
24
+ return self.criterions['reconstruction'][0].process_flow(sample, flow)
25
+
26
+ def viz_flow(self, flow):
27
+ return self.criterions['reconstruction'][0].viz_flow(flow)
28
+
losses/reconstruction_loss.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import functools
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ import flow_reconstruction
10
+ import utils
11
+ from utils.visualisation import flow2rgb_torch
12
+
13
+ logger = utils.log.getLogger(__name__)
14
+
15
+ class ReconstructionLoss:
16
+ def __init__(self, cfg, model):
17
+ self.criterion = nn.MSELoss() if cfg.GWM.CRITERION == 'L2' else nn.L1Loss()
18
+ self.l1_optimize = cfg.GWM.L1_OPTIMIZE
19
+ self.homography = cfg.GWM.HOMOGRAPHY
20
+ self.device=model.device
21
+ self.cfg = cfg
22
+ self.grid_x, self.grid_y = utils.grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device)
23
+ # self.mult_flow = cfg.GWM.USE_MULT_FLOW
24
+ self.flow_colorspace_rec = cfg.GWM.FLOW_COLORSPACE_REC
25
+ flow_reconstruction.set_subsample_skip(cfg.GWM.HOMOGRAPHY_SUBSAMPLE, cfg.GWM.HOMOGRAPHY_SKIP)
26
+ self.flow_u_low = cfg.GWM.FLOW_CLIP_U_LOW
27
+ self.flow_u_high = cfg.GWM.FLOW_CLIP_U_HIGH
28
+ self.flow_v_low = cfg.GWM.FLOW_CLIP_V_LOW
29
+ self.flow_v_high = cfg.GWM.FLOW_CLIP_V_HIGH
30
+
31
+ self._recon_fn = self.flow_quad
32
+ logger.info(f'Using reconstruction method {self._recon_fn.__name__}')
33
+ self.it = 0
34
+ self._extra_losses = []
35
+
36
+ def __call__(self, sample, flow, masks_softmaxed, it, train=True):
37
+ return self.loss(sample, flow, masks_softmaxed, it, train=train)
38
+
39
+ def loss(self, sample, flow, mask_softmaxed, it, train=True):
40
+ self.training = train
41
+ flow = self.process_flow(sample, flow)
42
+ self.it = it
43
+ self._extra_losses = []
44
+
45
+ if self.cfg.GWM.FLOW_RES is not None:
46
+ if flow.shape[-2:] != mask_softmaxed.shape[-2:]:
47
+ logger.debug_once(f'Resizing predicted masks to {self.cfg.GWM.FLOW_RES}')
48
+ mask_softmaxed = F.interpolate(mask_softmaxed, flow.shape[-2:], mode='bilinear', align_corners=False)
49
+
50
+ rec_flows = self.rec_flow(sample, flow, mask_softmaxed)
51
+ if not isinstance(rec_flows, (list, tuple)):
52
+ rec_flows = (rec_flows,)
53
+ k = len(rec_flows)
54
+ loss = sum(self.criterion(flow, rec_flow) / k for rec_flow in rec_flows)
55
+ if len(self._extra_losses):
56
+ loss = loss + sum(self._extra_losses, 0.) / len(self._extra_losses)
57
+ self._extra_losses = []
58
+ return loss
59
+
60
+ def flow_quad(self, sample, flow, masks_softmaxed, it, **_):
61
+ logger.debug_once(f'Reconstruction using quadratic. Masks shape {masks_softmaxed.shape} | '
62
+ f'Flow shape {flow.shape} | '
63
+ f'Grid shape {self.grid_x.shape, self.grid_y.shape}')
64
+ return flow_reconstruction.get_quad_flow(masks_softmaxed, flow, self.grid_x, self.grid_y)
65
+
66
+ def _clipped_recon_fn(self, *args, **kwargs):
67
+ flow = self._recon_fn(*args, **kwargs)
68
+ flow_o = flow[:, :-2]
69
+ flow_u = flow[:, -2:-1].clip(self.flow_u_low, self.flow_u_high)
70
+ flow_v = flow[:, -1:].clip(self.flow_v_low, self.flow_v_high)
71
+ return torch.cat([flow_o, flow_u, flow_v], dim=1)
72
+
73
+ def rec_flow(self, sample, flow, masks_softmaxed):
74
+ it = self.it
75
+ if self.cfg.GWM.FLOW_RES is not None and flow.shape[-2:] != self.grid_x.shape[-2:]:
76
+ logger.debug_once(f'Generating new grid predicted masks of {flow.shape[-2:]}')
77
+ self.grid_x, self.grid_y = utils.grid.get_meshgrid(flow.shape[-2:], self.device)
78
+ return [self._clipped_recon_fn(sample, flow, masks_softmaxed, it)]
79
+
80
+ def process_flow(self, sample, flow_cuda):
81
+ return flow_cuda
82
+
83
+ def viz_flow(self, flow):
84
+ return torch.stack([flow2rgb_torch(x) for x in flow])
85
+
main.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import determinism # noqa
2
+
3
+ determinism.i_do_nothing_but_dont_remove_me_otherwise_things_break() # noqa
4
+
5
+ import argparse
6
+ import bisect
7
+ import copy
8
+ import os
9
+ import sys
10
+ import time
11
+ from argparse import ArgumentParser
12
+
13
+ import torch
14
+ import wandb
15
+ from detectron2.checkpoint import DetectionCheckpointer
16
+ from detectron2.engine import PeriodicCheckpointer
17
+ from torch.utils.tensorboard import SummaryWriter
18
+ from tqdm import tqdm
19
+
20
+ import config
21
+ import losses
22
+ import utils
23
+ from eval_utils import eval_unsupmf, get_unsup_image_viz, get_vis_header
24
+ from mask_former_trainer import setup, Trainer
25
+
26
+
27
+ logger = utils.log.getLogger('gwm')
28
+
29
+ def freeze(module, set=False):
30
+ for param in module.parameters():
31
+ param.requires_grad = set
32
+
33
+
34
+ def main(args):
35
+ cfg = setup(args)
36
+ logger.info(f"Called as {' '.join(sys.argv)}")
37
+ logger.info(f'Output dir {cfg.OUTPUT_DIR}')
38
+
39
+ random_state = utils.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE))
40
+ random_state.seed_everything()
41
+ utils.log.checkpoint_code(cfg.OUTPUT_DIR)
42
+
43
+ if not cfg.SKIP_TB:
44
+ writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
45
+ else:
46
+ writer = None
47
+
48
+ # initialize model
49
+ model = Trainer.build_model(cfg)
50
+ optimizer = Trainer.build_optimizer(cfg, model)
51
+ scheduler = Trainer.build_lr_scheduler(cfg, optimizer)
52
+
53
+ logger.info(f'Optimiser is {type(optimizer)}')
54
+
55
+
56
+ checkpointer = DetectionCheckpointer(model,
57
+ save_dir=os.path.join(cfg.OUTPUT_DIR, 'checkpoints'),
58
+ random_state=random_state,
59
+ optimizer=optimizer,
60
+ scheduler=scheduler)
61
+ periodic_checkpointer = PeriodicCheckpointer(checkpointer=checkpointer,
62
+ period=cfg.SOLVER.CHECKPOINT_PERIOD,
63
+ max_iter=cfg.SOLVER.MAX_ITER,
64
+ max_to_keep=None if cfg.FLAGS.KEEP_ALL else 5,
65
+ file_prefix='checkpoint')
66
+ checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume_path is not None)
67
+ iteration = 0 if args.resume_path is None else checkpoint['iteration']
68
+
69
+ train_loader, val_loader = config.loaders(cfg)
70
+
71
+ # overfit single batch for debug
72
+ # sample = next(iter(loader))
73
+
74
+ criterions = {
75
+ 'reconstruction': (losses.ReconstructionLoss(cfg, model), cfg.GWM.LOSS_MULT.REC, lambda x: 1)}
76
+
77
+ criterion = losses.CriterionDict(criterions)
78
+
79
+ if args.eval_only:
80
+ if len(val_loader.dataset) == 0:
81
+ logger.error("Training dataset: empty")
82
+ sys.exit(0)
83
+ model.eval()
84
+ iou = eval_unsupmf(cfg=cfg, val_loader=val_loader, model=model, criterion=criterion, writer=writer,
85
+ writer_iteration=iteration)
86
+ logger.info(f"Results: iteration: {iteration} IOU = {iou}")
87
+ return
88
+ if len(train_loader.dataset) == 0:
89
+ logger.error("Training dataset: empty")
90
+ sys.exit(0)
91
+
92
+ logger.info(
93
+ f'Start of training: dataset {cfg.GWM.DATASET},'
94
+ f' train {len(train_loader.dataset)}, val {len(val_loader.dataset)},'
95
+ f' device {model.device}, keys {cfg.GWM.SAMPLE_KEYS}, '
96
+ f'multiple flows {cfg.GWM.USE_MULT_FLOW}')
97
+
98
+ iou_best = 0
99
+ timestart = time.time()
100
+ dilate_kernel = torch.ones((2, 2), device=model.device)
101
+
102
+ total_iter = cfg.TOTAL_ITER if cfg.TOTAL_ITER else cfg.SOLVER.MAX_ITER # early stop
103
+ with torch.autograd.set_detect_anomaly(cfg.DEBUG) and \
104
+ tqdm(initial=iteration, total=total_iter, disable=utils.environment.is_slurm()) as pbar:
105
+ while iteration < total_iter:
106
+ for sample in train_loader:
107
+
108
+ if cfg.MODEL.META_ARCHITECTURE != 'UNET' and cfg.FLAGS.UNFREEZE_AT:
109
+ if hasattr(model.backbone, 'frozen_stages'):
110
+ assert cfg.MODEL.BACKBONE.FREEZE_AT == -1, f"MODEL initial parameters forced frozen"
111
+ stages = [s for s, m in cfg.FLAGS.UNFREEZE_AT]
112
+ milest = [m for s, m in cfg.FLAGS.UNFREEZE_AT]
113
+ pos = bisect.bisect_right(milest, iteration) - 1
114
+ if pos >= 0:
115
+ curr_setting = model.backbone.frozen_stages
116
+ if curr_setting != stages[pos]:
117
+ logger.info(f"Updating backbone freezing stages from {curr_setting} to {stages[pos]}")
118
+ model.backbone.frozen_stages = stages[pos]
119
+ model.train()
120
+ else:
121
+ assert cfg.MODEL.BACKBONE.FREEZE_AT == -1, f"MODEL initial parameters forced frozen"
122
+ stages = [s for s, m in cfg.FLAGS.UNFREEZE_AT]
123
+ milest = [m for s, m in cfg.FLAGS.UNFREEZE_AT]
124
+ pos = bisect.bisect_right(milest, iteration) - 1
125
+ freeze(model, set=False)
126
+ freeze(model.sem_seg_head.predictor, set=True)
127
+ if pos >= 0:
128
+ stage = stages[pos]
129
+ if stage <= 2:
130
+ freeze(model.sem_seg_head, set=True)
131
+ if stage <= 1:
132
+ freeze(model.backbone, set=True)
133
+ model.train()
134
+
135
+ else:
136
+ logger.debug_once(f'Unfreezing disabled schedule: {cfg.FLAGS.UNFREEZE_AT}')
137
+
138
+ sample = [e for s in sample for e in s]
139
+ flow_key = 'flow'
140
+ raw_sem_seg = False
141
+ if cfg.GWM.FLOW_RES is not None:
142
+ flow_key = 'flow_big'
143
+ raw_sem_seg = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME == 'MegaBigPixelDecoder'
144
+
145
+ flow = torch.stack([x[flow_key].to(model.device) for x in sample]).clip(-20, 20)
146
+ logger.debug_once(f'flow shape: {flow.shape}')
147
+ preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True, raw_sem_seg=raw_sem_seg)
148
+ masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
149
+ logger.debug_once(f'mask shape: {masks_raw.shape}')
150
+ masks_softmaxed_list = [torch.softmax(masks_raw, dim=1)]
151
+
152
+
153
+ total_losses = []
154
+ log_dicts = []
155
+ for mask_idx, masks_softmaxed in enumerate(masks_softmaxed_list):
156
+
157
+ loss, log_dict = criterion(sample, flow, masks_softmaxed, iteration)
158
+
159
+ if cfg.GWM.USE_MULT_FLOW:
160
+ flow2 = torch.stack([x[flow_key + '_2'].to(model.device) for x in sample]).clip(-20, 20)
161
+ other_loss, other_log_dict = criterion(sample, flow2, masks_softmaxed, iteration)
162
+ loss = loss / 2 + other_loss / 2
163
+ for k, v in other_log_dict.items():
164
+ log_dict[k] = other_log_dict[k] / 2 + v / 2
165
+ total_losses.append(loss)
166
+ log_dicts.append(log_dict)
167
+
168
+ loss_ws = cfg.GWM.LOSS_MULT.HEIR_W
169
+ total_w = float(sum(loss_ws[:len(total_losses)]))
170
+ log_dict = {}
171
+ if len(total_losses) == 1:
172
+ log_dict = log_dicts[0]
173
+ loss = total_losses[0]
174
+ else:
175
+ loss = 0
176
+ for i, (tl, w, ld) in enumerate(zip(total_losses, loss_ws, log_dicts)):
177
+ for k, v in ld.items():
178
+ log_dict[f'{k}_{i}'] = v * w / total_w
179
+ loss += tl * w / total_w
180
+
181
+ train_log_dict = {f'train/{k}': v for k, v in log_dict.items()}
182
+ del log_dict
183
+ train_log_dict['train/learning_rate'] = optimizer.param_groups[-1]['lr']
184
+ train_log_dict['train/loss_total'] = loss.item()
185
+
186
+
187
+ optimizer.zero_grad()
188
+
189
+
190
+ loss.backward()
191
+ optimizer.step()
192
+ scheduler.step()
193
+
194
+ pbar.set_postfix(loss=loss.item())
195
+ pbar.update()
196
+
197
+ # Sanity check for RNG state
198
+ if (iteration + 1) % 1000 == 0 or iteration + 1 in {1, 50}:
199
+ logger.info(
200
+ f'Iteration {iteration + 1}. RNG outputs {utils.random_state.get_randstate_magic_numbers(model.device)}')
201
+
202
+ if cfg.DEBUG or (iteration + 1) % 100 == 0:
203
+ logger.info(
204
+ f'Iteration: {iteration + 1}, time: {time.time() - timestart:.01f}s, loss: {loss.item():.02f}.')
205
+
206
+ for k, v in train_log_dict.items():
207
+ if writer:
208
+ writer.add_scalar(k, v, iteration + 1)
209
+
210
+ if cfg.WANDB.ENABLE:
211
+ wandb.log(train_log_dict, step=iteration + 1)
212
+
213
+ if (iteration + 1) % cfg.LOG_FREQ == 0 or (iteration + 1) in [1, 50, 500]:
214
+ model.eval()
215
+ if writer:
216
+ flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
217
+ image_viz, header_text = get_unsup_image_viz(model, cfg, sample, criterion)
218
+ header = get_vis_header(image_viz.size(2), flow.size(3), header_text)
219
+ image_viz = torch.cat([header, image_viz], dim=1)
220
+ writer.add_image('train/images', image_viz, iteration + 1)
221
+ if cfg.WANDB.ENABLE and (iteration + 1) % 2500 == 0:
222
+ image_viz = get_unsup_image_viz(model, cfg, sample)
223
+ wandb.log({'train/viz': wandb.Image(image_viz.float())}, step=iteration + 1)
224
+
225
+ if iou := eval_unsupmf(cfg=cfg, val_loader=val_loader, model=model, criterion=criterion,
226
+ writer=writer, writer_iteration=iteration + 1, use_wandb=cfg.WANDB.ENABLE):
227
+ if cfg.SOLVER.CHECKPOINT_PERIOD and iou > iou_best:
228
+ iou_best = iou
229
+ if not args.wandb_sweep_mode:
230
+ checkpointer.save(name='checkpoint_best', iteration=iteration + 1, loss=loss,
231
+ iou=iou_best)
232
+ logger.info(f'New best IoU {iou_best:.02f} after iteration {iteration + 1}')
233
+ if cfg.WANDB.ENABLE:
234
+ wandb.log({'eval/IoU_best': iou_best}, step=iteration + 1)
235
+ if writer:
236
+ writer.add_scalar('eval/IoU_best', iou_best, iteration + 1)
237
+
238
+
239
+ model.train()
240
+
241
+ periodic_checkpointer.step(iteration=iteration + 1, loss=loss)
242
+
243
+ iteration += 1
244
+ timestart = time.time()
245
+
246
+
247
+ def get_argparse_args():
248
+ parser = ArgumentParser()
249
+ parser.add_argument('--resume_path', type=str, default=None)
250
+ parser.add_argument('--use_wandb', dest='wandb_sweep_mode', action='store_true') # for sweep
251
+ parser.add_argument('--config-file', type=str,
252
+ default='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml')
253
+ parser.add_argument('--eval_only', action='store_true')
254
+ parser.add_argument(
255
+ "opts",
256
+ help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. "
257
+ "See config references at "
258
+ "https://detectron2.readthedocs.io/modules/config.html#config-references",
259
+ default=None,
260
+ nargs=argparse.REMAINDER,
261
+ )
262
+ return parser
263
+
264
+
265
+ if __name__ == "__main__":
266
+ args = get_argparse_args().parse_args()
267
+ if args.resume_path:
268
+ args.config_file = "/".join(args.resume_path.split('/')[:-2]) + '/config.yaml'
269
+ print(args.config_file)
270
+ main(args)
mask_former/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import data # register all new datasets
3
+ from . import modeling
4
+
5
+ # config
6
+ from .config import add_mask_former_config
7
+
8
+ # dataset loading
9
+ from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper
10
+ from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import (
11
+ MaskFormerPanopticDatasetMapper,
12
+ )
13
+ from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
14
+ MaskFormerSemanticDatasetMapper,
15
+ )
16
+
17
+ # models
18
+ from .mask_former_model import MaskFormer
19
+ from .test_time_augmentation import SemanticSegmentorWithTTA
mask_former/config.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from detectron2.config import CfgNode as CN
4
+
5
+
6
+ def add_mask_former_config(cfg):
7
+ """
8
+ Add config for MASK_FORMER.
9
+ """
10
+ # data config
11
+ # select the dataset mapper
12
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
13
+ # Color augmentation
14
+ cfg.INPUT.COLOR_AUG_SSD = False
15
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
16
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
17
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
18
+ # Pad image and segmentation GT in dataset mapper.
19
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
20
+
21
+ # solver config
22
+ # weight decay on embedding
23
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
24
+ # optimizer
25
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
26
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
27
+
28
+ # mask_former model config
29
+ cfg.MODEL.MASK_FORMER = CN()
30
+
31
+ # loss
32
+ cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
33
+ cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
34
+ cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
35
+ cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
36
+
37
+ # transformer config
38
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
39
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
40
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
41
+ cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
42
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
43
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
44
+
45
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
46
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
47
+
48
+ cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
49
+ cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
50
+
51
+ # mask_former inference config
52
+ cfg.MODEL.MASK_FORMER.TEST = CN()
53
+ cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
54
+ cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
55
+ cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
56
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
57
+
58
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
59
+ # you can use this config to override
60
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
61
+
62
+ # pixel decoder config
63
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
64
+ # adding transformer in pixel decoder
65
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
66
+ # pixel decoder
67
+ cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
68
+
69
+ # swin transformer backbone
70
+ cfg.MODEL.SWIN = CN()
71
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
72
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
73
+ cfg.MODEL.SWIN.EMBED_DIM = 96
74
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
75
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
76
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
77
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
78
+ cfg.MODEL.SWIN.QKV_BIAS = True
79
+ cfg.MODEL.SWIN.QK_SCALE = None
80
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
81
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
82
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
83
+ cfg.MODEL.SWIN.APE = False
84
+ cfg.MODEL.SWIN.PATCH_NORM = True
85
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
mask_former/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import datasets
mask_former/data/dataset_mappers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
3
+ import copy
4
+ import logging
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import detection_utils as utils
11
+ from detectron2.data import transforms as T
12
+ from detectron2.data.transforms import TransformGen
13
+ from detectron2.structures import BitMasks, Instances
14
+
15
+ __all__ = ["DETRPanopticDatasetMapper"]
16
+
17
+
18
+ def build_transform_gen(cfg, is_train):
19
+ """
20
+ Create a list of :class:`TransformGen` from config.
21
+ Returns:
22
+ list[TransformGen]
23
+ """
24
+ if is_train:
25
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
26
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
27
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
28
+ else:
29
+ min_size = cfg.INPUT.MIN_SIZE_TEST
30
+ max_size = cfg.INPUT.MAX_SIZE_TEST
31
+ sample_style = "choice"
32
+ if sample_style == "range":
33
+ assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(
34
+ len(min_size)
35
+ )
36
+
37
+ logger = logging.getLogger(__name__)
38
+ tfm_gens = []
39
+ if is_train:
40
+ tfm_gens.append(T.RandomFlip())
41
+ tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
42
+ if is_train:
43
+ logger.info("TransformGens used in training: " + str(tfm_gens))
44
+ return tfm_gens
45
+
46
+
47
+ # This is specifically designed for the COCO dataset.
48
+ class DETRPanopticDatasetMapper:
49
+ """
50
+ A callable which takes a dataset dict in Detectron2 Dataset format,
51
+ and map it into a format used by MaskFormer.
52
+
53
+ This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
54
+
55
+ The callable currently does the following:
56
+
57
+ 1. Read the image from "file_name"
58
+ 2. Applies geometric transforms to the image and annotation
59
+ 3. Find and applies suitable cropping to the image and annotation
60
+ 4. Prepare image and annotation to Tensors
61
+ """
62
+
63
+ @configurable
64
+ def __init__(
65
+ self,
66
+ is_train=True,
67
+ *,
68
+ crop_gen,
69
+ tfm_gens,
70
+ image_format,
71
+ ):
72
+ """
73
+ NOTE: this interface is experimental.
74
+ Args:
75
+ is_train: for training or inference
76
+ augmentations: a list of augmentations or deterministic transforms to apply
77
+ crop_gen: crop augmentation
78
+ tfm_gens: data augmentation
79
+ image_format: an image format supported by :func:`detection_utils.read_image`.
80
+ """
81
+ self.crop_gen = crop_gen
82
+ self.tfm_gens = tfm_gens
83
+ logging.getLogger(__name__).info(
84
+ "[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format(
85
+ str(self.tfm_gens), str(self.crop_gen)
86
+ )
87
+ )
88
+
89
+ self.img_format = image_format
90
+ self.is_train = is_train
91
+
92
+ @classmethod
93
+ def from_config(cls, cfg, is_train=True):
94
+ # Build augmentation
95
+ if cfg.INPUT.CROP.ENABLED and is_train:
96
+ crop_gen = [
97
+ T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
98
+ T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
99
+ ]
100
+ else:
101
+ crop_gen = None
102
+
103
+ tfm_gens = build_transform_gen(cfg, is_train)
104
+
105
+ ret = {
106
+ "is_train": is_train,
107
+ "crop_gen": crop_gen,
108
+ "tfm_gens": tfm_gens,
109
+ "image_format": cfg.INPUT.FORMAT,
110
+ }
111
+ return ret
112
+
113
+ def __call__(self, dataset_dict):
114
+ """
115
+ Args:
116
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
117
+
118
+ Returns:
119
+ dict: a format that builtin models in detectron2 accept
120
+ """
121
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
122
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
123
+ utils.check_image_size(dataset_dict, image)
124
+
125
+ if self.crop_gen is None:
126
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
127
+ else:
128
+ if np.random.rand() > 0.5:
129
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
130
+ else:
131
+ image, transforms = T.apply_transform_gens(
132
+ self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
133
+ )
134
+
135
+ image_shape = image.shape[:2] # h, w
136
+
137
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
138
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
139
+ # Therefore it's important to use torch.Tensor.
140
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
141
+
142
+ if not self.is_train:
143
+ # USER: Modify this if you want to keep them for some reason.
144
+ dataset_dict.pop("annotations", None)
145
+ return dataset_dict
146
+
147
+ if "pan_seg_file_name" in dataset_dict:
148
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
149
+ segments_info = dataset_dict["segments_info"]
150
+
151
+ # apply the same transformation to panoptic segmentation
152
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
153
+
154
+ from panopticapi.utils import rgb2id
155
+
156
+ pan_seg_gt = rgb2id(pan_seg_gt)
157
+
158
+ instances = Instances(image_shape)
159
+ classes = []
160
+ masks = []
161
+ for segment_info in segments_info:
162
+ class_id = segment_info["category_id"]
163
+ if not segment_info["iscrowd"]:
164
+ classes.append(class_id)
165
+ masks.append(pan_seg_gt == segment_info["id"])
166
+
167
+ classes = np.array(classes)
168
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
169
+ if len(masks) == 0:
170
+ # Some image does not have annotation (all ignored)
171
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
172
+ else:
173
+ masks = BitMasks(
174
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
175
+ )
176
+ instances.gt_masks = masks.tensor
177
+
178
+ dataset_dict["instances"] = instances
179
+
180
+ return dataset_dict
mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import detection_utils as utils
11
+ from detectron2.data import transforms as T
12
+ from detectron2.structures import BitMasks, Instances
13
+
14
+ from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
15
+
16
+ __all__ = ["MaskFormerPanopticDatasetMapper"]
17
+
18
+
19
+ class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper):
20
+ """
21
+ A callable which takes a dataset dict in Detectron2 Dataset format,
22
+ and map it into a format used by MaskFormer for panoptic segmentation.
23
+
24
+ The callable currently does the following:
25
+
26
+ 1. Read the image from "file_name"
27
+ 2. Applies geometric transforms to the image and annotation
28
+ 3. Find and applies suitable cropping to the image and annotation
29
+ 4. Prepare image and annotation to Tensors
30
+ """
31
+
32
+ @configurable
33
+ def __init__(
34
+ self,
35
+ is_train=True,
36
+ *,
37
+ augmentations,
38
+ image_format,
39
+ ignore_label,
40
+ size_divisibility,
41
+ ):
42
+ """
43
+ NOTE: this interface is experimental.
44
+ Args:
45
+ is_train: for training or inference
46
+ augmentations: a list of augmentations or deterministic transforms to apply
47
+ image_format: an image format supported by :func:`detection_utils.read_image`.
48
+ ignore_label: the label that is ignored to evaluation
49
+ size_divisibility: pad image size to be divisible by this value
50
+ """
51
+ super().__init__(
52
+ is_train,
53
+ augmentations=augmentations,
54
+ image_format=image_format,
55
+ ignore_label=ignore_label,
56
+ size_divisibility=size_divisibility,
57
+ )
58
+
59
+ def __call__(self, dataset_dict):
60
+ """
61
+ Args:
62
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
63
+
64
+ Returns:
65
+ dict: a format that builtin models in detectron2 accept
66
+ """
67
+ assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!"
68
+
69
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
70
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
71
+ utils.check_image_size(dataset_dict, image)
72
+
73
+ # semantic segmentation
74
+ if "sem_seg_file_name" in dataset_dict:
75
+ # PyTorch transformation not implemented for uint16, so converting it to double first
76
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
77
+ else:
78
+ sem_seg_gt = None
79
+
80
+ # panoptic segmentation
81
+ if "pan_seg_file_name" in dataset_dict:
82
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
83
+ segments_info = dataset_dict["segments_info"]
84
+ else:
85
+ pan_seg_gt = None
86
+ segments_info = None
87
+
88
+ if pan_seg_gt is None:
89
+ raise ValueError(
90
+ "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
91
+ dataset_dict["file_name"]
92
+ )
93
+ )
94
+
95
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
96
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
97
+ image = aug_input.image
98
+ if sem_seg_gt is not None:
99
+ sem_seg_gt = aug_input.sem_seg
100
+
101
+ # apply the same transformation to panoptic segmentation
102
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
103
+
104
+ from panopticapi.utils import rgb2id
105
+
106
+ pan_seg_gt = rgb2id(pan_seg_gt)
107
+
108
+ # Pad image and segmentation label here!
109
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
110
+ if sem_seg_gt is not None:
111
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
112
+ pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
113
+
114
+ if self.size_divisibility > 0:
115
+ image_size = (image.shape[-2], image.shape[-1])
116
+ padding_size = [
117
+ 0,
118
+ self.size_divisibility - image_size[1],
119
+ 0,
120
+ self.size_divisibility - image_size[0],
121
+ ]
122
+ image = F.pad(image, padding_size, value=128).contiguous()
123
+ if sem_seg_gt is not None:
124
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
125
+ pan_seg_gt = F.pad(
126
+ pan_seg_gt, padding_size, value=0
127
+ ).contiguous() # 0 is the VOID panoptic label
128
+
129
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
130
+
131
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
132
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
133
+ # Therefore it's important to use torch.Tensor.
134
+ dataset_dict["image"] = image
135
+ if sem_seg_gt is not None:
136
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
137
+
138
+ if "annotations" in dataset_dict:
139
+ raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
140
+
141
+ # Prepare per-category binary masks
142
+ pan_seg_gt = pan_seg_gt.numpy()
143
+ instances = Instances(image_shape)
144
+ classes = []
145
+ masks = []
146
+ for segment_info in segments_info:
147
+ class_id = segment_info["category_id"]
148
+ if not segment_info["iscrowd"]:
149
+ classes.append(class_id)
150
+ masks.append(pan_seg_gt == segment_info["id"])
151
+
152
+ classes = np.array(classes)
153
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
154
+ if len(masks) == 0:
155
+ # Some image does not have annotation (all ignored)
156
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
157
+ else:
158
+ masks = BitMasks(
159
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
160
+ )
161
+ instances.gt_masks = masks.tensor
162
+
163
+ dataset_dict["instances"] = instances
164
+
165
+ return dataset_dict
mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import MetadataCatalog
11
+ from detectron2.data import detection_utils as utils
12
+ from detectron2.data import transforms as T
13
+ from detectron2.projects.point_rend import ColorAugSSDTransform
14
+ from detectron2.structures import BitMasks, Instances
15
+
16
+ __all__ = ["MaskFormerSemanticDatasetMapper"]
17
+
18
+
19
+ class MaskFormerSemanticDatasetMapper:
20
+ """
21
+ A callable which takes a dataset dict in Detectron2 Dataset format,
22
+ and map it into a format used by MaskFormer for semantic segmentation.
23
+
24
+ The callable currently does the following:
25
+
26
+ 1. Read the image from "file_name"
27
+ 2. Applies geometric transforms to the image and annotation
28
+ 3. Find and applies suitable cropping to the image and annotation
29
+ 4. Prepare image and annotation to Tensors
30
+ """
31
+
32
+ @configurable
33
+ def __init__(
34
+ self,
35
+ is_train=True,
36
+ *,
37
+ augmentations,
38
+ image_format,
39
+ ignore_label,
40
+ size_divisibility,
41
+ ):
42
+ """
43
+ NOTE: this interface is experimental.
44
+ Args:
45
+ is_train: for training or inference
46
+ augmentations: a list of augmentations or deterministic transforms to apply
47
+ image_format: an image format supported by :func:`detection_utils.read_image`.
48
+ ignore_label: the label that is ignored to evaluation
49
+ size_divisibility: pad image size to be divisible by this value
50
+ """
51
+ self.is_train = is_train
52
+ self.tfm_gens = augmentations
53
+ self.img_format = image_format
54
+ self.ignore_label = ignore_label
55
+ self.size_divisibility = size_divisibility
56
+
57
+ logger = logging.getLogger(__name__)
58
+ mode = "training" if is_train else "inference"
59
+ logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
60
+
61
+ @classmethod
62
+ def from_config(cls, cfg, is_train=True):
63
+ # Build augmentation
64
+ augs = [
65
+ T.ResizeShortestEdge(
66
+ cfg.INPUT.MIN_SIZE_TRAIN,
67
+ cfg.INPUT.MAX_SIZE_TRAIN,
68
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
69
+ )
70
+ ]
71
+ if cfg.INPUT.CROP.ENABLED:
72
+ augs.append(
73
+ T.RandomCrop_CategoryAreaConstraint(
74
+ cfg.INPUT.CROP.TYPE,
75
+ cfg.INPUT.CROP.SIZE,
76
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
77
+ cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
78
+ )
79
+ )
80
+ if cfg.INPUT.COLOR_AUG_SSD:
81
+ augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
82
+ augs.append(T.RandomFlip())
83
+
84
+ # Assume always applies to the training set.
85
+ dataset_names = cfg.DATASETS.TRAIN
86
+ meta = MetadataCatalog.get(dataset_names[0])
87
+ ignore_label = meta.ignore_label
88
+
89
+ ret = {
90
+ "is_train": is_train,
91
+ "augmentations": augs,
92
+ "image_format": cfg.INPUT.FORMAT,
93
+ "ignore_label": ignore_label,
94
+ "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
95
+ }
96
+ return ret
97
+
98
+ def __call__(self, dataset_dict):
99
+ """
100
+ Args:
101
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
102
+
103
+ Returns:
104
+ dict: a format that builtin models in detectron2 accept
105
+ """
106
+ assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
107
+
108
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
109
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
110
+ utils.check_image_size(dataset_dict, image)
111
+
112
+ if "sem_seg_file_name" in dataset_dict:
113
+ # PyTorch transformation not implemented for uint16, so converting it to double first
114
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
115
+ else:
116
+ sem_seg_gt = None
117
+
118
+ if sem_seg_gt is None:
119
+ raise ValueError(
120
+ "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
121
+ dataset_dict["file_name"]
122
+ )
123
+ )
124
+
125
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
126
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
127
+ image = aug_input.image
128
+ sem_seg_gt = aug_input.sem_seg
129
+
130
+ # Pad image and segmentation label here!
131
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
132
+ if sem_seg_gt is not None:
133
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
134
+
135
+ if self.size_divisibility > 0:
136
+ image_size = (image.shape[-2], image.shape[-1])
137
+ padding_size = [
138
+ 0,
139
+ self.size_divisibility - image_size[1],
140
+ 0,
141
+ self.size_divisibility - image_size[0],
142
+ ]
143
+ image = F.pad(image, padding_size, value=128).contiguous()
144
+ if sem_seg_gt is not None:
145
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
146
+
147
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
148
+
149
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
150
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
151
+ # Therefore it's important to use torch.Tensor.
152
+ dataset_dict["image"] = image
153
+
154
+ if sem_seg_gt is not None:
155
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
156
+
157
+ if "annotations" in dataset_dict:
158
+ raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
159
+
160
+ # Prepare per-category binary masks
161
+ if sem_seg_gt is not None:
162
+ sem_seg_gt = sem_seg_gt.numpy()
163
+ instances = Instances(image_shape)
164
+ classes = np.unique(sem_seg_gt)
165
+ # remove ignored region
166
+ classes = classes[classes != self.ignore_label]
167
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
168
+
169
+ masks = []
170
+ for class_id in classes:
171
+ masks.append(sem_seg_gt == class_id)
172
+
173
+ if len(masks) == 0:
174
+ # Some image does not have annotation (all ignored)
175
+ instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
176
+ else:
177
+ masks = BitMasks(
178
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
179
+ )
180
+ instances.gt_masks = masks.tensor
181
+
182
+ dataset_dict["instances"] = instances
183
+
184
+ return dataset_dict
mask_former/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import (
3
+ register_ade20k_full,
4
+ register_ade20k_panoptic,
5
+ register_coco_stuff_10k,
6
+ register_mapillary_vistas,
7
+ )
mask_former/data/datasets/register_ade20k_full.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ from detectron2.data import DatasetCatalog, MetadataCatalog
5
+ from detectron2.data.datasets import load_sem_seg
6
+
7
+ ADE20K_SEM_SEG_FULL_CATEGORIES = [
8
+ {"name": "wall", "id": 2978, "trainId": 0},
9
+ {"name": "building, edifice", "id": 312, "trainId": 1},
10
+ {"name": "sky", "id": 2420, "trainId": 2},
11
+ {"name": "tree", "id": 2855, "trainId": 3},
12
+ {"name": "road, route", "id": 2131, "trainId": 4},
13
+ {"name": "floor, flooring", "id": 976, "trainId": 5},
14
+ {"name": "ceiling", "id": 447, "trainId": 6},
15
+ {"name": "bed", "id": 165, "trainId": 7},
16
+ {"name": "sidewalk, pavement", "id": 2377, "trainId": 8},
17
+ {"name": "earth, ground", "id": 838, "trainId": 9},
18
+ {"name": "cabinet", "id": 350, "trainId": 10},
19
+ {"name": "person, individual, someone, somebody, mortal, soul", "id": 1831, "trainId": 11},
20
+ {"name": "grass", "id": 1125, "trainId": 12},
21
+ {"name": "windowpane, window", "id": 3055, "trainId": 13},
22
+ {"name": "car, auto, automobile, machine, motorcar", "id": 401, "trainId": 14},
23
+ {"name": "mountain, mount", "id": 1610, "trainId": 15},
24
+ {"name": "plant, flora, plant life", "id": 1910, "trainId": 16},
25
+ {"name": "table", "id": 2684, "trainId": 17},
26
+ {"name": "chair", "id": 471, "trainId": 18},
27
+ {"name": "curtain, drape, drapery, mantle, pall", "id": 687, "trainId": 19},
28
+ {"name": "door", "id": 774, "trainId": 20},
29
+ {"name": "sofa, couch, lounge", "id": 2473, "trainId": 21},
30
+ {"name": "sea", "id": 2264, "trainId": 22},
31
+ {"name": "painting, picture", "id": 1735, "trainId": 23},
32
+ {"name": "water", "id": 2994, "trainId": 24},
33
+ {"name": "mirror", "id": 1564, "trainId": 25},
34
+ {"name": "house", "id": 1276, "trainId": 26},
35
+ {"name": "rug, carpet, carpeting", "id": 2178, "trainId": 27},
36
+ {"name": "shelf", "id": 2329, "trainId": 28},
37
+ {"name": "armchair", "id": 57, "trainId": 29},
38
+ {"name": "fence, fencing", "id": 907, "trainId": 30},
39
+ {"name": "field", "id": 913, "trainId": 31},
40
+ {"name": "lamp", "id": 1395, "trainId": 32},
41
+ {"name": "rock, stone", "id": 2138, "trainId": 33},
42
+ {"name": "seat", "id": 2272, "trainId": 34},
43
+ {"name": "river", "id": 2128, "trainId": 35},
44
+ {"name": "desk", "id": 724, "trainId": 36},
45
+ {"name": "bathtub, bathing tub, bath, tub", "id": 155, "trainId": 37},
46
+ {"name": "railing, rail", "id": 2053, "trainId": 38},
47
+ {"name": "signboard, sign", "id": 2380, "trainId": 39},
48
+ {"name": "cushion", "id": 689, "trainId": 40},
49
+ {"name": "path", "id": 1788, "trainId": 41},
50
+ {"name": "work surface", "id": 3087, "trainId": 42},
51
+ {"name": "stairs, steps", "id": 2530, "trainId": 43},
52
+ {"name": "column, pillar", "id": 581, "trainId": 44},
53
+ {"name": "sink", "id": 2388, "trainId": 45},
54
+ {"name": "wardrobe, closet, press", "id": 2985, "trainId": 46},
55
+ {"name": "snow", "id": 2454, "trainId": 47},
56
+ {"name": "refrigerator, icebox", "id": 2096, "trainId": 48},
57
+ {"name": "base, pedestal, stand", "id": 137, "trainId": 49},
58
+ {"name": "bridge, span", "id": 294, "trainId": 50},
59
+ {"name": "blind, screen", "id": 212, "trainId": 51},
60
+ {"name": "runway", "id": 2185, "trainId": 52},
61
+ {"name": "cliff, drop, drop-off", "id": 524, "trainId": 53},
62
+ {"name": "sand", "id": 2212, "trainId": 54},
63
+ {"name": "fireplace, hearth, open fireplace", "id": 943, "trainId": 55},
64
+ {"name": "pillow", "id": 1869, "trainId": 56},
65
+ {"name": "screen door, screen", "id": 2251, "trainId": 57},
66
+ {"name": "toilet, can, commode, crapper, pot, potty, stool, throne", "id": 2793, "trainId": 58},
67
+ {"name": "skyscraper", "id": 2423, "trainId": 59},
68
+ {"name": "grandstand, covered stand", "id": 1121, "trainId": 60},
69
+ {"name": "box", "id": 266, "trainId": 61},
70
+ {"name": "pool table, billiard table, snooker table", "id": 1948, "trainId": 62},
71
+ {"name": "palm, palm tree", "id": 1744, "trainId": 63},
72
+ {"name": "double door", "id": 783, "trainId": 64},
73
+ {"name": "coffee table, cocktail table", "id": 571, "trainId": 65},
74
+ {"name": "counter", "id": 627, "trainId": 66},
75
+ {"name": "countertop", "id": 629, "trainId": 67},
76
+ {"name": "chest of drawers, chest, bureau, dresser", "id": 491, "trainId": 68},
77
+ {"name": "kitchen island", "id": 1374, "trainId": 69},
78
+ {"name": "boat", "id": 223, "trainId": 70},
79
+ {"name": "waterfall, falls", "id": 3016, "trainId": 71},
80
+ {
81
+ "name": "stove, kitchen stove, range, kitchen range, cooking stove",
82
+ "id": 2598,
83
+ "trainId": 72,
84
+ },
85
+ {"name": "flower", "id": 978, "trainId": 73},
86
+ {"name": "bookcase", "id": 239, "trainId": 74},
87
+ {"name": "controls", "id": 608, "trainId": 75},
88
+ {"name": "book", "id": 236, "trainId": 76},
89
+ {"name": "stairway, staircase", "id": 2531, "trainId": 77},
90
+ {"name": "streetlight, street lamp", "id": 2616, "trainId": 78},
91
+ {
92
+ "name": "computer, computing machine, computing device, data processor, electronic computer, information processing system",
93
+ "id": 591,
94
+ "trainId": 79,
95
+ },
96
+ {
97
+ "name": "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle",
98
+ "id": 327,
99
+ "trainId": 80,
100
+ },
101
+ {"name": "swivel chair", "id": 2679, "trainId": 81},
102
+ {"name": "light, light source", "id": 1451, "trainId": 82},
103
+ {"name": "bench", "id": 181, "trainId": 83},
104
+ {"name": "case, display case, showcase, vitrine", "id": 420, "trainId": 84},
105
+ {"name": "towel", "id": 2821, "trainId": 85},
106
+ {"name": "fountain", "id": 1023, "trainId": 86},
107
+ {"name": "embankment", "id": 855, "trainId": 87},
108
+ {
109
+ "name": "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box",
110
+ "id": 2733,
111
+ "trainId": 88,
112
+ },
113
+ {"name": "van", "id": 2928, "trainId": 89},
114
+ {"name": "hill", "id": 1240, "trainId": 90},
115
+ {"name": "awning, sunshade, sunblind", "id": 77, "trainId": 91},
116
+ {"name": "poster, posting, placard, notice, bill, card", "id": 1969, "trainId": 92},
117
+ {"name": "truck, motortruck", "id": 2880, "trainId": 93},
118
+ {"name": "airplane, aeroplane, plane", "id": 14, "trainId": 94},
119
+ {"name": "pole", "id": 1936, "trainId": 95},
120
+ {"name": "tower", "id": 2828, "trainId": 96},
121
+ {"name": "court", "id": 631, "trainId": 97},
122
+ {"name": "ball", "id": 103, "trainId": 98},
123
+ {
124
+ "name": "aircraft carrier, carrier, flattop, attack aircraft carrier",
125
+ "id": 3144,
126
+ "trainId": 99,
127
+ },
128
+ {"name": "buffet, counter, sideboard", "id": 308, "trainId": 100},
129
+ {"name": "hovel, hut, hutch, shack, shanty", "id": 1282, "trainId": 101},
130
+ {"name": "apparel, wearing apparel, dress, clothes", "id": 38, "trainId": 102},
131
+ {"name": "minibike, motorbike", "id": 1563, "trainId": 103},
132
+ {"name": "animal, animate being, beast, brute, creature, fauna", "id": 29, "trainId": 104},
133
+ {"name": "chandelier, pendant, pendent", "id": 480, "trainId": 105},
134
+ {"name": "step, stair", "id": 2569, "trainId": 106},
135
+ {"name": "booth, cubicle, stall, kiosk", "id": 247, "trainId": 107},
136
+ {"name": "bicycle, bike, wheel, cycle", "id": 187, "trainId": 108},
137
+ {"name": "doorframe, doorcase", "id": 778, "trainId": 109},
138
+ {"name": "sconce", "id": 2243, "trainId": 110},
139
+ {"name": "pond", "id": 1941, "trainId": 111},
140
+ {"name": "trade name, brand name, brand, marque", "id": 2833, "trainId": 112},
141
+ {"name": "bannister, banister, balustrade, balusters, handrail", "id": 120, "trainId": 113},
142
+ {"name": "bag", "id": 95, "trainId": 114},
143
+ {"name": "traffic light, traffic signal, stoplight", "id": 2836, "trainId": 115},
144
+ {"name": "gazebo", "id": 1087, "trainId": 116},
145
+ {"name": "escalator, moving staircase, moving stairway", "id": 868, "trainId": 117},
146
+ {"name": "land, ground, soil", "id": 1401, "trainId": 118},
147
+ {"name": "board, plank", "id": 220, "trainId": 119},
148
+ {"name": "arcade machine", "id": 47, "trainId": 120},
149
+ {"name": "eiderdown, duvet, continental quilt", "id": 843, "trainId": 121},
150
+ {"name": "bar", "id": 123, "trainId": 122},
151
+ {"name": "stall, stand, sales booth", "id": 2537, "trainId": 123},
152
+ {"name": "playground", "id": 1927, "trainId": 124},
153
+ {"name": "ship", "id": 2337, "trainId": 125},
154
+ {"name": "ottoman, pouf, pouffe, puff, hassock", "id": 1702, "trainId": 126},
155
+ {
156
+ "name": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
157
+ "id": 64,
158
+ "trainId": 127,
159
+ },
160
+ {"name": "bottle", "id": 249, "trainId": 128},
161
+ {"name": "cradle", "id": 642, "trainId": 129},
162
+ {"name": "pot, flowerpot", "id": 1981, "trainId": 130},
163
+ {
164
+ "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
165
+ "id": 609,
166
+ "trainId": 131,
167
+ },
168
+ {"name": "train, railroad train", "id": 2840, "trainId": 132},
169
+ {"name": "stool", "id": 2586, "trainId": 133},
170
+ {"name": "lake", "id": 1393, "trainId": 134},
171
+ {"name": "tank, storage tank", "id": 2704, "trainId": 135},
172
+ {"name": "ice, water ice", "id": 1304, "trainId": 136},
173
+ {"name": "basket, handbasket", "id": 146, "trainId": 137},
174
+ {"name": "manhole", "id": 1494, "trainId": 138},
175
+ {"name": "tent, collapsible shelter", "id": 2739, "trainId": 139},
176
+ {"name": "canopy", "id": 389, "trainId": 140},
177
+ {"name": "microwave, microwave oven", "id": 1551, "trainId": 141},
178
+ {"name": "barrel, cask", "id": 131, "trainId": 142},
179
+ {"name": "dirt track", "id": 738, "trainId": 143},
180
+ {"name": "beam", "id": 161, "trainId": 144},
181
+ {"name": "dishwasher, dish washer, dishwashing machine", "id": 747, "trainId": 145},
182
+ {"name": "plate", "id": 1919, "trainId": 146},
183
+ {"name": "screen, crt screen", "id": 3109, "trainId": 147},
184
+ {"name": "ruins", "id": 2179, "trainId": 148},
185
+ {"name": "washer, automatic washer, washing machine", "id": 2989, "trainId": 149},
186
+ {"name": "blanket, cover", "id": 206, "trainId": 150},
187
+ {"name": "plaything, toy", "id": 1930, "trainId": 151},
188
+ {"name": "food, solid food", "id": 1002, "trainId": 152},
189
+ {"name": "screen, silver screen, projection screen", "id": 2254, "trainId": 153},
190
+ {"name": "oven", "id": 1708, "trainId": 154},
191
+ {"name": "stage", "id": 2526, "trainId": 155},
192
+ {"name": "beacon, lighthouse, beacon light, pharos", "id": 160, "trainId": 156},
193
+ {"name": "umbrella", "id": 2901, "trainId": 157},
194
+ {"name": "sculpture", "id": 2262, "trainId": 158},
195
+ {"name": "aqueduct", "id": 44, "trainId": 159},
196
+ {"name": "container", "id": 597, "trainId": 160},
197
+ {"name": "scaffolding, staging", "id": 2235, "trainId": 161},
198
+ {"name": "hood, exhaust hood", "id": 1260, "trainId": 162},
199
+ {"name": "curb, curbing, kerb", "id": 682, "trainId": 163},
200
+ {"name": "roller coaster", "id": 2151, "trainId": 164},
201
+ {"name": "horse, equus caballus", "id": 3107, "trainId": 165},
202
+ {"name": "catwalk", "id": 432, "trainId": 166},
203
+ {"name": "glass, drinking glass", "id": 1098, "trainId": 167},
204
+ {"name": "vase", "id": 2932, "trainId": 168},
205
+ {"name": "central reservation", "id": 461, "trainId": 169},
206
+ {"name": "carousel", "id": 410, "trainId": 170},
207
+ {"name": "radiator", "id": 2046, "trainId": 171},
208
+ {"name": "closet", "id": 533, "trainId": 172},
209
+ {"name": "machine", "id": 1481, "trainId": 173},
210
+ {"name": "pier, wharf, wharfage, dock", "id": 1858, "trainId": 174},
211
+ {"name": "fan", "id": 894, "trainId": 175},
212
+ {"name": "inflatable bounce game", "id": 1322, "trainId": 176},
213
+ {"name": "pitch", "id": 1891, "trainId": 177},
214
+ {"name": "paper", "id": 1756, "trainId": 178},
215
+ {"name": "arcade, colonnade", "id": 49, "trainId": 179},
216
+ {"name": "hot tub", "id": 1272, "trainId": 180},
217
+ {"name": "helicopter", "id": 1229, "trainId": 181},
218
+ {"name": "tray", "id": 2850, "trainId": 182},
219
+ {"name": "partition, divider", "id": 1784, "trainId": 183},
220
+ {"name": "vineyard", "id": 2962, "trainId": 184},
221
+ {"name": "bowl", "id": 259, "trainId": 185},
222
+ {"name": "bullring", "id": 319, "trainId": 186},
223
+ {"name": "flag", "id": 954, "trainId": 187},
224
+ {"name": "pot", "id": 1974, "trainId": 188},
225
+ {"name": "footbridge, overcrossing, pedestrian bridge", "id": 1013, "trainId": 189},
226
+ {"name": "shower", "id": 2356, "trainId": 190},
227
+ {"name": "bag, traveling bag, travelling bag, grip, suitcase", "id": 97, "trainId": 191},
228
+ {"name": "bulletin board, notice board", "id": 318, "trainId": 192},
229
+ {"name": "confessional booth", "id": 592, "trainId": 193},
230
+ {"name": "trunk, tree trunk, bole", "id": 2885, "trainId": 194},
231
+ {"name": "forest", "id": 1017, "trainId": 195},
232
+ {"name": "elevator door", "id": 851, "trainId": 196},
233
+ {"name": "laptop, laptop computer", "id": 1407, "trainId": 197},
234
+ {"name": "instrument panel", "id": 1332, "trainId": 198},
235
+ {"name": "bucket, pail", "id": 303, "trainId": 199},
236
+ {"name": "tapestry, tapis", "id": 2714, "trainId": 200},
237
+ {"name": "platform", "id": 1924, "trainId": 201},
238
+ {"name": "jacket", "id": 1346, "trainId": 202},
239
+ {"name": "gate", "id": 1081, "trainId": 203},
240
+ {"name": "monitor, monitoring device", "id": 1583, "trainId": 204},
241
+ {
242
+ "name": "telephone booth, phone booth, call box, telephone box, telephone kiosk",
243
+ "id": 2727,
244
+ "trainId": 205,
245
+ },
246
+ {"name": "spotlight, spot", "id": 2509, "trainId": 206},
247
+ {"name": "ring", "id": 2123, "trainId": 207},
248
+ {"name": "control panel", "id": 602, "trainId": 208},
249
+ {"name": "blackboard, chalkboard", "id": 202, "trainId": 209},
250
+ {"name": "air conditioner, air conditioning", "id": 10, "trainId": 210},
251
+ {"name": "chest", "id": 490, "trainId": 211},
252
+ {"name": "clock", "id": 530, "trainId": 212},
253
+ {"name": "sand dune", "id": 2213, "trainId": 213},
254
+ {"name": "pipe, pipage, piping", "id": 1884, "trainId": 214},
255
+ {"name": "vault", "id": 2934, "trainId": 215},
256
+ {"name": "table football", "id": 2687, "trainId": 216},
257
+ {"name": "cannon", "id": 387, "trainId": 217},
258
+ {"name": "swimming pool, swimming bath, natatorium", "id": 2668, "trainId": 218},
259
+ {"name": "fluorescent, fluorescent fixture", "id": 982, "trainId": 219},
260
+ {"name": "statue", "id": 2547, "trainId": 220},
261
+ {
262
+ "name": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
263
+ "id": 1474,
264
+ "trainId": 221,
265
+ },
266
+ {"name": "exhibitor", "id": 877, "trainId": 222},
267
+ {"name": "ladder", "id": 1391, "trainId": 223},
268
+ {"name": "carport", "id": 414, "trainId": 224},
269
+ {"name": "dam", "id": 698, "trainId": 225},
270
+ {"name": "pulpit", "id": 2019, "trainId": 226},
271
+ {"name": "skylight, fanlight", "id": 2422, "trainId": 227},
272
+ {"name": "water tower", "id": 3010, "trainId": 228},
273
+ {"name": "grill, grille, grillwork", "id": 1139, "trainId": 229},
274
+ {"name": "display board", "id": 753, "trainId": 230},
275
+ {"name": "pane, pane of glass, window glass", "id": 1747, "trainId": 231},
276
+ {"name": "rubbish, trash, scrap", "id": 2175, "trainId": 232},
277
+ {"name": "ice rink", "id": 1301, "trainId": 233},
278
+ {"name": "fruit", "id": 1033, "trainId": 234},
279
+ {"name": "patio", "id": 1789, "trainId": 235},
280
+ {"name": "vending machine", "id": 2939, "trainId": 236},
281
+ {"name": "telephone, phone, telephone set", "id": 2730, "trainId": 237},
282
+ {"name": "net", "id": 1652, "trainId": 238},
283
+ {
284
+ "name": "backpack, back pack, knapsack, packsack, rucksack, haversack",
285
+ "id": 90,
286
+ "trainId": 239,
287
+ },
288
+ {"name": "jar", "id": 1349, "trainId": 240},
289
+ {"name": "track", "id": 2830, "trainId": 241},
290
+ {"name": "magazine", "id": 1485, "trainId": 242},
291
+ {"name": "shutter", "id": 2370, "trainId": 243},
292
+ {"name": "roof", "id": 2155, "trainId": 244},
293
+ {"name": "banner, streamer", "id": 118, "trainId": 245},
294
+ {"name": "landfill", "id": 1402, "trainId": 246},
295
+ {"name": "post", "id": 1957, "trainId": 247},
296
+ {"name": "altarpiece, reredos", "id": 3130, "trainId": 248},
297
+ {"name": "hat, chapeau, lid", "id": 1197, "trainId": 249},
298
+ {"name": "arch, archway", "id": 52, "trainId": 250},
299
+ {"name": "table game", "id": 2688, "trainId": 251},
300
+ {"name": "bag, handbag, pocketbook, purse", "id": 96, "trainId": 252},
301
+ {"name": "document, written document, papers", "id": 762, "trainId": 253},
302
+ {"name": "dome", "id": 772, "trainId": 254},
303
+ {"name": "pier", "id": 1857, "trainId": 255},
304
+ {"name": "shanties", "id": 2315, "trainId": 256},
305
+ {"name": "forecourt", "id": 1016, "trainId": 257},
306
+ {"name": "crane", "id": 643, "trainId": 258},
307
+ {"name": "dog, domestic dog, canis familiaris", "id": 3105, "trainId": 259},
308
+ {"name": "piano, pianoforte, forte-piano", "id": 1849, "trainId": 260},
309
+ {"name": "drawing", "id": 791, "trainId": 261},
310
+ {"name": "cabin", "id": 349, "trainId": 262},
311
+ {
312
+ "name": "ad, advertisement, advertizement, advertising, advertizing, advert",
313
+ "id": 6,
314
+ "trainId": 263,
315
+ },
316
+ {"name": "amphitheater, amphitheatre, coliseum", "id": 3114, "trainId": 264},
317
+ {"name": "monument", "id": 1587, "trainId": 265},
318
+ {"name": "henhouse", "id": 1233, "trainId": 266},
319
+ {"name": "cockpit", "id": 559, "trainId": 267},
320
+ {"name": "heater, warmer", "id": 1223, "trainId": 268},
321
+ {"name": "windmill, aerogenerator, wind generator", "id": 3049, "trainId": 269},
322
+ {"name": "pool", "id": 1943, "trainId": 270},
323
+ {"name": "elevator, lift", "id": 853, "trainId": 271},
324
+ {"name": "decoration, ornament, ornamentation", "id": 709, "trainId": 272},
325
+ {"name": "labyrinth", "id": 1390, "trainId": 273},
326
+ {"name": "text, textual matter", "id": 2748, "trainId": 274},
327
+ {"name": "printer", "id": 2007, "trainId": 275},
328
+ {"name": "mezzanine, first balcony", "id": 1546, "trainId": 276},
329
+ {"name": "mattress", "id": 1513, "trainId": 277},
330
+ {"name": "straw", "id": 2600, "trainId": 278},
331
+ {"name": "stalls", "id": 2538, "trainId": 279},
332
+ {"name": "patio, terrace", "id": 1790, "trainId": 280},
333
+ {"name": "billboard, hoarding", "id": 194, "trainId": 281},
334
+ {"name": "bus stop", "id": 326, "trainId": 282},
335
+ {"name": "trouser, pant", "id": 2877, "trainId": 283},
336
+ {"name": "console table, console", "id": 594, "trainId": 284},
337
+ {"name": "rack", "id": 2036, "trainId": 285},
338
+ {"name": "notebook", "id": 1662, "trainId": 286},
339
+ {"name": "shrine", "id": 2366, "trainId": 287},
340
+ {"name": "pantry", "id": 1754, "trainId": 288},
341
+ {"name": "cart", "id": 418, "trainId": 289},
342
+ {"name": "steam shovel", "id": 2553, "trainId": 290},
343
+ {"name": "porch", "id": 1951, "trainId": 291},
344
+ {"name": "postbox, mailbox, letter box", "id": 1963, "trainId": 292},
345
+ {"name": "figurine, statuette", "id": 918, "trainId": 293},
346
+ {"name": "recycling bin", "id": 2086, "trainId": 294},
347
+ {"name": "folding screen", "id": 997, "trainId": 295},
348
+ {"name": "telescope", "id": 2731, "trainId": 296},
349
+ {"name": "deck chair, beach chair", "id": 704, "trainId": 297},
350
+ {"name": "kennel", "id": 1365, "trainId": 298},
351
+ {"name": "coffee maker", "id": 569, "trainId": 299},
352
+ {"name": "altar, communion table, lord's table", "id": 3108, "trainId": 300},
353
+ {"name": "fish", "id": 948, "trainId": 301},
354
+ {"name": "easel", "id": 839, "trainId": 302},
355
+ {"name": "artificial golf green", "id": 63, "trainId": 303},
356
+ {"name": "iceberg", "id": 1305, "trainId": 304},
357
+ {"name": "candlestick, candle holder", "id": 378, "trainId": 305},
358
+ {"name": "shower stall, shower bath", "id": 2362, "trainId": 306},
359
+ {"name": "television stand", "id": 2734, "trainId": 307},
360
+ {
361
+ "name": "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle",
362
+ "id": 2982,
363
+ "trainId": 308,
364
+ },
365
+ {"name": "skeleton", "id": 2398, "trainId": 309},
366
+ {"name": "grand piano, grand", "id": 1119, "trainId": 310},
367
+ {"name": "candy, confect", "id": 382, "trainId": 311},
368
+ {"name": "grille door", "id": 1141, "trainId": 312},
369
+ {"name": "pedestal, plinth, footstall", "id": 1805, "trainId": 313},
370
+ {"name": "jersey, t-shirt, tee shirt", "id": 3102, "trainId": 314},
371
+ {"name": "shoe", "id": 2341, "trainId": 315},
372
+ {"name": "gravestone, headstone, tombstone", "id": 1131, "trainId": 316},
373
+ {"name": "shanty", "id": 2316, "trainId": 317},
374
+ {"name": "structure", "id": 2626, "trainId": 318},
375
+ {"name": "rocking chair, rocker", "id": 3104, "trainId": 319},
376
+ {"name": "bird", "id": 198, "trainId": 320},
377
+ {"name": "place mat", "id": 1896, "trainId": 321},
378
+ {"name": "tomb", "id": 2800, "trainId": 322},
379
+ {"name": "big top", "id": 190, "trainId": 323},
380
+ {"name": "gas pump, gasoline pump, petrol pump, island dispenser", "id": 3131, "trainId": 324},
381
+ {"name": "lockers", "id": 1463, "trainId": 325},
382
+ {"name": "cage", "id": 357, "trainId": 326},
383
+ {"name": "finger", "id": 929, "trainId": 327},
384
+ {"name": "bleachers", "id": 209, "trainId": 328},
385
+ {"name": "ferris wheel", "id": 912, "trainId": 329},
386
+ {"name": "hairdresser chair", "id": 1164, "trainId": 330},
387
+ {"name": "mat", "id": 1509, "trainId": 331},
388
+ {"name": "stands", "id": 2539, "trainId": 332},
389
+ {"name": "aquarium, fish tank, marine museum", "id": 3116, "trainId": 333},
390
+ {"name": "streetcar, tram, tramcar, trolley, trolley car", "id": 2615, "trainId": 334},
391
+ {"name": "napkin, table napkin, serviette", "id": 1644, "trainId": 335},
392
+ {"name": "dummy", "id": 818, "trainId": 336},
393
+ {"name": "booklet, brochure, folder, leaflet, pamphlet", "id": 242, "trainId": 337},
394
+ {"name": "sand trap", "id": 2217, "trainId": 338},
395
+ {"name": "shop, store", "id": 2347, "trainId": 339},
396
+ {"name": "table cloth", "id": 2686, "trainId": 340},
397
+ {"name": "service station", "id": 2300, "trainId": 341},
398
+ {"name": "coffin", "id": 572, "trainId": 342},
399
+ {"name": "drawer", "id": 789, "trainId": 343},
400
+ {"name": "cages", "id": 358, "trainId": 344},
401
+ {"name": "slot machine, coin machine", "id": 2443, "trainId": 345},
402
+ {"name": "balcony", "id": 101, "trainId": 346},
403
+ {"name": "volleyball court", "id": 2969, "trainId": 347},
404
+ {"name": "table tennis", "id": 2692, "trainId": 348},
405
+ {"name": "control table", "id": 606, "trainId": 349},
406
+ {"name": "shirt", "id": 2339, "trainId": 350},
407
+ {"name": "merchandise, ware, product", "id": 1533, "trainId": 351},
408
+ {"name": "railway", "id": 2060, "trainId": 352},
409
+ {"name": "parterre", "id": 1782, "trainId": 353},
410
+ {"name": "chimney", "id": 495, "trainId": 354},
411
+ {"name": "can, tin, tin can", "id": 371, "trainId": 355},
412
+ {"name": "tanks", "id": 2707, "trainId": 356},
413
+ {"name": "fabric, cloth, material, textile", "id": 889, "trainId": 357},
414
+ {"name": "alga, algae", "id": 3156, "trainId": 358},
415
+ {"name": "system", "id": 2683, "trainId": 359},
416
+ {"name": "map", "id": 1499, "trainId": 360},
417
+ {"name": "greenhouse", "id": 1135, "trainId": 361},
418
+ {"name": "mug", "id": 1619, "trainId": 362},
419
+ {"name": "barbecue", "id": 125, "trainId": 363},
420
+ {"name": "trailer", "id": 2838, "trainId": 364},
421
+ {"name": "toilet tissue, toilet paper, bathroom tissue", "id": 2792, "trainId": 365},
422
+ {"name": "organ", "id": 1695, "trainId": 366},
423
+ {"name": "dishrag, dishcloth", "id": 746, "trainId": 367},
424
+ {"name": "island", "id": 1343, "trainId": 368},
425
+ {"name": "keyboard", "id": 1370, "trainId": 369},
426
+ {"name": "trench", "id": 2858, "trainId": 370},
427
+ {"name": "basket, basketball hoop, hoop", "id": 145, "trainId": 371},
428
+ {"name": "steering wheel, wheel", "id": 2565, "trainId": 372},
429
+ {"name": "pitcher, ewer", "id": 1892, "trainId": 373},
430
+ {"name": "goal", "id": 1103, "trainId": 374},
431
+ {"name": "bread, breadstuff, staff of life", "id": 286, "trainId": 375},
432
+ {"name": "beds", "id": 170, "trainId": 376},
433
+ {"name": "wood", "id": 3073, "trainId": 377},
434
+ {"name": "file cabinet", "id": 922, "trainId": 378},
435
+ {"name": "newspaper, paper", "id": 1655, "trainId": 379},
436
+ {"name": "motorboat", "id": 1602, "trainId": 380},
437
+ {"name": "rope", "id": 2160, "trainId": 381},
438
+ {"name": "guitar", "id": 1151, "trainId": 382},
439
+ {"name": "rubble", "id": 2176, "trainId": 383},
440
+ {"name": "scarf", "id": 2239, "trainId": 384},
441
+ {"name": "barrels", "id": 132, "trainId": 385},
442
+ {"name": "cap", "id": 394, "trainId": 386},
443
+ {"name": "leaves", "id": 1424, "trainId": 387},
444
+ {"name": "control tower", "id": 607, "trainId": 388},
445
+ {"name": "dashboard", "id": 700, "trainId": 389},
446
+ {"name": "bandstand", "id": 116, "trainId": 390},
447
+ {"name": "lectern", "id": 1425, "trainId": 391},
448
+ {"name": "switch, electric switch, electrical switch", "id": 2676, "trainId": 392},
449
+ {"name": "baseboard, mopboard, skirting board", "id": 141, "trainId": 393},
450
+ {"name": "shower room", "id": 2360, "trainId": 394},
451
+ {"name": "smoke", "id": 2449, "trainId": 395},
452
+ {"name": "faucet, spigot", "id": 897, "trainId": 396},
453
+ {"name": "bulldozer", "id": 317, "trainId": 397},
454
+ {"name": "saucepan", "id": 2228, "trainId": 398},
455
+ {"name": "shops", "id": 2351, "trainId": 399},
456
+ {"name": "meter", "id": 1543, "trainId": 400},
457
+ {"name": "crevasse", "id": 656, "trainId": 401},
458
+ {"name": "gear", "id": 1088, "trainId": 402},
459
+ {"name": "candelabrum, candelabra", "id": 373, "trainId": 403},
460
+ {"name": "sofa bed", "id": 2472, "trainId": 404},
461
+ {"name": "tunnel", "id": 2892, "trainId": 405},
462
+ {"name": "pallet", "id": 1740, "trainId": 406},
463
+ {"name": "wire, conducting wire", "id": 3067, "trainId": 407},
464
+ {"name": "kettle, boiler", "id": 1367, "trainId": 408},
465
+ {"name": "bidet", "id": 188, "trainId": 409},
466
+ {
467
+ "name": "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher",
468
+ "id": 79,
469
+ "trainId": 410,
470
+ },
471
+ {"name": "music stand", "id": 1633, "trainId": 411},
472
+ {"name": "pipe, tube", "id": 1885, "trainId": 412},
473
+ {"name": "cup", "id": 677, "trainId": 413},
474
+ {"name": "parking meter", "id": 1779, "trainId": 414},
475
+ {"name": "ice hockey rink", "id": 1297, "trainId": 415},
476
+ {"name": "shelter", "id": 2334, "trainId": 416},
477
+ {"name": "weeds", "id": 3027, "trainId": 417},
478
+ {"name": "temple", "id": 2735, "trainId": 418},
479
+ {"name": "patty, cake", "id": 1791, "trainId": 419},
480
+ {"name": "ski slope", "id": 2405, "trainId": 420},
481
+ {"name": "panel", "id": 1748, "trainId": 421},
482
+ {"name": "wallet", "id": 2983, "trainId": 422},
483
+ {"name": "wheel", "id": 3035, "trainId": 423},
484
+ {"name": "towel rack, towel horse", "id": 2824, "trainId": 424},
485
+ {"name": "roundabout", "id": 2168, "trainId": 425},
486
+ {"name": "canister, cannister, tin", "id": 385, "trainId": 426},
487
+ {"name": "rod", "id": 2148, "trainId": 427},
488
+ {"name": "soap dispenser", "id": 2465, "trainId": 428},
489
+ {"name": "bell", "id": 175, "trainId": 429},
490
+ {"name": "canvas", "id": 390, "trainId": 430},
491
+ {"name": "box office, ticket office, ticket booth", "id": 268, "trainId": 431},
492
+ {"name": "teacup", "id": 2722, "trainId": 432},
493
+ {"name": "trellis", "id": 2857, "trainId": 433},
494
+ {"name": "workbench", "id": 3088, "trainId": 434},
495
+ {"name": "valley, vale", "id": 2926, "trainId": 435},
496
+ {"name": "toaster", "id": 2782, "trainId": 436},
497
+ {"name": "knife", "id": 1378, "trainId": 437},
498
+ {"name": "podium", "id": 1934, "trainId": 438},
499
+ {"name": "ramp", "id": 2072, "trainId": 439},
500
+ {"name": "tumble dryer", "id": 2889, "trainId": 440},
501
+ {"name": "fireplug, fire hydrant, plug", "id": 944, "trainId": 441},
502
+ {"name": "gym shoe, sneaker, tennis shoe", "id": 1158, "trainId": 442},
503
+ {"name": "lab bench", "id": 1383, "trainId": 443},
504
+ {"name": "equipment", "id": 867, "trainId": 444},
505
+ {"name": "rocky formation", "id": 2145, "trainId": 445},
506
+ {"name": "plastic", "id": 1915, "trainId": 446},
507
+ {"name": "calendar", "id": 361, "trainId": 447},
508
+ {"name": "caravan", "id": 402, "trainId": 448},
509
+ {"name": "check-in-desk", "id": 482, "trainId": 449},
510
+ {"name": "ticket counter", "id": 2761, "trainId": 450},
511
+ {"name": "brush", "id": 300, "trainId": 451},
512
+ {"name": "mill", "id": 1554, "trainId": 452},
513
+ {"name": "covered bridge", "id": 636, "trainId": 453},
514
+ {"name": "bowling alley", "id": 260, "trainId": 454},
515
+ {"name": "hanger", "id": 1186, "trainId": 455},
516
+ {"name": "excavator", "id": 871, "trainId": 456},
517
+ {"name": "trestle", "id": 2859, "trainId": 457},
518
+ {"name": "revolving door", "id": 2103, "trainId": 458},
519
+ {"name": "blast furnace", "id": 208, "trainId": 459},
520
+ {"name": "scale, weighing machine", "id": 2236, "trainId": 460},
521
+ {"name": "projector", "id": 2012, "trainId": 461},
522
+ {"name": "soap", "id": 2462, "trainId": 462},
523
+ {"name": "locker", "id": 1462, "trainId": 463},
524
+ {"name": "tractor", "id": 2832, "trainId": 464},
525
+ {"name": "stretcher", "id": 2617, "trainId": 465},
526
+ {"name": "frame", "id": 1024, "trainId": 466},
527
+ {"name": "grating", "id": 1129, "trainId": 467},
528
+ {"name": "alembic", "id": 18, "trainId": 468},
529
+ {"name": "candle, taper, wax light", "id": 376, "trainId": 469},
530
+ {"name": "barrier", "id": 134, "trainId": 470},
531
+ {"name": "cardboard", "id": 407, "trainId": 471},
532
+ {"name": "cave", "id": 434, "trainId": 472},
533
+ {"name": "puddle", "id": 2017, "trainId": 473},
534
+ {"name": "tarp", "id": 2717, "trainId": 474},
535
+ {"name": "price tag", "id": 2005, "trainId": 475},
536
+ {"name": "watchtower", "id": 2993, "trainId": 476},
537
+ {"name": "meters", "id": 1545, "trainId": 477},
538
+ {
539
+ "name": "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb",
540
+ "id": 1445,
541
+ "trainId": 478,
542
+ },
543
+ {"name": "tracks", "id": 2831, "trainId": 479},
544
+ {"name": "hair dryer", "id": 1161, "trainId": 480},
545
+ {"name": "skirt", "id": 2411, "trainId": 481},
546
+ {"name": "viaduct", "id": 2949, "trainId": 482},
547
+ {"name": "paper towel", "id": 1769, "trainId": 483},
548
+ {"name": "coat", "id": 552, "trainId": 484},
549
+ {"name": "sheet", "id": 2327, "trainId": 485},
550
+ {"name": "fire extinguisher, extinguisher, asphyxiator", "id": 939, "trainId": 486},
551
+ {"name": "water wheel", "id": 3013, "trainId": 487},
552
+ {"name": "pottery, clayware", "id": 1986, "trainId": 488},
553
+ {"name": "magazine rack", "id": 1486, "trainId": 489},
554
+ {"name": "teapot", "id": 2723, "trainId": 490},
555
+ {"name": "microphone, mike", "id": 1549, "trainId": 491},
556
+ {"name": "support", "id": 2649, "trainId": 492},
557
+ {"name": "forklift", "id": 1020, "trainId": 493},
558
+ {"name": "canyon", "id": 392, "trainId": 494},
559
+ {"name": "cash register, register", "id": 422, "trainId": 495},
560
+ {"name": "leaf, leafage, foliage", "id": 1419, "trainId": 496},
561
+ {"name": "remote control, remote", "id": 2099, "trainId": 497},
562
+ {"name": "soap dish", "id": 2464, "trainId": 498},
563
+ {"name": "windshield, windscreen", "id": 3058, "trainId": 499},
564
+ {"name": "cat", "id": 430, "trainId": 500},
565
+ {"name": "cue, cue stick, pool cue, pool stick", "id": 675, "trainId": 501},
566
+ {"name": "vent, venthole, vent-hole, blowhole", "id": 2941, "trainId": 502},
567
+ {"name": "videos", "id": 2955, "trainId": 503},
568
+ {"name": "shovel", "id": 2355, "trainId": 504},
569
+ {"name": "eaves", "id": 840, "trainId": 505},
570
+ {"name": "antenna, aerial, transmitting aerial", "id": 32, "trainId": 506},
571
+ {"name": "shipyard", "id": 2338, "trainId": 507},
572
+ {"name": "hen, biddy", "id": 1232, "trainId": 508},
573
+ {"name": "traffic cone", "id": 2834, "trainId": 509},
574
+ {"name": "washing machines", "id": 2991, "trainId": 510},
575
+ {"name": "truck crane", "id": 2879, "trainId": 511},
576
+ {"name": "cds", "id": 444, "trainId": 512},
577
+ {"name": "niche", "id": 1657, "trainId": 513},
578
+ {"name": "scoreboard", "id": 2246, "trainId": 514},
579
+ {"name": "briefcase", "id": 296, "trainId": 515},
580
+ {"name": "boot", "id": 245, "trainId": 516},
581
+ {"name": "sweater, jumper", "id": 2661, "trainId": 517},
582
+ {"name": "hay", "id": 1202, "trainId": 518},
583
+ {"name": "pack", "id": 1714, "trainId": 519},
584
+ {"name": "bottle rack", "id": 251, "trainId": 520},
585
+ {"name": "glacier", "id": 1095, "trainId": 521},
586
+ {"name": "pergola", "id": 1828, "trainId": 522},
587
+ {"name": "building materials", "id": 311, "trainId": 523},
588
+ {"name": "television camera", "id": 2732, "trainId": 524},
589
+ {"name": "first floor", "id": 947, "trainId": 525},
590
+ {"name": "rifle", "id": 2115, "trainId": 526},
591
+ {"name": "tennis table", "id": 2738, "trainId": 527},
592
+ {"name": "stadium", "id": 2525, "trainId": 528},
593
+ {"name": "safety belt", "id": 2194, "trainId": 529},
594
+ {"name": "cover", "id": 634, "trainId": 530},
595
+ {"name": "dish rack", "id": 740, "trainId": 531},
596
+ {"name": "synthesizer", "id": 2682, "trainId": 532},
597
+ {"name": "pumpkin", "id": 2020, "trainId": 533},
598
+ {"name": "gutter", "id": 1156, "trainId": 534},
599
+ {"name": "fruit stand", "id": 1036, "trainId": 535},
600
+ {"name": "ice floe, floe", "id": 1295, "trainId": 536},
601
+ {"name": "handle, grip, handgrip, hold", "id": 1181, "trainId": 537},
602
+ {"name": "wheelchair", "id": 3037, "trainId": 538},
603
+ {"name": "mousepad, mouse mat", "id": 1614, "trainId": 539},
604
+ {"name": "diploma", "id": 736, "trainId": 540},
605
+ {"name": "fairground ride", "id": 893, "trainId": 541},
606
+ {"name": "radio", "id": 2047, "trainId": 542},
607
+ {"name": "hotplate", "id": 1274, "trainId": 543},
608
+ {"name": "junk", "id": 1361, "trainId": 544},
609
+ {"name": "wheelbarrow", "id": 3036, "trainId": 545},
610
+ {"name": "stream", "id": 2606, "trainId": 546},
611
+ {"name": "toll plaza", "id": 2797, "trainId": 547},
612
+ {"name": "punching bag", "id": 2022, "trainId": 548},
613
+ {"name": "trough", "id": 2876, "trainId": 549},
614
+ {"name": "throne", "id": 2758, "trainId": 550},
615
+ {"name": "chair desk", "id": 472, "trainId": 551},
616
+ {"name": "weighbridge", "id": 3028, "trainId": 552},
617
+ {"name": "extractor fan", "id": 882, "trainId": 553},
618
+ {"name": "hanging clothes", "id": 1189, "trainId": 554},
619
+ {"name": "dish, dish aerial, dish antenna, saucer", "id": 743, "trainId": 555},
620
+ {"name": "alarm clock, alarm", "id": 3122, "trainId": 556},
621
+ {"name": "ski lift", "id": 2401, "trainId": 557},
622
+ {"name": "chain", "id": 468, "trainId": 558},
623
+ {"name": "garage", "id": 1061, "trainId": 559},
624
+ {"name": "mechanical shovel", "id": 1523, "trainId": 560},
625
+ {"name": "wine rack", "id": 3059, "trainId": 561},
626
+ {"name": "tramway", "id": 2843, "trainId": 562},
627
+ {"name": "treadmill", "id": 2853, "trainId": 563},
628
+ {"name": "menu", "id": 1529, "trainId": 564},
629
+ {"name": "block", "id": 214, "trainId": 565},
630
+ {"name": "well", "id": 3032, "trainId": 566},
631
+ {"name": "witness stand", "id": 3071, "trainId": 567},
632
+ {"name": "branch", "id": 277, "trainId": 568},
633
+ {"name": "duck", "id": 813, "trainId": 569},
634
+ {"name": "casserole", "id": 426, "trainId": 570},
635
+ {"name": "frying pan", "id": 1039, "trainId": 571},
636
+ {"name": "desk organizer", "id": 727, "trainId": 572},
637
+ {"name": "mast", "id": 1508, "trainId": 573},
638
+ {"name": "spectacles, specs, eyeglasses, glasses", "id": 2490, "trainId": 574},
639
+ {"name": "service elevator", "id": 2299, "trainId": 575},
640
+ {"name": "dollhouse", "id": 768, "trainId": 576},
641
+ {"name": "hammock", "id": 1172, "trainId": 577},
642
+ {"name": "clothes hanging", "id": 537, "trainId": 578},
643
+ {"name": "photocopier", "id": 1847, "trainId": 579},
644
+ {"name": "notepad", "id": 1664, "trainId": 580},
645
+ {"name": "golf cart", "id": 1110, "trainId": 581},
646
+ {"name": "footpath", "id": 1014, "trainId": 582},
647
+ {"name": "cross", "id": 662, "trainId": 583},
648
+ {"name": "baptismal font", "id": 121, "trainId": 584},
649
+ {"name": "boiler", "id": 227, "trainId": 585},
650
+ {"name": "skip", "id": 2410, "trainId": 586},
651
+ {"name": "rotisserie", "id": 2165, "trainId": 587},
652
+ {"name": "tables", "id": 2696, "trainId": 588},
653
+ {"name": "water mill", "id": 3005, "trainId": 589},
654
+ {"name": "helmet", "id": 1231, "trainId": 590},
655
+ {"name": "cover curtain", "id": 635, "trainId": 591},
656
+ {"name": "brick", "id": 292, "trainId": 592},
657
+ {"name": "table runner", "id": 2690, "trainId": 593},
658
+ {"name": "ashtray", "id": 65, "trainId": 594},
659
+ {"name": "street box", "id": 2607, "trainId": 595},
660
+ {"name": "stick", "id": 2574, "trainId": 596},
661
+ {"name": "hangers", "id": 1188, "trainId": 597},
662
+ {"name": "cells", "id": 456, "trainId": 598},
663
+ {"name": "urinal", "id": 2913, "trainId": 599},
664
+ {"name": "centerpiece", "id": 459, "trainId": 600},
665
+ {"name": "portable fridge", "id": 1955, "trainId": 601},
666
+ {"name": "dvds", "id": 827, "trainId": 602},
667
+ {"name": "golf club", "id": 1111, "trainId": 603},
668
+ {"name": "skirting board", "id": 2412, "trainId": 604},
669
+ {"name": "water cooler", "id": 2997, "trainId": 605},
670
+ {"name": "clipboard", "id": 528, "trainId": 606},
671
+ {"name": "camera, photographic camera", "id": 366, "trainId": 607},
672
+ {"name": "pigeonhole", "id": 1863, "trainId": 608},
673
+ {"name": "chips", "id": 500, "trainId": 609},
674
+ {"name": "food processor", "id": 1001, "trainId": 610},
675
+ {"name": "post box", "id": 1958, "trainId": 611},
676
+ {"name": "lid", "id": 1441, "trainId": 612},
677
+ {"name": "drum", "id": 809, "trainId": 613},
678
+ {"name": "blender", "id": 210, "trainId": 614},
679
+ {"name": "cave entrance", "id": 435, "trainId": 615},
680
+ {"name": "dental chair", "id": 718, "trainId": 616},
681
+ {"name": "obelisk", "id": 1674, "trainId": 617},
682
+ {"name": "canoe", "id": 388, "trainId": 618},
683
+ {"name": "mobile", "id": 1572, "trainId": 619},
684
+ {"name": "monitors", "id": 1584, "trainId": 620},
685
+ {"name": "pool ball", "id": 1944, "trainId": 621},
686
+ {"name": "cue rack", "id": 674, "trainId": 622},
687
+ {"name": "baggage carts", "id": 99, "trainId": 623},
688
+ {"name": "shore", "id": 2352, "trainId": 624},
689
+ {"name": "fork", "id": 1019, "trainId": 625},
690
+ {"name": "paper filer", "id": 1763, "trainId": 626},
691
+ {"name": "bicycle rack", "id": 185, "trainId": 627},
692
+ {"name": "coat rack", "id": 554, "trainId": 628},
693
+ {"name": "garland", "id": 1066, "trainId": 629},
694
+ {"name": "sports bag", "id": 2508, "trainId": 630},
695
+ {"name": "fish tank", "id": 951, "trainId": 631},
696
+ {"name": "towel dispenser", "id": 2822, "trainId": 632},
697
+ {"name": "carriage", "id": 415, "trainId": 633},
698
+ {"name": "brochure", "id": 297, "trainId": 634},
699
+ {"name": "plaque", "id": 1914, "trainId": 635},
700
+ {"name": "stringer", "id": 2619, "trainId": 636},
701
+ {"name": "iron", "id": 1338, "trainId": 637},
702
+ {"name": "spoon", "id": 2505, "trainId": 638},
703
+ {"name": "flag pole", "id": 955, "trainId": 639},
704
+ {"name": "toilet brush", "id": 2786, "trainId": 640},
705
+ {"name": "book stand", "id": 238, "trainId": 641},
706
+ {"name": "water faucet, water tap, tap, hydrant", "id": 3000, "trainId": 642},
707
+ {"name": "ticket office", "id": 2763, "trainId": 643},
708
+ {"name": "broom", "id": 299, "trainId": 644},
709
+ {"name": "dvd", "id": 822, "trainId": 645},
710
+ {"name": "ice bucket", "id": 1288, "trainId": 646},
711
+ {"name": "carapace, shell, cuticle, shield", "id": 3101, "trainId": 647},
712
+ {"name": "tureen", "id": 2894, "trainId": 648},
713
+ {"name": "folders", "id": 992, "trainId": 649},
714
+ {"name": "chess", "id": 489, "trainId": 650},
715
+ {"name": "root", "id": 2157, "trainId": 651},
716
+ {"name": "sewing machine", "id": 2309, "trainId": 652},
717
+ {"name": "model", "id": 1576, "trainId": 653},
718
+ {"name": "pen", "id": 1810, "trainId": 654},
719
+ {"name": "violin", "id": 2964, "trainId": 655},
720
+ {"name": "sweatshirt", "id": 2662, "trainId": 656},
721
+ {"name": "recycling materials", "id": 2087, "trainId": 657},
722
+ {"name": "mitten", "id": 1569, "trainId": 658},
723
+ {"name": "chopping board, cutting board", "id": 503, "trainId": 659},
724
+ {"name": "mask", "id": 1505, "trainId": 660},
725
+ {"name": "log", "id": 1468, "trainId": 661},
726
+ {"name": "mouse, computer mouse", "id": 1613, "trainId": 662},
727
+ {"name": "grill", "id": 1138, "trainId": 663},
728
+ {"name": "hole", "id": 1256, "trainId": 664},
729
+ {"name": "target", "id": 2715, "trainId": 665},
730
+ {"name": "trash bag", "id": 2846, "trainId": 666},
731
+ {"name": "chalk", "id": 477, "trainId": 667},
732
+ {"name": "sticks", "id": 2576, "trainId": 668},
733
+ {"name": "balloon", "id": 108, "trainId": 669},
734
+ {"name": "score", "id": 2245, "trainId": 670},
735
+ {"name": "hair spray", "id": 1162, "trainId": 671},
736
+ {"name": "roll", "id": 2149, "trainId": 672},
737
+ {"name": "runner", "id": 2183, "trainId": 673},
738
+ {"name": "engine", "id": 858, "trainId": 674},
739
+ {"name": "inflatable glove", "id": 1324, "trainId": 675},
740
+ {"name": "games", "id": 1055, "trainId": 676},
741
+ {"name": "pallets", "id": 1741, "trainId": 677},
742
+ {"name": "baskets", "id": 149, "trainId": 678},
743
+ {"name": "coop", "id": 615, "trainId": 679},
744
+ {"name": "dvd player", "id": 825, "trainId": 680},
745
+ {"name": "rocking horse", "id": 2143, "trainId": 681},
746
+ {"name": "buckets", "id": 304, "trainId": 682},
747
+ {"name": "bread rolls", "id": 283, "trainId": 683},
748
+ {"name": "shawl", "id": 2322, "trainId": 684},
749
+ {"name": "watering can", "id": 3017, "trainId": 685},
750
+ {"name": "spotlights", "id": 2510, "trainId": 686},
751
+ {"name": "post-it", "id": 1960, "trainId": 687},
752
+ {"name": "bowls", "id": 265, "trainId": 688},
753
+ {"name": "security camera", "id": 2282, "trainId": 689},
754
+ {"name": "runner cloth", "id": 2184, "trainId": 690},
755
+ {"name": "lock", "id": 1461, "trainId": 691},
756
+ {"name": "alarm, warning device, alarm system", "id": 3113, "trainId": 692},
757
+ {"name": "side", "id": 2372, "trainId": 693},
758
+ {"name": "roulette", "id": 2166, "trainId": 694},
759
+ {"name": "bone", "id": 232, "trainId": 695},
760
+ {"name": "cutlery", "id": 693, "trainId": 696},
761
+ {"name": "pool balls", "id": 1945, "trainId": 697},
762
+ {"name": "wheels", "id": 3039, "trainId": 698},
763
+ {"name": "spice rack", "id": 2494, "trainId": 699},
764
+ {"name": "plant pots", "id": 1908, "trainId": 700},
765
+ {"name": "towel ring", "id": 2827, "trainId": 701},
766
+ {"name": "bread box", "id": 280, "trainId": 702},
767
+ {"name": "video", "id": 2950, "trainId": 703},
768
+ {"name": "funfair", "id": 1044, "trainId": 704},
769
+ {"name": "breads", "id": 288, "trainId": 705},
770
+ {"name": "tripod", "id": 2863, "trainId": 706},
771
+ {"name": "ironing board", "id": 1342, "trainId": 707},
772
+ {"name": "skimmer", "id": 2409, "trainId": 708},
773
+ {"name": "hollow", "id": 1258, "trainId": 709},
774
+ {"name": "scratching post", "id": 2249, "trainId": 710},
775
+ {"name": "tricycle", "id": 2862, "trainId": 711},
776
+ {"name": "file box", "id": 920, "trainId": 712},
777
+ {"name": "mountain pass", "id": 1607, "trainId": 713},
778
+ {"name": "tombstones", "id": 2802, "trainId": 714},
779
+ {"name": "cooker", "id": 610, "trainId": 715},
780
+ {"name": "card game, cards", "id": 3129, "trainId": 716},
781
+ {"name": "golf bag", "id": 1108, "trainId": 717},
782
+ {"name": "towel paper", "id": 2823, "trainId": 718},
783
+ {"name": "chaise lounge", "id": 476, "trainId": 719},
784
+ {"name": "sun", "id": 2641, "trainId": 720},
785
+ {"name": "toilet paper holder", "id": 2788, "trainId": 721},
786
+ {"name": "rake", "id": 2070, "trainId": 722},
787
+ {"name": "key", "id": 1368, "trainId": 723},
788
+ {"name": "umbrella stand", "id": 2903, "trainId": 724},
789
+ {"name": "dartboard", "id": 699, "trainId": 725},
790
+ {"name": "transformer", "id": 2844, "trainId": 726},
791
+ {"name": "fireplace utensils", "id": 942, "trainId": 727},
792
+ {"name": "sweatshirts", "id": 2663, "trainId": 728},
793
+ {
794
+ "name": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
795
+ "id": 457,
796
+ "trainId": 729,
797
+ },
798
+ {"name": "tallboy", "id": 2701, "trainId": 730},
799
+ {"name": "stapler", "id": 2540, "trainId": 731},
800
+ {"name": "sauna", "id": 2231, "trainId": 732},
801
+ {"name": "test tube", "id": 2746, "trainId": 733},
802
+ {"name": "palette", "id": 1738, "trainId": 734},
803
+ {"name": "shopping carts", "id": 2350, "trainId": 735},
804
+ {"name": "tools", "id": 2808, "trainId": 736},
805
+ {"name": "push button, push, button", "id": 2025, "trainId": 737},
806
+ {"name": "star", "id": 2541, "trainId": 738},
807
+ {"name": "roof rack", "id": 2156, "trainId": 739},
808
+ {"name": "barbed wire", "id": 126, "trainId": 740},
809
+ {"name": "spray", "id": 2512, "trainId": 741},
810
+ {"name": "ear", "id": 831, "trainId": 742},
811
+ {"name": "sponge", "id": 2503, "trainId": 743},
812
+ {"name": "racket", "id": 2039, "trainId": 744},
813
+ {"name": "tins", "id": 2774, "trainId": 745},
814
+ {"name": "eyeglasses", "id": 886, "trainId": 746},
815
+ {"name": "file", "id": 919, "trainId": 747},
816
+ {"name": "scarfs", "id": 2240, "trainId": 748},
817
+ {"name": "sugar bowl", "id": 2636, "trainId": 749},
818
+ {"name": "flip flop", "id": 963, "trainId": 750},
819
+ {"name": "headstones", "id": 1218, "trainId": 751},
820
+ {"name": "laptop bag", "id": 1406, "trainId": 752},
821
+ {"name": "leash", "id": 1420, "trainId": 753},
822
+ {"name": "climbing frame", "id": 526, "trainId": 754},
823
+ {"name": "suit hanger", "id": 2639, "trainId": 755},
824
+ {"name": "floor spotlight", "id": 975, "trainId": 756},
825
+ {"name": "plate rack", "id": 1921, "trainId": 757},
826
+ {"name": "sewer", "id": 2305, "trainId": 758},
827
+ {"name": "hard drive", "id": 1193, "trainId": 759},
828
+ {"name": "sprinkler", "id": 2517, "trainId": 760},
829
+ {"name": "tools box", "id": 2809, "trainId": 761},
830
+ {"name": "necklace", "id": 1647, "trainId": 762},
831
+ {"name": "bulbs", "id": 314, "trainId": 763},
832
+ {"name": "steel industry", "id": 2560, "trainId": 764},
833
+ {"name": "club", "id": 545, "trainId": 765},
834
+ {"name": "jack", "id": 1345, "trainId": 766},
835
+ {"name": "door bars", "id": 775, "trainId": 767},
836
+ {
837
+ "name": "control panel, instrument panel, control board, board, panel",
838
+ "id": 603,
839
+ "trainId": 768,
840
+ },
841
+ {"name": "hairbrush", "id": 1163, "trainId": 769},
842
+ {"name": "napkin holder", "id": 1641, "trainId": 770},
843
+ {"name": "office", "id": 1678, "trainId": 771},
844
+ {"name": "smoke detector", "id": 2450, "trainId": 772},
845
+ {"name": "utensils", "id": 2915, "trainId": 773},
846
+ {"name": "apron", "id": 42, "trainId": 774},
847
+ {"name": "scissors", "id": 2242, "trainId": 775},
848
+ {"name": "terminal", "id": 2741, "trainId": 776},
849
+ {"name": "grinder", "id": 1143, "trainId": 777},
850
+ {"name": "entry phone", "id": 862, "trainId": 778},
851
+ {"name": "newspaper stand", "id": 1654, "trainId": 779},
852
+ {"name": "pepper shaker", "id": 1826, "trainId": 780},
853
+ {"name": "onions", "id": 1689, "trainId": 781},
854
+ {
855
+ "name": "central processing unit, cpu, c p u , central processor, processor, mainframe",
856
+ "id": 3124,
857
+ "trainId": 782,
858
+ },
859
+ {"name": "tape", "id": 2710, "trainId": 783},
860
+ {"name": "bat", "id": 152, "trainId": 784},
861
+ {"name": "coaster", "id": 549, "trainId": 785},
862
+ {"name": "calculator", "id": 360, "trainId": 786},
863
+ {"name": "potatoes", "id": 1982, "trainId": 787},
864
+ {"name": "luggage rack", "id": 1478, "trainId": 788},
865
+ {"name": "salt", "id": 2203, "trainId": 789},
866
+ {"name": "street number", "id": 2612, "trainId": 790},
867
+ {"name": "viewpoint", "id": 2956, "trainId": 791},
868
+ {"name": "sword", "id": 2681, "trainId": 792},
869
+ {"name": "cd", "id": 437, "trainId": 793},
870
+ {"name": "rowing machine", "id": 2171, "trainId": 794},
871
+ {"name": "plug", "id": 1933, "trainId": 795},
872
+ {"name": "andiron, firedog, dog, dog-iron", "id": 3110, "trainId": 796},
873
+ {"name": "pepper", "id": 1824, "trainId": 797},
874
+ {"name": "tongs", "id": 2803, "trainId": 798},
875
+ {"name": "bonfire", "id": 234, "trainId": 799},
876
+ {"name": "dog dish", "id": 764, "trainId": 800},
877
+ {"name": "belt", "id": 177, "trainId": 801},
878
+ {"name": "dumbbells", "id": 817, "trainId": 802},
879
+ {"name": "videocassette recorder, vcr", "id": 3145, "trainId": 803},
880
+ {"name": "hook", "id": 1262, "trainId": 804},
881
+ {"name": "envelopes", "id": 864, "trainId": 805},
882
+ {"name": "shower faucet", "id": 2359, "trainId": 806},
883
+ {"name": "watch", "id": 2992, "trainId": 807},
884
+ {"name": "padlock", "id": 1725, "trainId": 808},
885
+ {"name": "swimming pool ladder", "id": 2667, "trainId": 809},
886
+ {"name": "spanners", "id": 2484, "trainId": 810},
887
+ {"name": "gravy boat", "id": 1133, "trainId": 811},
888
+ {"name": "notice board", "id": 1667, "trainId": 812},
889
+ {"name": "trash bags", "id": 2847, "trainId": 813},
890
+ {"name": "fire alarm", "id": 932, "trainId": 814},
891
+ {"name": "ladle", "id": 1392, "trainId": 815},
892
+ {"name": "stethoscope", "id": 2573, "trainId": 816},
893
+ {"name": "rocket", "id": 2140, "trainId": 817},
894
+ {"name": "funnel", "id": 1046, "trainId": 818},
895
+ {"name": "bowling pins", "id": 264, "trainId": 819},
896
+ {"name": "valve", "id": 2927, "trainId": 820},
897
+ {"name": "thermometer", "id": 2752, "trainId": 821},
898
+ {"name": "cups", "id": 679, "trainId": 822},
899
+ {"name": "spice jar", "id": 2493, "trainId": 823},
900
+ {"name": "night light", "id": 1658, "trainId": 824},
901
+ {"name": "soaps", "id": 2466, "trainId": 825},
902
+ {"name": "games table", "id": 1057, "trainId": 826},
903
+ {"name": "slotted spoon", "id": 2444, "trainId": 827},
904
+ {"name": "reel", "id": 2093, "trainId": 828},
905
+ {"name": "scourer", "id": 2248, "trainId": 829},
906
+ {"name": "sleeping robe", "id": 2432, "trainId": 830},
907
+ {"name": "desk mat", "id": 726, "trainId": 831},
908
+ {"name": "dumbbell", "id": 816, "trainId": 832},
909
+ {"name": "hammer", "id": 1171, "trainId": 833},
910
+ {"name": "tie", "id": 2766, "trainId": 834},
911
+ {"name": "typewriter", "id": 2900, "trainId": 835},
912
+ {"name": "shaker", "id": 2313, "trainId": 836},
913
+ {"name": "cheese dish", "id": 488, "trainId": 837},
914
+ {"name": "sea star", "id": 2265, "trainId": 838},
915
+ {"name": "racquet", "id": 2043, "trainId": 839},
916
+ {"name": "butane gas cylinder", "id": 332, "trainId": 840},
917
+ {"name": "paper weight", "id": 1771, "trainId": 841},
918
+ {"name": "shaving brush", "id": 2320, "trainId": 842},
919
+ {"name": "sunglasses", "id": 2646, "trainId": 843},
920
+ {"name": "gear shift", "id": 1089, "trainId": 844},
921
+ {"name": "towel rail", "id": 2826, "trainId": 845},
922
+ {"name": "adding machine, totalizer, totaliser", "id": 3148, "trainId": 846},
923
+ ]
924
+
925
+
926
+ def _get_ade20k_full_meta():
927
+ # Id 0 is reserved for ignore_label, we change ignore_label for 0
928
+ # to 255 in our pre-processing, so all ids are shifted by 1.
929
+ stuff_ids = [k["id"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
930
+ assert len(stuff_ids) == 847, len(stuff_ids)
931
+
932
+ # For semantic segmentation, this mapping maps from contiguous stuff id
933
+ # (in [0, 91], used in models) to ids in the dataset (used for processing results)
934
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
935
+ stuff_classes = [k["name"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
936
+
937
+ ret = {
938
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
939
+ "stuff_classes": stuff_classes,
940
+ }
941
+ return ret
942
+
943
+
944
+ def register_all_ade20k_full(root):
945
+ root = os.path.join(root, "ADE20K_2021_17_01")
946
+ meta = _get_ade20k_full_meta()
947
+ for name, dirname in [("train", "training"), ("val", "validation")]:
948
+ image_dir = os.path.join(root, "images_detectron2", dirname)
949
+ gt_dir = os.path.join(root, "annotations_detectron2", dirname)
950
+ name = f"ade20k_full_sem_seg_{name}"
951
+ DatasetCatalog.register(
952
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="tif", image_ext="jpg")
953
+ )
954
+ MetadataCatalog.get(name).set(
955
+ stuff_classes=meta["stuff_classes"][:],
956
+ image_root=image_dir,
957
+ sem_seg_root=gt_dir,
958
+ evaluator_type="sem_seg",
959
+ ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images
960
+ )
961
+
962
+
963
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
964
+ register_all_ade20k_full(_root)
mask_former/data/datasets/register_ade20k_panoptic.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import json
3
+ import os
4
+
5
+ from detectron2.data import DatasetCatalog, MetadataCatalog
6
+ from detectron2.utils.file_io import PathManager
7
+
8
+ ADE20K_150_CATEGORIES = [
9
+ {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
10
+ {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
11
+ {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
12
+ {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
13
+ {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
14
+ {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
15
+ {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
16
+ {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
17
+ {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
18
+ {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
19
+ {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
20
+ {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
21
+ {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
22
+ {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
23
+ {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
24
+ {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
25
+ {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
26
+ {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
27
+ {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
28
+ {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
29
+ {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
30
+ {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
31
+ {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
32
+ {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
33
+ {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
34
+ {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
35
+ {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
36
+ {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
37
+ {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
38
+ {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
39
+ {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
40
+ {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
41
+ {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
42
+ {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
43
+ {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
44
+ {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
45
+ {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
46
+ {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
47
+ {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
48
+ {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
49
+ {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
50
+ {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
51
+ {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
52
+ {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
53
+ {
54
+ "color": [6, 51, 255],
55
+ "id": 44,
56
+ "isthing": 1,
57
+ "name": "chest of drawers, chest, bureau, dresser",
58
+ },
59
+ {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
60
+ {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
61
+ {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
62
+ {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
63
+ {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
64
+ {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
65
+ {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
66
+ {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
67
+ {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
68
+ {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
69
+ {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
70
+ {
71
+ "color": [255, 71, 0],
72
+ "id": 56,
73
+ "isthing": 1,
74
+ "name": "pool table, billiard table, snooker table",
75
+ },
76
+ {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
77
+ {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
78
+ {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
79
+ {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
80
+ {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
81
+ {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
82
+ {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
83
+ {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
84
+ {
85
+ "color": [0, 255, 133],
86
+ "id": 65,
87
+ "isthing": 1,
88
+ "name": "toilet, can, commode, crapper, pot, potty, stool, throne",
89
+ },
90
+ {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
91
+ {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
92
+ {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
93
+ {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
94
+ {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
95
+ {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
96
+ {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
97
+ {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
98
+ {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
99
+ {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
100
+ {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
101
+ {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
102
+ {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
103
+ {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
104
+ {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
105
+ {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
106
+ {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
107
+ {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
108
+ {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
109
+ {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
110
+ {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
111
+ {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
112
+ {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
113
+ {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
114
+ {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
115
+ {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
116
+ {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
117
+ {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
118
+ {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
119
+ {
120
+ "color": [0, 122, 255],
121
+ "id": 95,
122
+ "isthing": 1,
123
+ "name": "bannister, banister, balustrade, balusters, handrail",
124
+ },
125
+ {
126
+ "color": [0, 255, 163],
127
+ "id": 96,
128
+ "isthing": 0,
129
+ "name": "escalator, moving staircase, moving stairway",
130
+ },
131
+ {
132
+ "color": [255, 153, 0],
133
+ "id": 97,
134
+ "isthing": 1,
135
+ "name": "ottoman, pouf, pouffe, puff, hassock",
136
+ },
137
+ {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
138
+ {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
139
+ {
140
+ "color": [143, 255, 0],
141
+ "id": 100,
142
+ "isthing": 0,
143
+ "name": "poster, posting, placard, notice, bill, card",
144
+ },
145
+ {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
146
+ {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
147
+ {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
148
+ {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
149
+ {
150
+ "color": [133, 0, 255],
151
+ "id": 105,
152
+ "isthing": 0,
153
+ "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
154
+ },
155
+ {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
156
+ {
157
+ "color": [184, 0, 255],
158
+ "id": 107,
159
+ "isthing": 1,
160
+ "name": "washer, automatic washer, washing machine",
161
+ },
162
+ {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
163
+ {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
164
+ {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
165
+ {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
166
+ {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
167
+ {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
168
+ {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
169
+ {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
170
+ {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
171
+ {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
172
+ {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
173
+ {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
174
+ {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
175
+ {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
176
+ {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
177
+ {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
178
+ {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
179
+ {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
180
+ {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
181
+ {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
182
+ {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
183
+ {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
184
+ {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
185
+ {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
186
+ {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
187
+ {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
188
+ {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
189
+ {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
190
+ {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
191
+ {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
192
+ {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
193
+ {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
194
+ {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
195
+ {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
196
+ {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
197
+ {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
198
+ {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
199
+ {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
200
+ {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
201
+ {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
202
+ {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
203
+ {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
204
+ ]
205
+
206
+ ADE20k_COLORS = [k["color"] for k in ADE20K_150_CATEGORIES]
207
+
208
+ MetadataCatalog.get("ade20k_sem_seg_train").set(
209
+ stuff_colors=ADE20k_COLORS[:],
210
+ )
211
+
212
+ MetadataCatalog.get("ade20k_sem_seg_val").set(
213
+ stuff_colors=ADE20k_COLORS[:],
214
+ )
215
+
216
+
217
+ def load_ade20k_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta):
218
+ """
219
+ Args:
220
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
221
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
222
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
223
+ Returns:
224
+ list[dict]: a list of dicts in Detectron2 standard format. (See
225
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
226
+ """
227
+
228
+ def _convert_category_id(segment_info, meta):
229
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
230
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
231
+ segment_info["category_id"]
232
+ ]
233
+ segment_info["isthing"] = True
234
+ else:
235
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
236
+ segment_info["category_id"]
237
+ ]
238
+ segment_info["isthing"] = False
239
+ return segment_info
240
+
241
+ with PathManager.open(json_file) as f:
242
+ json_info = json.load(f)
243
+
244
+ ret = []
245
+ for ann in json_info["annotations"]:
246
+ image_id = ann["image_id"]
247
+ # TODO: currently we assume image and label has the same filename but
248
+ # different extension, and images have extension ".jpg" for COCO. Need
249
+ # to make image extension a user-provided argument if we extend this
250
+ # function to support other COCO-like datasets.
251
+ image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
252
+ label_file = os.path.join(gt_dir, ann["file_name"])
253
+ sem_label_file = os.path.join(semseg_dir, ann["file_name"])
254
+ segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
255
+ ret.append(
256
+ {
257
+ "file_name": image_file,
258
+ "image_id": image_id,
259
+ "pan_seg_file_name": label_file,
260
+ "sem_seg_file_name": sem_label_file,
261
+ "segments_info": segments_info,
262
+ }
263
+ )
264
+ assert len(ret), f"No images found in {image_dir}!"
265
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
266
+ assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
267
+ assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
268
+ return ret
269
+
270
+
271
+ def register_ade20k_panoptic(
272
+ name, metadata, image_root, panoptic_root, semantic_root, panoptic_json, instances_json=None
273
+ ):
274
+ """
275
+ Register a "standard" version of ADE20k panoptic segmentation dataset named `name`.
276
+ The dictionaries in this registered dataset follows detectron2's standard format.
277
+ Hence it's called "standard".
278
+ Args:
279
+ name (str): the name that identifies a dataset,
280
+ e.g. "ade20k_panoptic_train"
281
+ metadata (dict): extra metadata associated with this dataset.
282
+ image_root (str): directory which contains all the images
283
+ panoptic_root (str): directory which contains panoptic annotation images in COCO format
284
+ panoptic_json (str): path to the json panoptic annotation file in COCO format
285
+ sem_seg_root (none): not used, to be consistent with
286
+ `register_coco_panoptic_separated`.
287
+ instances_json (str): path to the json instance annotation file
288
+ """
289
+ panoptic_name = name
290
+ DatasetCatalog.register(
291
+ panoptic_name,
292
+ lambda: load_ade20k_panoptic_json(
293
+ panoptic_json, image_root, panoptic_root, semantic_root, metadata
294
+ ),
295
+ )
296
+ MetadataCatalog.get(panoptic_name).set(
297
+ panoptic_root=panoptic_root,
298
+ image_root=image_root,
299
+ panoptic_json=panoptic_json,
300
+ json_file=instances_json,
301
+ evaluator_type="ade20k_panoptic_seg",
302
+ ignore_label=255,
303
+ label_divisor=1000,
304
+ **metadata,
305
+ )
306
+
307
+
308
+ _PREDEFINED_SPLITS_ADE20K_PANOPTIC = {
309
+ "ade20k_panoptic_train": (
310
+ "ADEChallengeData2016/images/training",
311
+ "ADEChallengeData2016/ade20k_panoptic_train",
312
+ "ADEChallengeData2016/ade20k_panoptic_train.json",
313
+ "ADEChallengeData2016/annotations_detectron2/training",
314
+ ),
315
+ "ade20k_panoptic_val": (
316
+ "ADEChallengeData2016/images/validation",
317
+ "ADEChallengeData2016/ade20k_panoptic_val",
318
+ "ADEChallengeData2016/ade20k_panoptic_val.json",
319
+ "ADEChallengeData2016/annotations_detectron2/validation",
320
+ ),
321
+ }
322
+
323
+
324
+ def get_metadata():
325
+ meta = {}
326
+ # The following metadata maps contiguous id from [0, #thing categories +
327
+ # #stuff categories) to their names and colors. We have to replica of the
328
+ # same name and color under "thing_*" and "stuff_*" because the current
329
+ # visualization function in D2 handles thing and class classes differently
330
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
331
+ # enable reusing existing visualization functions.
332
+ thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
333
+ thing_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
334
+ stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
335
+ stuff_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
336
+
337
+ meta["thing_classes"] = thing_classes
338
+ meta["thing_colors"] = thing_colors
339
+ meta["stuff_classes"] = stuff_classes
340
+ meta["stuff_colors"] = stuff_colors
341
+
342
+ # Convert category id for training:
343
+ # category id: like semantic segmentation, it is the class id for each
344
+ # pixel. Since there are some classes not used in evaluation, the category
345
+ # id is not always contiguous and thus we have two set of category ids:
346
+ # - original category id: category id in the original dataset, mainly
347
+ # used for evaluation.
348
+ # - contiguous category id: [0, #classes), in order to train the linear
349
+ # softmax classifier.
350
+ thing_dataset_id_to_contiguous_id = {}
351
+ stuff_dataset_id_to_contiguous_id = {}
352
+
353
+ for i, cat in enumerate(ADE20K_150_CATEGORIES):
354
+ if cat["isthing"]:
355
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
356
+ # else:
357
+ # stuff_dataset_id_to_contiguous_id[cat["id"]] = i
358
+
359
+ # in order to use sem_seg evaluator
360
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
361
+
362
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
363
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
364
+
365
+ return meta
366
+
367
+
368
+ def register_all_ade20k_panoptic(root):
369
+ metadata = get_metadata()
370
+ for (
371
+ prefix,
372
+ (image_root, panoptic_root, panoptic_json, semantic_root),
373
+ ) in _PREDEFINED_SPLITS_ADE20K_PANOPTIC.items():
374
+ # The "standard" version of COCO panoptic segmentation dataset,
375
+ # e.g. used by Panoptic-DeepLab
376
+ register_ade20k_panoptic(
377
+ prefix,
378
+ metadata,
379
+ os.path.join(root, image_root),
380
+ os.path.join(root, panoptic_root),
381
+ os.path.join(root, semantic_root),
382
+ os.path.join(root, panoptic_json),
383
+ )
384
+
385
+
386
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
387
+ register_all_ade20k_panoptic(_root)
mask_former/data/datasets/register_coco_stuff_10k.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ from detectron2.data import DatasetCatalog, MetadataCatalog
5
+ from detectron2.data.datasets import load_sem_seg
6
+
7
+ COCO_CATEGORIES = [
8
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
9
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
10
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
11
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
12
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
13
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
14
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
15
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
16
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
17
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
18
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
19
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
20
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
21
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
22
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
23
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
24
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
25
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
26
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
27
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
28
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
29
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
30
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
31
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
32
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
33
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
34
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
35
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
36
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
37
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
38
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
39
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
40
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
41
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
42
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
43
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
44
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
45
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
46
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
47
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
48
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
49
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
50
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
51
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
52
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
53
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
54
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
55
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
56
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
57
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
58
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
59
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
60
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
61
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
62
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
63
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
64
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
65
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
66
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
67
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
68
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
69
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
70
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
71
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
72
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
73
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
74
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
75
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
76
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
77
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
78
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
79
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
80
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
81
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
82
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
83
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
84
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
85
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
86
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
87
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
88
+ {"id": 92, "name": "banner", "supercategory": "textile"},
89
+ {"id": 93, "name": "blanket", "supercategory": "textile"},
90
+ {"id": 94, "name": "branch", "supercategory": "plant"},
91
+ {"id": 95, "name": "bridge", "supercategory": "building"},
92
+ {"id": 96, "name": "building-other", "supercategory": "building"},
93
+ {"id": 97, "name": "bush", "supercategory": "plant"},
94
+ {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
95
+ {"id": 99, "name": "cage", "supercategory": "structural"},
96
+ {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
97
+ {"id": 101, "name": "carpet", "supercategory": "floor"},
98
+ {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
99
+ {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
100
+ {"id": 104, "name": "cloth", "supercategory": "textile"},
101
+ {"id": 105, "name": "clothes", "supercategory": "textile"},
102
+ {"id": 106, "name": "clouds", "supercategory": "sky"},
103
+ {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
104
+ {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
105
+ {"id": 109, "name": "curtain", "supercategory": "textile"},
106
+ {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
107
+ {"id": 111, "name": "dirt", "supercategory": "ground"},
108
+ {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
109
+ {"id": 113, "name": "fence", "supercategory": "structural"},
110
+ {"id": 114, "name": "floor-marble", "supercategory": "floor"},
111
+ {"id": 115, "name": "floor-other", "supercategory": "floor"},
112
+ {"id": 116, "name": "floor-stone", "supercategory": "floor"},
113
+ {"id": 117, "name": "floor-tile", "supercategory": "floor"},
114
+ {"id": 118, "name": "floor-wood", "supercategory": "floor"},
115
+ {"id": 119, "name": "flower", "supercategory": "plant"},
116
+ {"id": 120, "name": "fog", "supercategory": "water"},
117
+ {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
118
+ {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
119
+ {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
120
+ {"id": 124, "name": "grass", "supercategory": "plant"},
121
+ {"id": 125, "name": "gravel", "supercategory": "ground"},
122
+ {"id": 126, "name": "ground-other", "supercategory": "ground"},
123
+ {"id": 127, "name": "hill", "supercategory": "solid"},
124
+ {"id": 128, "name": "house", "supercategory": "building"},
125
+ {"id": 129, "name": "leaves", "supercategory": "plant"},
126
+ {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
127
+ {"id": 131, "name": "mat", "supercategory": "textile"},
128
+ {"id": 132, "name": "metal", "supercategory": "raw-material"},
129
+ {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
130
+ {"id": 134, "name": "moss", "supercategory": "plant"},
131
+ {"id": 135, "name": "mountain", "supercategory": "solid"},
132
+ {"id": 136, "name": "mud", "supercategory": "ground"},
133
+ {"id": 137, "name": "napkin", "supercategory": "textile"},
134
+ {"id": 138, "name": "net", "supercategory": "structural"},
135
+ {"id": 139, "name": "paper", "supercategory": "raw-material"},
136
+ {"id": 140, "name": "pavement", "supercategory": "ground"},
137
+ {"id": 141, "name": "pillow", "supercategory": "textile"},
138
+ {"id": 142, "name": "plant-other", "supercategory": "plant"},
139
+ {"id": 143, "name": "plastic", "supercategory": "raw-material"},
140
+ {"id": 144, "name": "platform", "supercategory": "ground"},
141
+ {"id": 145, "name": "playingfield", "supercategory": "ground"},
142
+ {"id": 146, "name": "railing", "supercategory": "structural"},
143
+ {"id": 147, "name": "railroad", "supercategory": "ground"},
144
+ {"id": 148, "name": "river", "supercategory": "water"},
145
+ {"id": 149, "name": "road", "supercategory": "ground"},
146
+ {"id": 150, "name": "rock", "supercategory": "solid"},
147
+ {"id": 151, "name": "roof", "supercategory": "building"},
148
+ {"id": 152, "name": "rug", "supercategory": "textile"},
149
+ {"id": 153, "name": "salad", "supercategory": "food-stuff"},
150
+ {"id": 154, "name": "sand", "supercategory": "ground"},
151
+ {"id": 155, "name": "sea", "supercategory": "water"},
152
+ {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
153
+ {"id": 157, "name": "sky-other", "supercategory": "sky"},
154
+ {"id": 158, "name": "skyscraper", "supercategory": "building"},
155
+ {"id": 159, "name": "snow", "supercategory": "ground"},
156
+ {"id": 160, "name": "solid-other", "supercategory": "solid"},
157
+ {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
158
+ {"id": 162, "name": "stone", "supercategory": "solid"},
159
+ {"id": 163, "name": "straw", "supercategory": "plant"},
160
+ {"id": 164, "name": "structural-other", "supercategory": "structural"},
161
+ {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
162
+ {"id": 166, "name": "tent", "supercategory": "building"},
163
+ {"id": 167, "name": "textile-other", "supercategory": "textile"},
164
+ {"id": 168, "name": "towel", "supercategory": "textile"},
165
+ {"id": 169, "name": "tree", "supercategory": "plant"},
166
+ {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
167
+ {"id": 171, "name": "wall-brick", "supercategory": "wall"},
168
+ {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
169
+ {"id": 173, "name": "wall-other", "supercategory": "wall"},
170
+ {"id": 174, "name": "wall-panel", "supercategory": "wall"},
171
+ {"id": 175, "name": "wall-stone", "supercategory": "wall"},
172
+ {"id": 176, "name": "wall-tile", "supercategory": "wall"},
173
+ {"id": 177, "name": "wall-wood", "supercategory": "wall"},
174
+ {"id": 178, "name": "water-other", "supercategory": "water"},
175
+ {"id": 179, "name": "waterdrops", "supercategory": "water"},
176
+ {"id": 180, "name": "window-blind", "supercategory": "window"},
177
+ {"id": 181, "name": "window-other", "supercategory": "window"},
178
+ {"id": 182, "name": "wood", "supercategory": "solid"},
179
+ ]
180
+
181
+
182
+ def _get_coco_stuff_meta():
183
+ # Id 0 is reserved for ignore_label, we change ignore_label for 0
184
+ # to 255 in our pre-processing.
185
+ stuff_ids = [k["id"] for k in COCO_CATEGORIES]
186
+ assert len(stuff_ids) == 171, len(stuff_ids)
187
+
188
+ # For semantic segmentation, this mapping maps from contiguous stuff id
189
+ # (in [0, 91], used in models) to ids in the dataset (used for processing results)
190
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
191
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
192
+
193
+ ret = {
194
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
195
+ "stuff_classes": stuff_classes,
196
+ }
197
+ return ret
198
+
199
+
200
+ def register_all_coco_stuff_10k(root):
201
+ root = os.path.join(root, "coco", "coco_stuff_10k")
202
+ meta = _get_coco_stuff_meta()
203
+ for name, image_dirname, sem_seg_dirname in [
204
+ ("train", "images_detectron2/train", "annotations_detectron2/train"),
205
+ ("test", "images_detectron2/test", "annotations_detectron2/test"),
206
+ ]:
207
+ image_dir = os.path.join(root, image_dirname)
208
+ gt_dir = os.path.join(root, sem_seg_dirname)
209
+ name = f"coco_2017_{name}_stuff_10k_sem_seg"
210
+ DatasetCatalog.register(
211
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
212
+ )
213
+ MetadataCatalog.get(name).set(
214
+ image_root=image_dir,
215
+ sem_seg_root=gt_dir,
216
+ evaluator_type="sem_seg",
217
+ ignore_label=255,
218
+ **meta,
219
+ )
220
+
221
+
222
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
223
+ register_all_coco_stuff_10k(_root)
mask_former/data/datasets/register_mapillary_vistas.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ from detectron2.data import DatasetCatalog, MetadataCatalog
5
+ from detectron2.data.datasets import load_sem_seg
6
+
7
+ MAPILLARY_VISTAS_SEM_SEG_CATEGORIES = [
8
+ {
9
+ "color": [165, 42, 42],
10
+ "instances": True,
11
+ "readable": "Bird",
12
+ "name": "animal--bird",
13
+ "evaluate": True,
14
+ },
15
+ {
16
+ "color": [0, 192, 0],
17
+ "instances": True,
18
+ "readable": "Ground Animal",
19
+ "name": "animal--ground-animal",
20
+ "evaluate": True,
21
+ },
22
+ {
23
+ "color": [196, 196, 196],
24
+ "instances": False,
25
+ "readable": "Curb",
26
+ "name": "construction--barrier--curb",
27
+ "evaluate": True,
28
+ },
29
+ {
30
+ "color": [190, 153, 153],
31
+ "instances": False,
32
+ "readable": "Fence",
33
+ "name": "construction--barrier--fence",
34
+ "evaluate": True,
35
+ },
36
+ {
37
+ "color": [180, 165, 180],
38
+ "instances": False,
39
+ "readable": "Guard Rail",
40
+ "name": "construction--barrier--guard-rail",
41
+ "evaluate": True,
42
+ },
43
+ {
44
+ "color": [90, 120, 150],
45
+ "instances": False,
46
+ "readable": "Barrier",
47
+ "name": "construction--barrier--other-barrier",
48
+ "evaluate": True,
49
+ },
50
+ {
51
+ "color": [102, 102, 156],
52
+ "instances": False,
53
+ "readable": "Wall",
54
+ "name": "construction--barrier--wall",
55
+ "evaluate": True,
56
+ },
57
+ {
58
+ "color": [128, 64, 255],
59
+ "instances": False,
60
+ "readable": "Bike Lane",
61
+ "name": "construction--flat--bike-lane",
62
+ "evaluate": True,
63
+ },
64
+ {
65
+ "color": [140, 140, 200],
66
+ "instances": True,
67
+ "readable": "Crosswalk - Plain",
68
+ "name": "construction--flat--crosswalk-plain",
69
+ "evaluate": True,
70
+ },
71
+ {
72
+ "color": [170, 170, 170],
73
+ "instances": False,
74
+ "readable": "Curb Cut",
75
+ "name": "construction--flat--curb-cut",
76
+ "evaluate": True,
77
+ },
78
+ {
79
+ "color": [250, 170, 160],
80
+ "instances": False,
81
+ "readable": "Parking",
82
+ "name": "construction--flat--parking",
83
+ "evaluate": True,
84
+ },
85
+ {
86
+ "color": [96, 96, 96],
87
+ "instances": False,
88
+ "readable": "Pedestrian Area",
89
+ "name": "construction--flat--pedestrian-area",
90
+ "evaluate": True,
91
+ },
92
+ {
93
+ "color": [230, 150, 140],
94
+ "instances": False,
95
+ "readable": "Rail Track",
96
+ "name": "construction--flat--rail-track",
97
+ "evaluate": True,
98
+ },
99
+ {
100
+ "color": [128, 64, 128],
101
+ "instances": False,
102
+ "readable": "Road",
103
+ "name": "construction--flat--road",
104
+ "evaluate": True,
105
+ },
106
+ {
107
+ "color": [110, 110, 110],
108
+ "instances": False,
109
+ "readable": "Service Lane",
110
+ "name": "construction--flat--service-lane",
111
+ "evaluate": True,
112
+ },
113
+ {
114
+ "color": [244, 35, 232],
115
+ "instances": False,
116
+ "readable": "Sidewalk",
117
+ "name": "construction--flat--sidewalk",
118
+ "evaluate": True,
119
+ },
120
+ {
121
+ "color": [150, 100, 100],
122
+ "instances": False,
123
+ "readable": "Bridge",
124
+ "name": "construction--structure--bridge",
125
+ "evaluate": True,
126
+ },
127
+ {
128
+ "color": [70, 70, 70],
129
+ "instances": False,
130
+ "readable": "Building",
131
+ "name": "construction--structure--building",
132
+ "evaluate": True,
133
+ },
134
+ {
135
+ "color": [150, 120, 90],
136
+ "instances": False,
137
+ "readable": "Tunnel",
138
+ "name": "construction--structure--tunnel",
139
+ "evaluate": True,
140
+ },
141
+ {
142
+ "color": [220, 20, 60],
143
+ "instances": True,
144
+ "readable": "Person",
145
+ "name": "human--person",
146
+ "evaluate": True,
147
+ },
148
+ {
149
+ "color": [255, 0, 0],
150
+ "instances": True,
151
+ "readable": "Bicyclist",
152
+ "name": "human--rider--bicyclist",
153
+ "evaluate": True,
154
+ },
155
+ {
156
+ "color": [255, 0, 100],
157
+ "instances": True,
158
+ "readable": "Motorcyclist",
159
+ "name": "human--rider--motorcyclist",
160
+ "evaluate": True,
161
+ },
162
+ {
163
+ "color": [255, 0, 200],
164
+ "instances": True,
165
+ "readable": "Other Rider",
166
+ "name": "human--rider--other-rider",
167
+ "evaluate": True,
168
+ },
169
+ {
170
+ "color": [200, 128, 128],
171
+ "instances": True,
172
+ "readable": "Lane Marking - Crosswalk",
173
+ "name": "marking--crosswalk-zebra",
174
+ "evaluate": True,
175
+ },
176
+ {
177
+ "color": [255, 255, 255],
178
+ "instances": False,
179
+ "readable": "Lane Marking - General",
180
+ "name": "marking--general",
181
+ "evaluate": True,
182
+ },
183
+ {
184
+ "color": [64, 170, 64],
185
+ "instances": False,
186
+ "readable": "Mountain",
187
+ "name": "nature--mountain",
188
+ "evaluate": True,
189
+ },
190
+ {
191
+ "color": [230, 160, 50],
192
+ "instances": False,
193
+ "readable": "Sand",
194
+ "name": "nature--sand",
195
+ "evaluate": True,
196
+ },
197
+ {
198
+ "color": [70, 130, 180],
199
+ "instances": False,
200
+ "readable": "Sky",
201
+ "name": "nature--sky",
202
+ "evaluate": True,
203
+ },
204
+ {
205
+ "color": [190, 255, 255],
206
+ "instances": False,
207
+ "readable": "Snow",
208
+ "name": "nature--snow",
209
+ "evaluate": True,
210
+ },
211
+ {
212
+ "color": [152, 251, 152],
213
+ "instances": False,
214
+ "readable": "Terrain",
215
+ "name": "nature--terrain",
216
+ "evaluate": True,
217
+ },
218
+ {
219
+ "color": [107, 142, 35],
220
+ "instances": False,
221
+ "readable": "Vegetation",
222
+ "name": "nature--vegetation",
223
+ "evaluate": True,
224
+ },
225
+ {
226
+ "color": [0, 170, 30],
227
+ "instances": False,
228
+ "readable": "Water",
229
+ "name": "nature--water",
230
+ "evaluate": True,
231
+ },
232
+ {
233
+ "color": [255, 255, 128],
234
+ "instances": True,
235
+ "readable": "Banner",
236
+ "name": "object--banner",
237
+ "evaluate": True,
238
+ },
239
+ {
240
+ "color": [250, 0, 30],
241
+ "instances": True,
242
+ "readable": "Bench",
243
+ "name": "object--bench",
244
+ "evaluate": True,
245
+ },
246
+ {
247
+ "color": [100, 140, 180],
248
+ "instances": True,
249
+ "readable": "Bike Rack",
250
+ "name": "object--bike-rack",
251
+ "evaluate": True,
252
+ },
253
+ {
254
+ "color": [220, 220, 220],
255
+ "instances": True,
256
+ "readable": "Billboard",
257
+ "name": "object--billboard",
258
+ "evaluate": True,
259
+ },
260
+ {
261
+ "color": [220, 128, 128],
262
+ "instances": True,
263
+ "readable": "Catch Basin",
264
+ "name": "object--catch-basin",
265
+ "evaluate": True,
266
+ },
267
+ {
268
+ "color": [222, 40, 40],
269
+ "instances": True,
270
+ "readable": "CCTV Camera",
271
+ "name": "object--cctv-camera",
272
+ "evaluate": True,
273
+ },
274
+ {
275
+ "color": [100, 170, 30],
276
+ "instances": True,
277
+ "readable": "Fire Hydrant",
278
+ "name": "object--fire-hydrant",
279
+ "evaluate": True,
280
+ },
281
+ {
282
+ "color": [40, 40, 40],
283
+ "instances": True,
284
+ "readable": "Junction Box",
285
+ "name": "object--junction-box",
286
+ "evaluate": True,
287
+ },
288
+ {
289
+ "color": [33, 33, 33],
290
+ "instances": True,
291
+ "readable": "Mailbox",
292
+ "name": "object--mailbox",
293
+ "evaluate": True,
294
+ },
295
+ {
296
+ "color": [100, 128, 160],
297
+ "instances": True,
298
+ "readable": "Manhole",
299
+ "name": "object--manhole",
300
+ "evaluate": True,
301
+ },
302
+ {
303
+ "color": [142, 0, 0],
304
+ "instances": True,
305
+ "readable": "Phone Booth",
306
+ "name": "object--phone-booth",
307
+ "evaluate": True,
308
+ },
309
+ {
310
+ "color": [70, 100, 150],
311
+ "instances": False,
312
+ "readable": "Pothole",
313
+ "name": "object--pothole",
314
+ "evaluate": True,
315
+ },
316
+ {
317
+ "color": [210, 170, 100],
318
+ "instances": True,
319
+ "readable": "Street Light",
320
+ "name": "object--street-light",
321
+ "evaluate": True,
322
+ },
323
+ {
324
+ "color": [153, 153, 153],
325
+ "instances": True,
326
+ "readable": "Pole",
327
+ "name": "object--support--pole",
328
+ "evaluate": True,
329
+ },
330
+ {
331
+ "color": [128, 128, 128],
332
+ "instances": True,
333
+ "readable": "Traffic Sign Frame",
334
+ "name": "object--support--traffic-sign-frame",
335
+ "evaluate": True,
336
+ },
337
+ {
338
+ "color": [0, 0, 80],
339
+ "instances": True,
340
+ "readable": "Utility Pole",
341
+ "name": "object--support--utility-pole",
342
+ "evaluate": True,
343
+ },
344
+ {
345
+ "color": [250, 170, 30],
346
+ "instances": True,
347
+ "readable": "Traffic Light",
348
+ "name": "object--traffic-light",
349
+ "evaluate": True,
350
+ },
351
+ {
352
+ "color": [192, 192, 192],
353
+ "instances": True,
354
+ "readable": "Traffic Sign (Back)",
355
+ "name": "object--traffic-sign--back",
356
+ "evaluate": True,
357
+ },
358
+ {
359
+ "color": [220, 220, 0],
360
+ "instances": True,
361
+ "readable": "Traffic Sign (Front)",
362
+ "name": "object--traffic-sign--front",
363
+ "evaluate": True,
364
+ },
365
+ {
366
+ "color": [140, 140, 20],
367
+ "instances": True,
368
+ "readable": "Trash Can",
369
+ "name": "object--trash-can",
370
+ "evaluate": True,
371
+ },
372
+ {
373
+ "color": [119, 11, 32],
374
+ "instances": True,
375
+ "readable": "Bicycle",
376
+ "name": "object--vehicle--bicycle",
377
+ "evaluate": True,
378
+ },
379
+ {
380
+ "color": [150, 0, 255],
381
+ "instances": True,
382
+ "readable": "Boat",
383
+ "name": "object--vehicle--boat",
384
+ "evaluate": True,
385
+ },
386
+ {
387
+ "color": [0, 60, 100],
388
+ "instances": True,
389
+ "readable": "Bus",
390
+ "name": "object--vehicle--bus",
391
+ "evaluate": True,
392
+ },
393
+ {
394
+ "color": [0, 0, 142],
395
+ "instances": True,
396
+ "readable": "Car",
397
+ "name": "object--vehicle--car",
398
+ "evaluate": True,
399
+ },
400
+ {
401
+ "color": [0, 0, 90],
402
+ "instances": True,
403
+ "readable": "Caravan",
404
+ "name": "object--vehicle--caravan",
405
+ "evaluate": True,
406
+ },
407
+ {
408
+ "color": [0, 0, 230],
409
+ "instances": True,
410
+ "readable": "Motorcycle",
411
+ "name": "object--vehicle--motorcycle",
412
+ "evaluate": True,
413
+ },
414
+ {
415
+ "color": [0, 80, 100],
416
+ "instances": False,
417
+ "readable": "On Rails",
418
+ "name": "object--vehicle--on-rails",
419
+ "evaluate": True,
420
+ },
421
+ {
422
+ "color": [128, 64, 64],
423
+ "instances": True,
424
+ "readable": "Other Vehicle",
425
+ "name": "object--vehicle--other-vehicle",
426
+ "evaluate": True,
427
+ },
428
+ {
429
+ "color": [0, 0, 110],
430
+ "instances": True,
431
+ "readable": "Trailer",
432
+ "name": "object--vehicle--trailer",
433
+ "evaluate": True,
434
+ },
435
+ {
436
+ "color": [0, 0, 70],
437
+ "instances": True,
438
+ "readable": "Truck",
439
+ "name": "object--vehicle--truck",
440
+ "evaluate": True,
441
+ },
442
+ {
443
+ "color": [0, 0, 192],
444
+ "instances": True,
445
+ "readable": "Wheeled Slow",
446
+ "name": "object--vehicle--wheeled-slow",
447
+ "evaluate": True,
448
+ },
449
+ {
450
+ "color": [32, 32, 32],
451
+ "instances": False,
452
+ "readable": "Car Mount",
453
+ "name": "void--car-mount",
454
+ "evaluate": True,
455
+ },
456
+ {
457
+ "color": [120, 10, 10],
458
+ "instances": False,
459
+ "readable": "Ego Vehicle",
460
+ "name": "void--ego-vehicle",
461
+ "evaluate": True,
462
+ },
463
+ {
464
+ "color": [0, 0, 0],
465
+ "instances": False,
466
+ "readable": "Unlabeled",
467
+ "name": "void--unlabeled",
468
+ "evaluate": False,
469
+ },
470
+ ]
471
+
472
+
473
+ def _get_mapillary_vistas_meta():
474
+ stuff_classes = [k["readable"] for k in MAPILLARY_VISTAS_SEM_SEG_CATEGORIES if k["evaluate"]]
475
+ assert len(stuff_classes) == 65
476
+
477
+ stuff_colors = [k["color"] for k in MAPILLARY_VISTAS_SEM_SEG_CATEGORIES if k["evaluate"]]
478
+ assert len(stuff_colors) == 65
479
+
480
+ ret = {
481
+ "stuff_classes": stuff_classes,
482
+ "stuff_colors": stuff_colors,
483
+ }
484
+ return ret
485
+
486
+
487
+ def register_all_mapillary_vistas(root):
488
+ root = os.path.join(root, "mapillary_vistas")
489
+ meta = _get_mapillary_vistas_meta()
490
+ for name, dirname in [("train", "training"), ("val", "validation")]:
491
+ image_dir = os.path.join(root, dirname, "images")
492
+ gt_dir = os.path.join(root, dirname, "labels")
493
+ name = f"mapillary_vistas_sem_seg_{name}"
494
+ DatasetCatalog.register(
495
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
496
+ )
497
+ MetadataCatalog.get(name).set(
498
+ image_root=image_dir,
499
+ sem_seg_root=gt_dir,
500
+ evaluator_type="sem_seg",
501
+ ignore_label=65, # different from other datasets, Mapillary Vistas sets ignore_label to 65
502
+ **meta,
503
+ )
504
+
505
+
506
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
507
+ register_all_mapillary_vistas(_root)
mask_former/mask_former_model.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from detectron2.config import configurable
6
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
7
+ from detectron2.modeling.backbone import Backbone
8
+ from detectron2.structures import ImageList
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torchvision.transforms import functional as Ftv
12
+
13
+ from utils.log import getLogger
14
+ from .modeling.criterion import SetCriterion
15
+ from .modeling.matcher import HungarianMatcher
16
+
17
+ logger = getLogger(__name__)
18
+
19
+
20
+ def interpolate_or_crop(img,
21
+ size=(128, 128),
22
+ mode="bilinear",
23
+ align_corners=False,
24
+ tol=1.1):
25
+ h, w = img.shape[-2:]
26
+ H, W = size
27
+ if h == H and w == W:
28
+ return img
29
+ if H <= h < tol * H and W <= w < tol * W:
30
+ logger.info_once(f"Using center cropping instead of interpolation")
31
+ return Ftv.center_crop(img, output_size=size)
32
+ return F.interpolate(img, size=size, mode=mode, align_corners=align_corners)
33
+
34
+
35
+ @META_ARCH_REGISTRY.register()
36
+ class MaskFormer(nn.Module):
37
+ """
38
+ Main class for mask classification semantic segmentation architectures.
39
+ """
40
+
41
+ @configurable
42
+ def __init__(
43
+ self,
44
+ *,
45
+ backbone: Backbone,
46
+ sem_seg_head: nn.Module,
47
+ criterion: nn.Module,
48
+ num_queries: int,
49
+ panoptic_on: bool,
50
+ object_mask_threshold: float,
51
+ overlap_threshold: float,
52
+ metadata,
53
+ size_divisibility: int,
54
+ sem_seg_postprocess_before_inference: bool,
55
+ pixel_mean: Tuple[float],
56
+ pixel_std: Tuple[float],
57
+ crop_not_upsample: bool=True
58
+ ):
59
+ """
60
+ Args:
61
+ backbone: a backbone module, must follow detectron2's backbone interface
62
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
63
+ criterion: a module that defines the loss
64
+ num_queries: int, number of queries
65
+ panoptic_on: bool, whether to output panoptic segmentation prediction
66
+ object_mask_threshold: float, threshold to filter query based on classification score
67
+ for panoptic segmentation inference
68
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
69
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
70
+ segmentation inference
71
+ size_divisibility: Some backbones require the input height and width to be divisible by a
72
+ specific integer. We can use this to override such requirement.
73
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
74
+ to original input size before semantic segmentation inference or after.
75
+ For high-resolution dataset like Mapillary, resizing predictions before
76
+ inference will cause OOM error.
77
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
78
+ the per-channel mean and std to be used to normalize the input image
79
+ """
80
+ super().__init__()
81
+ self.crop_not_upsample = crop_not_upsample
82
+ self.backbone = backbone
83
+ self.sem_seg_head = sem_seg_head
84
+ self.criterion = criterion
85
+ self.num_queries = num_queries
86
+ self.overlap_threshold = overlap_threshold
87
+ self.panoptic_on = panoptic_on
88
+ self.object_mask_threshold = object_mask_threshold
89
+ self.metadata = metadata
90
+ if size_divisibility < 0:
91
+ # use backbone size_divisibility if not set
92
+ size_divisibility = self.backbone.size_divisibility
93
+ self.size_divisibility = size_divisibility
94
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
95
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
96
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
97
+
98
+ @classmethod
99
+ def from_config(cls, cfg):
100
+ backbone = build_backbone(cfg)
101
+ out_shape = backbone.output_shape()
102
+ if len(cfg.GWM.SAMPLE_KEYS) > 1:
103
+ for k, v in out_shape.items():
104
+ out_shape[k] = v._replace(channels=v.channels * len(cfg.GWM.SAMPLE_KEYS))
105
+ sem_seg_head = build_sem_seg_head(cfg, out_shape)
106
+
107
+ # Loss parameters:
108
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
109
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
110
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
111
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
112
+
113
+ # building criterion
114
+ matcher = HungarianMatcher(
115
+ cost_class=1,
116
+ cost_mask=mask_weight,
117
+ cost_dice=dice_weight,
118
+ )
119
+
120
+ weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight}
121
+ if deep_supervision:
122
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
123
+ aux_weight_dict = {}
124
+ for i in range(dec_layers - 1):
125
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
126
+ weight_dict.update(aux_weight_dict)
127
+
128
+ losses = ["labels", "masks"]
129
+
130
+ criterion = SetCriterion(
131
+ sem_seg_head.num_classes,
132
+ matcher=matcher,
133
+ weight_dict=weight_dict,
134
+ eos_coef=no_object_weight,
135
+ losses=losses,
136
+ )
137
+
138
+ return {
139
+ "backbone": backbone,
140
+ "sem_seg_head": sem_seg_head,
141
+ "criterion": criterion,
142
+ "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
143
+ "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
144
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
145
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
146
+ "metadata": None, # MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
147
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
148
+ "sem_seg_postprocess_before_inference": (
149
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
150
+ or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
151
+ ),
152
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
153
+ "pixel_std": cfg.MODEL.PIXEL_STD,
154
+ 'crop_not_upsample': cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME != 'BasePixelDecoder'
155
+ }
156
+
157
+ @property
158
+ def device(self):
159
+ return self.pixel_mean.device
160
+
161
+ def forward(self, batched_inputs):
162
+ """
163
+ Args:
164
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
165
+ Each item in the list contains the inputs for one image.
166
+ For now, each item in the list is a dict that contains:
167
+ * "image": Tensor, image in (C, H, W) format.
168
+ * "instances": per-region ground truth
169
+ * Other information that's included in the original dicts, such as:
170
+ "height", "width" (int): the output resolution of the model (may be different
171
+ from input resolution), used in inference.
172
+ Returns:
173
+ list[dict]:
174
+ each dict has the results for one image. The dict contains the following keys:
175
+
176
+ * "sem_seg":
177
+ A Tensor that represents the
178
+ per-pixel segmentation prediced by the head.
179
+ The prediction has shape KxHxW that represents the logits of
180
+ each class for each pixel.
181
+ * "panoptic_seg":
182
+ A tuple that represent panoptic output
183
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
184
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
185
+ Each dict contains keys "id", "category_id", "isthing".
186
+ """
187
+ return self.forward_base(batched_inputs, keys=["image"], get_train=not self.training,
188
+ get_eval=not self.training)
189
+
190
+ def forward_base(self, batched_inputs, keys, get_train=False, get_eval=False, raw_sem_seg=False):
191
+ for i, key in enumerate(keys):
192
+ images = [x[key].to(self.device) for x in batched_inputs]
193
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
194
+ images = ImageList.from_tensors(images, self.size_divisibility)
195
+ logger.debug_once(f"Maskformer input {key} shape: {images.tensor.shape}")
196
+ out = self.backbone(images.tensor)
197
+ if i == 0:
198
+ features = out
199
+ else:
200
+ features = {k: torch.cat([features[k], v], 1) for k, v in out.items()}
201
+ outputs = self.sem_seg_head(features)
202
+
203
+ if get_train:
204
+ # mask classification target
205
+ if "instances" in batched_inputs[0]:
206
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
207
+ targets = self.prepare_targets(gt_instances, images)
208
+ else:
209
+ targets = None
210
+
211
+ # bipartite matching-based loss
212
+ losses = self.criterion(outputs, targets)
213
+
214
+ for k in list(losses.keys()):
215
+ if k in self.criterion.weight_dict:
216
+ losses[k] *= self.criterion.weight_dict[k]
217
+ else:
218
+ # remove this loss if not specified in `weight_dict`
219
+ losses.pop(k)
220
+ if not get_eval:
221
+ return losses
222
+
223
+ if get_eval:
224
+ # mask_cls_results = outputs["pred_logits"]
225
+ mask_pred_results = outputs["pred_masks"]
226
+ mask_cls_results = mask_pred_results
227
+ logger.debug_once(f"Maskformer mask_pred_results shape: {mask_pred_results.shape}")
228
+ # upsample masks
229
+ # mask_pred_results = interpolate_or_crop(
230
+ # mask_pred_results,
231
+ # size=(images.tensor.shape[-2], images.tensor.shape[-1]),
232
+ # mode="bilinear",
233
+ # align_corners=False,
234
+ # )
235
+
236
+ processed_results = []
237
+ for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
238
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
239
+ ):
240
+
241
+ if raw_sem_seg:
242
+ processed_results.append({"sem_seg": mask_pred_result})
243
+ continue
244
+
245
+ height = input_per_image.get("height", image_size[0])
246
+ width = input_per_image.get("width", image_size[1])
247
+ logger.debug_once(f"Maskformer mask_pred_results target HW: {height, width}")
248
+ r = interpolate_or_crop(mask_pred_result[None], size=(height, width), mode="bilinear", align_corners=False)[0]
249
+
250
+ processed_results.append({"sem_seg": r})
251
+
252
+ # panoptic segmentation inference
253
+ # if self.panoptic_on:
254
+ # panoptic_r = self.panoptic_inference(mask_cls_result, mask_pred_result)
255
+ # processed_results[-1]["panoptic_seg"] = panoptic_r
256
+
257
+ # if 'features' in outputs:
258
+ # features = outputs['features']
259
+ # features = interpolate_or_crop(
260
+ # features,
261
+ # size=(images.tensor.shape[-2], images.tensor.shape[-1]),
262
+ # mode="bilinear",
263
+ # align_corners=False,
264
+ # )
265
+ # for res, f in zip(processed_results, features):
266
+ # res['features'] = f
267
+ del outputs
268
+
269
+ if not get_train:
270
+ return processed_results
271
+
272
+ return losses, processed_results
273
+
274
+
275
+ def prepare_targets(self, targets, images):
276
+ h, w = images.tensor.shape[-2:]
277
+ new_targets = []
278
+ for targets_per_image in targets:
279
+ # pad gt
280
+ gt_masks = targets_per_image.gt_masks
281
+ padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device)
282
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
283
+ new_targets.append(
284
+ {
285
+ "labels": targets_per_image.gt_classes,
286
+ "masks": padded_masks,
287
+ }
288
+ )
289
+ return new_targets
290
+
291
+
292
+ def semantic_inference(self, mask_cls, mask_pred):
293
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
294
+ mask_pred = mask_pred.sigmoid()
295
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
296
+ return semseg
297
+
298
+
299
+ def panoptic_inference(self, mask_cls, mask_pred):
300
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
301
+ mask_pred = mask_pred.sigmoid()
302
+
303
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
304
+ cur_scores = scores[keep]
305
+ cur_classes = labels[keep]
306
+ cur_masks = mask_pred[keep]
307
+ cur_mask_cls = mask_cls[keep]
308
+ cur_mask_cls = cur_mask_cls[:, :-1]
309
+
310
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
311
+
312
+ h, w = cur_masks.shape[-2:]
313
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
314
+ segments_info = []
315
+
316
+ current_segment_id = 0
317
+
318
+ if cur_masks.shape[0] == 0:
319
+ # We didn't detect any mask :(
320
+ return panoptic_seg, segments_info
321
+ else:
322
+ # take argmax
323
+ cur_mask_ids = cur_prob_masks.argmax(0)
324
+ stuff_memory_list = {}
325
+ for k in range(cur_classes.shape[0]):
326
+ pred_class = cur_classes[k].item()
327
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
328
+ mask = cur_mask_ids == k
329
+ mask_area = mask.sum().item()
330
+ original_area = (cur_masks[k] >= 0.5).sum().item()
331
+
332
+ if mask_area > 0 and original_area > 0:
333
+ if mask_area / original_area < self.overlap_threshold:
334
+ continue
335
+
336
+ # merge stuff regions
337
+ if not isthing:
338
+ if int(pred_class) in stuff_memory_list.keys():
339
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
340
+ continue
341
+ else:
342
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
343
+
344
+ current_segment_id += 1
345
+ panoptic_seg[mask] = current_segment_id
346
+
347
+ segments_info.append(
348
+ {
349
+ "id": current_segment_id,
350
+ "isthing": bool(isthing),
351
+ "category_id": int(pred_class),
352
+ }
353
+ )
354
+
355
+ return panoptic_seg, segments_info
mask_former/modeling/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .backbone.swin import D2SwinTransformer
3
+ from .backbone.vit import D2ViTTransformer
4
+ from .heads.mask_former_head import MaskFormerHead
5
+ from .heads.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead
6
+ from .heads.pixel_decoder import BasePixelDecoder
7
+ from .heads.big_pixel_decoder import BigPixelDecoder
8
+ from .heads.mega_big_pixel_decoder import MegaBigPixelDecoder
9
+ from .heads.mask_former_head_baseline import MaskFormerBaselineHead
mask_former/modeling/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
mask_former/modeling/backbone/swin.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+ import logging
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint as checkpoint
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
20
+ logger = logging.getLogger('gwm')
21
+
22
+ class Mlp(nn.Module):
23
+ """Multilayer perceptron."""
24
+
25
+ def __init__(
26
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
27
+ ):
28
+ super().__init__()
29
+ out_features = out_features or in_features
30
+ hidden_features = hidden_features or in_features
31
+ self.fc1 = nn.Linear(in_features, hidden_features)
32
+ self.act = act_layer()
33
+ self.fc2 = nn.Linear(hidden_features, out_features)
34
+ self.drop = nn.Dropout(drop)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+ x = self.act(x)
39
+ x = self.drop(x)
40
+ x = self.fc2(x)
41
+ x = self.drop(x)
42
+ return x
43
+
44
+
45
+ def window_partition(x, window_size):
46
+ """
47
+ Args:
48
+ x: (B, H, W, C)
49
+ window_size (int): window size
50
+ Returns:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ """
53
+ B, H, W, C = x.shape
54
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
55
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
56
+ return windows
57
+
58
+
59
+ def window_reverse(windows, window_size, H, W):
60
+ """
61
+ Args:
62
+ windows: (num_windows*B, window_size, window_size, C)
63
+ window_size (int): Window size
64
+ H (int): Height of image
65
+ W (int): Width of image
66
+ Returns:
67
+ x: (B, H, W, C)
68
+ """
69
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
70
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
71
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
72
+ return x
73
+
74
+
75
+ class WindowAttention(nn.Module):
76
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
77
+ It supports both of shifted and non-shifted window.
78
+ Args:
79
+ dim (int): Number of input channels.
80
+ window_size (tuple[int]): The height and width of the window.
81
+ num_heads (int): Number of attention heads.
82
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
83
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
84
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
85
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ dim,
91
+ window_size,
92
+ num_heads,
93
+ qkv_bias=True,
94
+ qk_scale=None,
95
+ attn_drop=0.0,
96
+ proj_drop=0.0,
97
+ ):
98
+
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.window_size = window_size # Wh, Ww
102
+ self.num_heads = num_heads
103
+ head_dim = dim // num_heads
104
+ self.scale = qk_scale or head_dim ** -0.5
105
+
106
+ # define a parameter table of relative position bias
107
+ self.relative_position_bias_table = nn.Parameter(
108
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
109
+ ) # 2*Wh-1 * 2*Ww-1, nH
110
+
111
+ # get pair-wise relative position index for each token inside the window
112
+ coords_h = torch.arange(self.window_size[0])
113
+ coords_w = torch.arange(self.window_size[1])
114
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
115
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
116
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
117
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
118
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
119
+ relative_coords[:, :, 1] += self.window_size[1] - 1
120
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
121
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
122
+ self.register_buffer("relative_position_index", relative_position_index)
123
+
124
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
125
+ self.attn_drop = nn.Dropout(attn_drop)
126
+ self.proj = nn.Linear(dim, dim)
127
+ self.proj_drop = nn.Dropout(proj_drop)
128
+
129
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
130
+ self.softmax = nn.Softmax(dim=-1)
131
+
132
+ def forward(self, x, mask=None):
133
+ """Forward function.
134
+ Args:
135
+ x: input features with shape of (num_windows*B, N, C)
136
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
137
+ """
138
+ B_, N, C = x.shape
139
+ qkv = (
140
+ self.qkv(x)
141
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
142
+ .permute(2, 0, 3, 1, 4)
143
+ )
144
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
145
+
146
+ q = q * self.scale
147
+ attn = q @ k.transpose(-2, -1)
148
+
149
+ relative_position_bias = self.relative_position_bias_table[
150
+ self.relative_position_index.view(-1)
151
+ ].view(
152
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
153
+ ) # Wh*Ww,Wh*Ww,nH
154
+ relative_position_bias = relative_position_bias.permute(
155
+ 2, 0, 1
156
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
157
+ attn = attn + relative_position_bias.unsqueeze(0)
158
+
159
+ if mask is not None:
160
+ nW = mask.shape[0]
161
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
162
+ attn = attn.view(-1, self.num_heads, N, N)
163
+ attn = self.softmax(attn)
164
+ else:
165
+ attn = self.softmax(attn)
166
+
167
+ attn = self.attn_drop(attn)
168
+
169
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
170
+ x = self.proj(x)
171
+ x = self.proj_drop(x)
172
+ return x
173
+
174
+
175
+ class SwinTransformerBlock(nn.Module):
176
+ """Swin Transformer Block.
177
+ Args:
178
+ dim (int): Number of input channels.
179
+ num_heads (int): Number of attention heads.
180
+ window_size (int): Window size.
181
+ shift_size (int): Shift size for SW-MSA.
182
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
183
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
184
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
185
+ drop (float, optional): Dropout rate. Default: 0.0
186
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
187
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
188
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
189
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ dim,
195
+ num_heads,
196
+ window_size=7,
197
+ shift_size=0,
198
+ mlp_ratio=4.0,
199
+ qkv_bias=True,
200
+ qk_scale=None,
201
+ drop=0.0,
202
+ attn_drop=0.0,
203
+ drop_path=0.0,
204
+ act_layer=nn.GELU,
205
+ norm_layer=nn.LayerNorm,
206
+ ):
207
+ super().__init__()
208
+ self.dim = dim
209
+ self.num_heads = num_heads
210
+ self.window_size = window_size
211
+ self.shift_size = shift_size
212
+ self.mlp_ratio = mlp_ratio
213
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
214
+
215
+ self.norm1 = norm_layer(dim)
216
+ self.attn = WindowAttention(
217
+ dim,
218
+ window_size=to_2tuple(self.window_size),
219
+ num_heads=num_heads,
220
+ qkv_bias=qkv_bias,
221
+ qk_scale=qk_scale,
222
+ attn_drop=attn_drop,
223
+ proj_drop=drop,
224
+ )
225
+
226
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
227
+ self.norm2 = norm_layer(dim)
228
+ mlp_hidden_dim = int(dim * mlp_ratio)
229
+ self.mlp = Mlp(
230
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
231
+ )
232
+
233
+ self.H = None
234
+ self.W = None
235
+
236
+ def forward(self, x, mask_matrix):
237
+ """Forward function.
238
+ Args:
239
+ x: Input feature, tensor size (B, H*W, C).
240
+ H, W: Spatial resolution of the input feature.
241
+ mask_matrix: Attention mask for cyclic shift.
242
+ """
243
+ B, L, C = x.shape
244
+ H, W = self.H, self.W
245
+ assert L == H * W, "input feature has wrong size"
246
+
247
+ shortcut = x
248
+ x = self.norm1(x)
249
+ x = x.view(B, H, W, C)
250
+
251
+ # pad feature maps to multiples of window size
252
+ pad_l = pad_t = 0
253
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
254
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
255
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
256
+ _, Hp, Wp, _ = x.shape
257
+
258
+ # cyclic shift
259
+ if self.shift_size > 0:
260
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
261
+ attn_mask = mask_matrix
262
+ else:
263
+ shifted_x = x
264
+ attn_mask = None
265
+
266
+ # partition windows
267
+ x_windows = window_partition(
268
+ shifted_x, self.window_size
269
+ ) # nW*B, window_size, window_size, C
270
+ x_windows = x_windows.view(
271
+ -1, self.window_size * self.window_size, C
272
+ ) # nW*B, window_size*window_size, C
273
+
274
+ # W-MSA/SW-MSA
275
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
276
+
277
+ # merge windows
278
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
279
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
280
+
281
+ # reverse cyclic shift
282
+ if self.shift_size > 0:
283
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
284
+ else:
285
+ x = shifted_x
286
+
287
+ if pad_r > 0 or pad_b > 0:
288
+ x = x[:, :H, :W, :].contiguous()
289
+
290
+ x = x.view(B, H * W, C)
291
+
292
+ # FFN
293
+ x = shortcut + self.drop_path(x)
294
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
295
+
296
+ return x
297
+
298
+
299
+ class PatchMerging(nn.Module):
300
+ """Patch Merging Layer
301
+ Args:
302
+ dim (int): Number of input channels.
303
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
304
+ """
305
+
306
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
307
+ super().__init__()
308
+ self.dim = dim
309
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
310
+ self.norm = norm_layer(4 * dim)
311
+
312
+ def forward(self, x, H, W):
313
+ """Forward function.
314
+ Args:
315
+ x: Input feature, tensor size (B, H*W, C).
316
+ H, W: Spatial resolution of the input feature.
317
+ """
318
+ B, L, C = x.shape
319
+ assert L == H * W, "input feature has wrong size"
320
+
321
+ x = x.view(B, H, W, C)
322
+
323
+ # padding
324
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
325
+ if pad_input:
326
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
327
+
328
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
329
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
330
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
331
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
332
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
333
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
334
+
335
+ x = self.norm(x)
336
+ x = self.reduction(x)
337
+
338
+ return x
339
+
340
+
341
+ class BasicLayer(nn.Module):
342
+ """A basic Swin Transformer layer for one stage.
343
+ Args:
344
+ dim (int): Number of feature channels
345
+ depth (int): Depths of this stage.
346
+ num_heads (int): Number of attention head.
347
+ window_size (int): Local window size. Default: 7.
348
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
349
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
350
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
351
+ drop (float, optional): Dropout rate. Default: 0.0
352
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
353
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
354
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
355
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
356
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
357
+ """
358
+
359
+ def __init__(
360
+ self,
361
+ dim,
362
+ depth,
363
+ num_heads,
364
+ window_size=7,
365
+ mlp_ratio=4.0,
366
+ qkv_bias=True,
367
+ qk_scale=None,
368
+ drop=0.0,
369
+ attn_drop=0.0,
370
+ drop_path=0.0,
371
+ norm_layer=nn.LayerNorm,
372
+ downsample=None,
373
+ use_checkpoint=False,
374
+ ):
375
+ super().__init__()
376
+ self.window_size = window_size
377
+ self.shift_size = window_size // 2
378
+ self.depth = depth
379
+ self.use_checkpoint = use_checkpoint
380
+
381
+ # build blocks
382
+ self.blocks = nn.ModuleList(
383
+ [
384
+ SwinTransformerBlock(
385
+ dim=dim,
386
+ num_heads=num_heads,
387
+ window_size=window_size,
388
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
389
+ mlp_ratio=mlp_ratio,
390
+ qkv_bias=qkv_bias,
391
+ qk_scale=qk_scale,
392
+ drop=drop,
393
+ attn_drop=attn_drop,
394
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
395
+ norm_layer=norm_layer,
396
+ )
397
+ for i in range(depth)
398
+ ]
399
+ )
400
+
401
+ # patch merging layer
402
+ if downsample is not None:
403
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
404
+ else:
405
+ self.downsample = None
406
+
407
+ def forward(self, x, H, W):
408
+ """Forward function.
409
+ Args:
410
+ x: Input feature, tensor size (B, H*W, C).
411
+ H, W: Spatial resolution of the input feature.
412
+ """
413
+
414
+ # calculate attention mask for SW-MSA
415
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
416
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
417
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
418
+ h_slices = (
419
+ slice(0, -self.window_size),
420
+ slice(-self.window_size, -self.shift_size),
421
+ slice(-self.shift_size, None),
422
+ )
423
+ w_slices = (
424
+ slice(0, -self.window_size),
425
+ slice(-self.window_size, -self.shift_size),
426
+ slice(-self.shift_size, None),
427
+ )
428
+ cnt = 0
429
+ for h in h_slices:
430
+ for w in w_slices:
431
+ img_mask[:, h, w, :] = cnt
432
+ cnt += 1
433
+
434
+ mask_windows = window_partition(
435
+ img_mask, self.window_size
436
+ ) # nW, window_size, window_size, 1
437
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
438
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
439
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
440
+ attn_mask == 0, float(0.0)
441
+ )
442
+
443
+ for blk in self.blocks:
444
+ blk.H, blk.W = H, W
445
+ if self.use_checkpoint:
446
+ x = checkpoint.checkpoint(blk, x, attn_mask)
447
+ else:
448
+ x = blk(x, attn_mask)
449
+ if self.downsample is not None:
450
+ x_down = self.downsample(x, H, W)
451
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
452
+ return x, H, W, x_down, Wh, Ww
453
+ else:
454
+ return x, H, W, x, H, W
455
+
456
+
457
+ class PatchEmbed(nn.Module):
458
+ """Image to Patch Embedding
459
+ Args:
460
+ patch_size (int): Patch token size. Default: 4.
461
+ in_chans (int): Number of input image channels. Default: 3.
462
+ embed_dim (int): Number of linear projection output channels. Default: 96.
463
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
464
+ """
465
+
466
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
467
+ super().__init__()
468
+ patch_size = to_2tuple(patch_size)
469
+ self.patch_size = patch_size
470
+
471
+ self.in_chans = in_chans
472
+ self.embed_dim = embed_dim
473
+
474
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
475
+ if norm_layer is not None:
476
+ self.norm = norm_layer(embed_dim)
477
+ else:
478
+ self.norm = None
479
+
480
+ def forward(self, x):
481
+ """Forward function."""
482
+ # padding
483
+ _, _, H, W = x.size()
484
+ if W % self.patch_size[1] != 0:
485
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
486
+ if H % self.patch_size[0] != 0:
487
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
488
+
489
+ x = self.proj(x) # B C Wh Ww
490
+ if self.norm is not None:
491
+ Wh, Ww = x.size(2), x.size(3)
492
+ x = x.flatten(2).transpose(1, 2)
493
+ x = self.norm(x)
494
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
495
+
496
+ return x
497
+
498
+
499
+ class SwinTransformer(nn.Module):
500
+ """Swin Transformer backbone.
501
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
502
+ https://arxiv.org/pdf/2103.14030
503
+ Args:
504
+ pretrain_img_size (int): Input image size for training the pretrained model,
505
+ used in absolute postion embedding. Default 224.
506
+ patch_size (int | tuple(int)): Patch size. Default: 4.
507
+ in_chans (int): Number of input image channels. Default: 3.
508
+ embed_dim (int): Number of linear projection output channels. Default: 96.
509
+ depths (tuple[int]): Depths of each Swin Transformer stage.
510
+ num_heads (tuple[int]): Number of attention head of each stage.
511
+ window_size (int): Window size. Default: 7.
512
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
513
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
514
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
515
+ drop_rate (float): Dropout rate.
516
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
517
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
518
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
519
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
520
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
521
+ out_indices (Sequence[int]): Output from which stages.
522
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
523
+ -1 means not freezing any parameters.
524
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
525
+ """
526
+
527
+ def __init__(
528
+ self,
529
+ pretrain_img_size=224,
530
+ patch_size=4,
531
+ in_chans=3,
532
+ embed_dim=96,
533
+ depths=[2, 2, 6, 2],
534
+ num_heads=[3, 6, 12, 24],
535
+ window_size=7,
536
+ mlp_ratio=4.0,
537
+ qkv_bias=True,
538
+ qk_scale=None,
539
+ drop_rate=0.0,
540
+ attn_drop_rate=0.0,
541
+ drop_path_rate=0.2,
542
+ norm_layer=nn.LayerNorm,
543
+ ape=False,
544
+ patch_norm=True,
545
+ out_indices=(0, 1, 2, 3),
546
+ frozen_stages=-1,
547
+ use_checkpoint=False,
548
+ ):
549
+ super().__init__()
550
+
551
+ self.pretrain_img_size = pretrain_img_size
552
+ self.num_layers = len(depths)
553
+ self.embed_dim = embed_dim
554
+ self.ape = ape
555
+ self.patch_norm = patch_norm
556
+ self.out_indices = out_indices
557
+ self.frozen_stages = frozen_stages
558
+
559
+ # split image into non-overlapping patches
560
+ self.patch_embed = PatchEmbed(
561
+ patch_size=patch_size,
562
+ in_chans=in_chans,
563
+ embed_dim=embed_dim,
564
+ norm_layer=norm_layer if self.patch_norm else None,
565
+ )
566
+
567
+ # absolute position embedding
568
+ if self.ape:
569
+ pretrain_img_size = to_2tuple(pretrain_img_size)
570
+ patch_size = to_2tuple(patch_size)
571
+ patches_resolution = [
572
+ pretrain_img_size[0] // patch_size[0],
573
+ pretrain_img_size[1] // patch_size[1],
574
+ ]
575
+
576
+ self.absolute_pos_embed = nn.Parameter(
577
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
578
+ )
579
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
580
+
581
+ self.pos_drop = nn.Dropout(p=drop_rate)
582
+
583
+ # stochastic depth
584
+ dpr = [
585
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
586
+ ] # stochastic depth decay rule
587
+
588
+ # build layers
589
+ self.layers = nn.ModuleList()
590
+ for i_layer in range(self.num_layers):
591
+ layer = BasicLayer(
592
+ dim=int(embed_dim * 2 ** i_layer),
593
+ depth=depths[i_layer],
594
+ num_heads=num_heads[i_layer],
595
+ window_size=window_size,
596
+ mlp_ratio=mlp_ratio,
597
+ qkv_bias=qkv_bias,
598
+ qk_scale=qk_scale,
599
+ drop=drop_rate,
600
+ attn_drop=attn_drop_rate,
601
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
602
+ norm_layer=norm_layer,
603
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
604
+ use_checkpoint=use_checkpoint,
605
+ )
606
+ self.layers.append(layer)
607
+
608
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
609
+ self.num_features = num_features
610
+
611
+ # add a norm layer for each output
612
+ for i_layer in out_indices:
613
+ layer = norm_layer(num_features[i_layer])
614
+ layer_name = f"norm{i_layer}"
615
+ self.add_module(layer_name, layer)
616
+
617
+ self._freeze_stages()
618
+ logger.info(f"Freezing {self.frozen_stages} Layers")
619
+
620
+ def _freeze_stages(self):
621
+ if self.frozen_stages >= 0:
622
+ self.patch_embed.eval()
623
+ for param in self.patch_embed.parameters():
624
+ param.requires_grad = False
625
+
626
+ if self.frozen_stages >= 1 and self.ape:
627
+ self.absolute_pos_embed.requires_grad = False
628
+
629
+ if self.frozen_stages >= 2:
630
+ self.pos_drop.eval()
631
+ for i in range(0, self.frozen_stages - 1):
632
+ m = self.layers[i]
633
+ m.eval()
634
+ for param in m.parameters():
635
+ param.requires_grad = False
636
+
637
+ def init_weights(self, pretrained=None):
638
+ """Initialize the weights in backbone.
639
+ Args:
640
+ pretrained (str, optional): Path to pre-trained weights.
641
+ Defaults to None.
642
+ """
643
+
644
+ def _init_weights(m):
645
+ if isinstance(m, nn.Linear):
646
+ trunc_normal_(m.weight, std=0.02)
647
+ if isinstance(m, nn.Linear) and m.bias is not None:
648
+ nn.init.constant_(m.bias, 0)
649
+ elif isinstance(m, nn.LayerNorm):
650
+ nn.init.constant_(m.bias, 0)
651
+ nn.init.constant_(m.weight, 1.0)
652
+
653
+ def forward(self, x):
654
+ """Forward function."""
655
+ x = self.patch_embed(x)
656
+
657
+ Wh, Ww = x.size(2), x.size(3)
658
+ if self.ape:
659
+ # interpolate the position embedding to the corresponding size
660
+ absolute_pos_embed = F.interpolate(
661
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
662
+ )
663
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
664
+ else:
665
+ x = x.flatten(2).transpose(1, 2)
666
+ x = self.pos_drop(x)
667
+
668
+ outs = {}
669
+ for i in range(self.num_layers):
670
+ layer = self.layers[i]
671
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
672
+
673
+ if i in self.out_indices:
674
+ norm_layer = getattr(self, f"norm{i}")
675
+ x_out = norm_layer(x_out)
676
+
677
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
678
+ outs["res{}".format(i + 2)] = out
679
+
680
+ return outs
681
+
682
+ def train(self, mode=True):
683
+ """Convert the model into training mode while keep layers freezed."""
684
+ super(SwinTransformer, self).train(mode)
685
+ self._freeze_stages()
686
+
687
+
688
+ @BACKBONE_REGISTRY.register()
689
+ class D2SwinTransformer(SwinTransformer, Backbone):
690
+ def __init__(self, cfg, input_shape):
691
+
692
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
693
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
694
+ in_chans = 3
695
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
696
+ depths = cfg.MODEL.SWIN.DEPTHS
697
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
698
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
699
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
700
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
701
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
702
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
703
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
704
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
705
+ norm_layer = nn.LayerNorm
706
+ ape = cfg.MODEL.SWIN.APE
707
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
708
+ frozen_stages = cfg.MODEL.BACKBONE.FREEZE_AT
709
+
710
+ super().__init__(
711
+ pretrain_img_size,
712
+ patch_size,
713
+ in_chans,
714
+ embed_dim,
715
+ depths,
716
+ num_heads,
717
+ window_size,
718
+ mlp_ratio,
719
+ qkv_bias,
720
+ qk_scale,
721
+ drop_rate,
722
+ attn_drop_rate,
723
+ drop_path_rate,
724
+ norm_layer,
725
+ ape,
726
+ patch_norm,
727
+ frozen_stages=frozen_stages,
728
+ )
729
+
730
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
731
+
732
+ self._out_feature_strides = {
733
+ "res2": 4,
734
+ "res3": 8,
735
+ "res4": 16,
736
+ "res5": 32,
737
+ }
738
+ self._out_feature_channels = {
739
+ "res2": self.num_features[0],
740
+ "res3": self.num_features[1],
741
+ "res4": self.num_features[2],
742
+ "res5": self.num_features[3],
743
+ }
744
+
745
+ def forward(self, x):
746
+ """
747
+ Args:
748
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
749
+ Returns:
750
+ dict[str->Tensor]: names and the corresponding features
751
+ """
752
+ assert (
753
+ x.dim() == 4
754
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
755
+ outputs = {}
756
+ y = super().forward(x)
757
+ for k in y.keys():
758
+ if k in self._out_features:
759
+ outputs[k] = y[k]
760
+ return outputs
761
+
762
+ def output_shape(self):
763
+ return {
764
+ name: ShapeSpec(
765
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
766
+ )
767
+ for name in self._out_features
768
+ }
769
+
770
+ @property
771
+ def size_divisibility(self):
772
+ return 32
mask_former/modeling/backbone/vit.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+ import logging
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint as checkpoint
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
20
+ logger = logging.getLogger('gwm')
21
+ import argparse
22
+ import torch
23
+ import torchvision.transforms
24
+ from torch import nn
25
+ from torchvision import transforms
26
+ import torch.nn.modules.utils as nn_utils
27
+ import math
28
+ import timm
29
+ import types
30
+ from pathlib import Path
31
+ from typing import Union, List, Tuple
32
+ from PIL import Image
33
+ import einops
34
+
35
+ class ViTExtractor:
36
+ """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT.
37
+
38
+ We use the following notation in the documentation of the module's methods:
39
+ B - batch size
40
+ h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW
41
+ p - patch size of the ViT. either 8 or 16.
42
+ t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width
43
+ of the input image.
44
+ d - the embedding dimension in the ViT.
45
+ """
46
+
47
+ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'):
48
+ """
49
+ :param model_type: A string specifying the type of model to extract from.
50
+ [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 |
51
+ vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]
52
+ :param stride: stride of first convolution layer. small stride -> higher resolution.
53
+ :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor.
54
+ should be compatible with model_type.
55
+ """
56
+ self.model_type = model_type
57
+ self.device = device
58
+ if model is not None:
59
+ self.model = model
60
+ else:
61
+ self.model = ViTExtractor.create_model(model_type)
62
+
63
+ self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride)
64
+ # self.model.eval()
65
+ self.model.to(self.device)
66
+ self.p = self.model.patch_embed.patch_size
67
+ self.stride = self.model.patch_embed.proj.stride
68
+
69
+ self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5)
70
+ self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5)
71
+
72
+ self._feats = []
73
+ self.hook_handlers = []
74
+ self.load_size = None
75
+ self.num_patches = None
76
+
77
+ @staticmethod
78
+ def create_model(model_type: str) -> nn.Module:
79
+ """
80
+ :param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 |
81
+ dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 |
82
+ vit_base_patch16_224]
83
+ :return: the model
84
+ """
85
+ if 'dino' in model_type:
86
+ model = torch.hub.load('facebookresearch/dino:main', model_type)
87
+ else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
88
+ temp_model = timm.create_model(model_type, pretrained=True)
89
+ model_type_dict = {
90
+ 'vit_small_patch16_224': 'dino_vits16',
91
+ 'vit_small_patch8_224': 'dino_vits8',
92
+ 'vit_base_patch16_224': 'dino_vitb16',
93
+ 'vit_base_patch8_224': 'dino_vitb8'
94
+ }
95
+ model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type])
96
+ temp_state_dict = temp_model.state_dict()
97
+ del temp_state_dict['head.weight']
98
+ del temp_state_dict['head.bias']
99
+ model.load_state_dict(temp_state_dict)
100
+ return model
101
+
102
+ @staticmethod
103
+ def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]):
104
+ """
105
+ Creates a method for position encoding interpolation.
106
+ :param patch_size: patch size of the model.
107
+ :param stride_hw: A tuple containing the new height and width stride respectively.
108
+ :return: the interpolation method
109
+ """
110
+ def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
111
+ npatch = x.shape[1] - 1
112
+ N = self.pos_embed.shape[1] - 1
113
+ if npatch == N and w == h:
114
+ return self.pos_embed
115
+ class_pos_embed = self.pos_embed[:, 0]
116
+ patch_pos_embed = self.pos_embed[:, 1:]
117
+ dim = x.shape[-1]
118
+ # compute number of tokens taking stride into account
119
+ w0 = 1 + (w - patch_size) // stride_hw[1]
120
+ h0 = 1 + (h - patch_size) // stride_hw[0]
121
+ assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and
122
+ stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}"""
123
+ # we add a small number to avoid floating point error in the interpolation
124
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
125
+ w0, h0 = w0 + 0.1, h0 + 0.1
126
+ patch_pos_embed = nn.functional.interpolate(
127
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
128
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
129
+ mode='bicubic',
130
+ align_corners=False, recompute_scale_factor=False
131
+ )
132
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
133
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
134
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
135
+
136
+ return interpolate_pos_encoding
137
+
138
+ @staticmethod
139
+ def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module:
140
+ """
141
+ change resolution of model output by changing the stride of the patch extraction.
142
+ :param model: the model to change resolution for.
143
+ :param stride: the new stride parameter.
144
+ :return: the adjusted model
145
+ """
146
+ patch_size = model.patch_embed.patch_size
147
+ if stride == patch_size: # nothing to do
148
+ return model
149
+
150
+ stride = nn_utils._pair(stride)
151
+ assert all([(patch_size // s_) * s_ == patch_size for s_ in
152
+ stride]), f'stride {stride} should divide patch_size {patch_size}'
153
+
154
+ # fix the stride
155
+ model.patch_embed.proj.stride = stride
156
+ # fix the positional encoding code
157
+ model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model)
158
+ return model
159
+
160
+ def preprocess(self, image_path: Union[str, Path],
161
+ load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]:
162
+ """
163
+ Preprocesses an image before extraction.
164
+ :param image_path: path to image to be extracted.
165
+ :param load_size: optional. Size to resize image before the rest of preprocessing.
166
+ :return: a tuple containing:
167
+ (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW.
168
+ (2) the pil image in relevant dimensions
169
+ """
170
+ pil_image = Image.open(image_path).convert('RGB')
171
+ if load_size is not None:
172
+ pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image)
173
+ prep = transforms.Compose([
174
+ transforms.ToTensor(),
175
+ transforms.Normalize(mean=self.mean, std=self.std)
176
+ ])
177
+ prep_img = prep(pil_image)[None, ...]
178
+ return prep_img, pil_image
179
+
180
+ def _get_hook(self, facet: str):
181
+ """
182
+ generate a hook method for a specific block and facet.
183
+ """
184
+ if facet in ['attn', 'token']:
185
+ def _hook(model, input, output):
186
+ self._feats.append(output)
187
+ return _hook
188
+
189
+ if facet == 'query':
190
+ facet_idx = 0
191
+ elif facet == 'key':
192
+ facet_idx = 1
193
+ elif facet == 'value':
194
+ facet_idx = 2
195
+ else:
196
+ raise TypeError(f"{facet} is not a supported facet.")
197
+
198
+ def _inner_hook(module, input, output):
199
+ input = input[0]
200
+ B, N, C = input.shape
201
+ qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
202
+ self._feats.append(qkv[facet_idx]) #Bxhxtxd
203
+ return _inner_hook
204
+
205
+ def _register_hooks(self, layers: List[int], facet: str) -> None:
206
+ """
207
+ register hook to extract features.
208
+ :param layers: layers from which to extract features.
209
+ :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
210
+ """
211
+ for block_idx, block in enumerate(self.model.blocks):
212
+ if block_idx in layers:
213
+ if facet == 'token':
214
+ self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet)))
215
+ elif facet == 'attn':
216
+ self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet)))
217
+ elif facet in ['key', 'query', 'value']:
218
+ self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet)))
219
+ else:
220
+ raise TypeError(f"{facet} is not a supported facet.")
221
+
222
+ def _unregister_hooks(self) -> None:
223
+ """
224
+ unregisters the hooks. should be called after feature extraction.
225
+ """
226
+ for handle in self.hook_handlers:
227
+ handle.remove()
228
+ self.hook_handlers = []
229
+
230
+ def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]:
231
+ """
232
+ extract features from the model
233
+ :param batch: batch to extract features for. Has shape BxCxHxW.
234
+ :param layers: layer to extract. A number between 0 to 11.
235
+ :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
236
+ :return : tensor of features.
237
+ if facet is 'key' | 'query' | 'value' has shape Bxhxtxd
238
+ if facet is 'attn' has shape Bxhxtxt
239
+ if facet is 'token' has shape Bxtxd
240
+ """
241
+ B, C, H, W = batch.shape
242
+ self._feats = []
243
+ self._register_hooks(layers, facet)
244
+ with torch.no_grad():
245
+ _ = self.model(batch)
246
+ self._unregister_hooks()
247
+ self.load_size = (H, W)
248
+ self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1])
249
+ return self._feats
250
+
251
+ def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
252
+ """
253
+ create a log-binned descriptor.
254
+ :param x: tensor of features. Has shape Bxhxtxd.
255
+ :param hierarchy: how many bin hierarchies to use.
256
+ """
257
+ B = x.shape[0]
258
+ num_bins = 1 + 8 * hierarchy
259
+
260
+ bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh)
261
+ bin_x = bin_x.permute(0, 2, 1)
262
+ bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1])
263
+ # Bx(dxh)xnum_patches[0]xnum_patches[1]
264
+ sub_desc_dim = bin_x.shape[1]
265
+
266
+ avg_pools = []
267
+ # compute bins of all sizes for all spatial locations.
268
+ for k in range(0, hierarchy):
269
+ # avg pooling with kernel 3**kx3**k
270
+ win_size = 3 ** k
271
+ avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False)
272
+ avg_pools.append(avg_pool(bin_x))
273
+
274
+ bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device)
275
+ for y in range(self.num_patches[0]):
276
+ for x in range(self.num_patches[1]):
277
+ part_idx = 0
278
+ # fill all bins for a spatial location (y, x)
279
+ for k in range(0, hierarchy):
280
+ kernel_size = 3 ** k
281
+ for i in range(y - kernel_size, y + kernel_size + 1, kernel_size):
282
+ for j in range(x - kernel_size, x + kernel_size + 1, kernel_size):
283
+ if i == y and j == x and k != 0:
284
+ continue
285
+ if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]:
286
+ bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
287
+ :, :, i, j]
288
+ else: # handle padding in a more delicate way than zero padding
289
+ temp_i = max(0, min(i, self.num_patches[0] - 1))
290
+ temp_j = max(0, min(j, self.num_patches[1] - 1))
291
+ bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
292
+ :, :, temp_i,
293
+ temp_j]
294
+ part_idx += 1
295
+ bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1)
296
+ # Bx1x(t-1)x(dxh)
297
+ return bin_x
298
+
299
+ def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key',
300
+ bin: bool = False, include_cls: bool = False) -> torch.Tensor:
301
+ """
302
+ extract descriptors from the model
303
+ :param batch: batch to extract descriptors for. Has shape BxCxHxW.
304
+ :param layers: layer to extract. A number between 0 to 11.
305
+ :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token']
306
+ :param bin: apply log binning to the descriptor. default is False.
307
+ :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors.
308
+ """
309
+ assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors.
310
+ choose from ['key' | 'query' | 'value' | 'token'] """
311
+ self._extract_features(batch, [layer], facet)
312
+ x = self._feats[0]
313
+ if facet == 'token':
314
+ x.unsqueeze_(dim=1) #Bx1xtxd
315
+ if not include_cls:
316
+ x = x[:, :, 1:, :] # remove cls token
317
+ else:
318
+ assert not bin, "bin = True and include_cls = True are not supported together, set one of them False."
319
+ if not bin:
320
+ desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
321
+ else:
322
+ desc = self._log_bin(x)
323
+ return desc
324
+
325
+ def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor:
326
+ """
327
+ extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer
328
+ in of the CLS token. All values are then normalized to range between 0 and 1.
329
+ :param batch: batch to extract saliency maps for. Has shape BxCxHxW.
330
+ :return: a tensor of saliency maps. has shape Bxt-1
331
+ """
332
+ assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type."
333
+ self._extract_features(batch, [11], 'attn')
334
+ head_idxs = [0, 2, 4, 5]
335
+ curr_feats = self._feats[0] #Bxhxtxt
336
+ cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1)
337
+ temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0]
338
+ cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1]
339
+ return cls_attn_maps
340
+
341
+ @BACKBONE_REGISTRY.register()
342
+ class D2ViTTransformer(Backbone):
343
+ def __init__(self, cfg, input_shape):
344
+
345
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
346
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
347
+ in_chans = 3
348
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
349
+ depths = cfg.MODEL.SWIN.DEPTHS
350
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
351
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
352
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
353
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
354
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
355
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
356
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
357
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
358
+ norm_layer = nn.LayerNorm
359
+ ape = cfg.MODEL.SWIN.APE
360
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
361
+ frozen_stages = cfg.MODEL.BACKBONE.FREEZE_AT
362
+
363
+ super().__init__()
364
+ self.num_layers = 12
365
+ num_features = [int(embed_dim) for i in range(self.num_layers)]
366
+ self.num_features = num_features
367
+ self.frozen_stages = frozen_stages
368
+ self.extractor = ViTExtractor( model_type='dino_vitb8', stride = 4, model = None, device = cfg.MODEL.DEVICE)
369
+ if self.frozen_stages >= 0:
370
+ for block_idx, block in enumerate(self.extractor.model.blocks):
371
+ if block_idx <= self.frozen_stages:
372
+ block.eval()
373
+ for p in block.parameters():
374
+ p.requires_grad = False
375
+
376
+ for block_idx, block in enumerate(self.extractor.model.blocks):
377
+ if all(p.requires_grad == False for p in block.parameters()):
378
+ print(f"DINO {block_idx} frozen")
379
+
380
+
381
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
382
+
383
+ self._out_feature_strides = {
384
+ "res2": 4,
385
+ "res3": 8,
386
+ "res4": 16,
387
+ "res5": 32,
388
+ }
389
+ self._out_feature_channels = {
390
+ "res2": self.num_features[0],
391
+ "res3": self.num_features[1],
392
+ "res4": self.num_features[2],
393
+ "res5": self.num_features[3],
394
+ }
395
+
396
+ def forward(self, x):
397
+ facet = 'key'
398
+ self.extractor._extract_features(x, [5, 7, 9, 11], facet=facet)
399
+ res2 = self.extractor._feats[0].unsqueeze_(dim=1) # Bx1xtxd
400
+ res3 = self.extractor._feats[1].unsqueeze_(dim=1) # Bx1xtxd
401
+ res4 = self.extractor._feats[2].unsqueeze_(dim=1) # Bx1xtxd
402
+ res5 = self.extractor._feats[3].unsqueeze_(dim=1) # Bx1xtxd
403
+ if facet == 'key':
404
+ res2 = einops.rearrange(res2, 'b c h t d -> b c t (d h)') # Bx1xtxd
405
+ res3 = einops.rearrange(res3, 'b c h t d -> b c t (d h)') # Bx1xtxd
406
+ res4 = einops.rearrange(res4, 'b c h t d -> b c t (d h)') # Bx1xtxd
407
+ res5 = einops.rearrange(res5, 'b c h t d -> b c t (d h)') # Bx1xtxd
408
+
409
+ res2 = res2.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
410
+ res3 = res3.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
411
+ res4 = res4.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
412
+ res5 = res5.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
413
+
414
+ res2 = res2[:, :, 1:, :] # remove cls token
415
+ res3 = res3[:, :, 1:, :] # remove cls token
416
+ res4 = res4[:, :, 1:, :] # remove cls token
417
+ res5 = res5[:, :, 1:, :] # remove cls token
418
+
419
+ res2 = res2.reshape(res2.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
420
+ res3 = res3.reshape(res3.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
421
+ res4 = res4.reshape(res4.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
422
+ res5 = res5.reshape(res5.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
423
+
424
+ return {
425
+ "res2": res2,
426
+ "res3": res3,
427
+ "res4": res4,
428
+ "res5": res5,
429
+ }
430
+
431
+ def output_shape(self):
432
+ return {
433
+ name: ShapeSpec(
434
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
435
+ )
436
+ for name in self._out_features
437
+ }
438
+
439
+ @property
440
+ def size_divisibility(self):
441
+ return 32
mask_former/modeling/criterion.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ """
4
+ MaskFormer criterion.
5
+ """
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from detectron2.utils.comm import get_world_size
11
+
12
+ from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
13
+
14
+
15
+ def dice_loss(inputs, targets, num_masks):
16
+ """
17
+ Compute the DICE loss, similar to generalized IOU for masks
18
+ Args:
19
+ inputs: A float tensor of arbitrary shape.
20
+ The predictions for each example.
21
+ targets: A float tensor with the same shape as inputs. Stores the binary
22
+ classification label for each element in inputs
23
+ (0 for the negative class and 1 for the positive class).
24
+ """
25
+ inputs = inputs.sigmoid()
26
+ inputs = inputs.flatten(1)
27
+ numerator = 2 * (inputs * targets).sum(-1)
28
+ denominator = inputs.sum(-1) + targets.sum(-1)
29
+ loss = 1 - (numerator + 1) / (denominator + 1)
30
+ return loss.sum() / num_masks
31
+
32
+
33
+ def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2):
34
+ """
35
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
36
+ Args:
37
+ inputs: A float tensor of arbitrary shape.
38
+ The predictions for each example.
39
+ targets: A float tensor with the same shape as inputs. Stores the binary
40
+ classification label for each element in inputs
41
+ (0 for the negative class and 1 for the positive class).
42
+ alpha: (optional) Weighting factor in range (0,1) to balance
43
+ positive vs negative examples. Default = -1 (no weighting).
44
+ gamma: Exponent of the modulating factor (1 - p_t) to
45
+ balance easy vs hard examples.
46
+ Returns:
47
+ Loss tensor
48
+ """
49
+ prob = inputs.sigmoid()
50
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
51
+ p_t = prob * targets + (1 - prob) * (1 - targets)
52
+ loss = ce_loss * ((1 - p_t) ** gamma)
53
+
54
+ if alpha >= 0:
55
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
56
+ loss = alpha_t * loss
57
+
58
+ return loss.mean(1).sum() / num_masks
59
+
60
+
61
+ class SetCriterion(nn.Module):
62
+ """This class computes the loss for DETR.
63
+ The process happens in two steps:
64
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
65
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
66
+ """
67
+
68
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
69
+ """Create the criterion.
70
+ Parameters:
71
+ num_classes: number of object categories, omitting the special no-object category
72
+ matcher: module able to compute a matching between targets and proposals
73
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
74
+ eos_coef: relative classification weight applied to the no-object category
75
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
76
+ """
77
+ super().__init__()
78
+ self.num_classes = num_classes
79
+ self.matcher = matcher
80
+ self.weight_dict = weight_dict
81
+ self.eos_coef = eos_coef
82
+ self.losses = losses
83
+ empty_weight = torch.ones(self.num_classes + 1)
84
+ empty_weight[-1] = self.eos_coef
85
+ self.register_buffer("empty_weight", empty_weight)
86
+
87
+ def loss_labels(self, outputs, targets, indices, num_masks):
88
+ """Classification loss (NLL)
89
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
90
+ """
91
+ assert "pred_logits" in outputs
92
+ src_logits = outputs["pred_logits"]
93
+
94
+ idx = self._get_src_permutation_idx(indices)
95
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
96
+ target_classes = torch.full(
97
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
98
+ )
99
+ target_classes[idx] = target_classes_o
100
+
101
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
102
+ losses = {"loss_ce": loss_ce}
103
+ return losses
104
+
105
+ def loss_masks(self, outputs, targets, indices, num_masks):
106
+ """Compute the losses related to the masks: the focal loss and the dice loss.
107
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
108
+ """
109
+ assert "pred_masks" in outputs
110
+
111
+ src_idx = self._get_src_permutation_idx(indices)
112
+ tgt_idx = self._get_tgt_permutation_idx(indices)
113
+ src_masks = outputs["pred_masks"]
114
+ src_masks = src_masks[src_idx]
115
+ masks = [t["masks"] for t in targets]
116
+ # TODO use valid to mask invalid areas due to padding in loss
117
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
118
+ target_masks = target_masks.to(src_masks)
119
+ target_masks = target_masks[tgt_idx]
120
+
121
+ # upsample predictions to the target size
122
+ src_masks = F.interpolate(
123
+ src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
124
+ )
125
+ src_masks = src_masks[:, 0].flatten(1)
126
+
127
+ target_masks = target_masks.flatten(1)
128
+ target_masks = target_masks.view(src_masks.shape)
129
+ losses = {
130
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
131
+ "loss_dice": dice_loss(src_masks, target_masks, num_masks),
132
+ }
133
+ return losses
134
+
135
+ def _get_src_permutation_idx(self, indices):
136
+ # permute predictions following indices
137
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
138
+ src_idx = torch.cat([src for (src, _) in indices])
139
+ return batch_idx, src_idx
140
+
141
+ def _get_tgt_permutation_idx(self, indices):
142
+ # permute targets following indices
143
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
144
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
145
+ return batch_idx, tgt_idx
146
+
147
+ def get_loss(self, loss, outputs, targets, indices, num_masks):
148
+ loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
149
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
150
+ return loss_map[loss](outputs, targets, indices, num_masks)
151
+
152
+ def forward(self, outputs, targets):
153
+ """This performs the loss computation.
154
+ Parameters:
155
+ outputs: dict of tensors, see the output specification of the model for the format
156
+ targets: list of dicts, such that len(targets) == batch_size.
157
+ The expected keys in each dict depends on the losses applied, see each loss' doc
158
+ """
159
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
160
+
161
+ # Retrieve the matching between the outputs of the last layer and the targets
162
+ indices = self.matcher(outputs_without_aux, targets)
163
+
164
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
165
+ num_masks = sum(len(t["labels"]) for t in targets)
166
+ num_masks = torch.as_tensor(
167
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
168
+ )
169
+ if is_dist_avail_and_initialized():
170
+ torch.distributed.all_reduce(num_masks)
171
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
172
+
173
+ # Compute all the requested losses
174
+ losses = {}
175
+ for loss in self.losses:
176
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
177
+
178
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
179
+ if "aux_outputs" in outputs:
180
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
181
+ indices = self.matcher(aux_outputs, targets)
182
+ for loss in self.losses:
183
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
184
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
185
+ losses.update(l_dict)
186
+
187
+ return losses
mask_former/modeling/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
mask_former/modeling/heads/big_pixel_decoder.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from typing import Callable, Dict, Optional, Union
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ from detectron2.config import configurable
7
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
8
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from ..transformer.position_encoding import PositionEmbeddingSine
13
+ from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer
14
+
15
+ @SEM_SEG_HEADS_REGISTRY.register()
16
+ class BigPixelDecoder(nn.Module):
17
+ @configurable
18
+ def __init__(
19
+ self,
20
+ input_shape: Dict[str, ShapeSpec],
21
+ *,
22
+ conv_dim: int,
23
+ mask_dim: int,
24
+ norm: Optional[Union[str, Callable]] = None,
25
+ ):
26
+ """
27
+ NOTE: this interface is experimental.
28
+ Args:
29
+ input_shape: shapes (channels and stride) of the input features
30
+ conv_dims: number of output channels for the intermediate conv layers.
31
+ mask_dim: number of output channels for the final conv layer.
32
+ norm (str or callable): normalization for all conv layers
33
+ """
34
+ super().__init__()
35
+
36
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
37
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
38
+ feature_channels = [v.channels for k, v in input_shape]
39
+
40
+ lateral_convs = []
41
+ output_convs = []
42
+
43
+ use_bias = norm == ""
44
+ for idx, in_channels in enumerate(feature_channels):
45
+ if idx == len(self.in_features) - 1:
46
+ output_norm = get_norm(norm, conv_dim)
47
+ output_conv = Conv2d(
48
+ in_channels,
49
+ conv_dim,
50
+ kernel_size=3,
51
+ stride=1,
52
+ padding=1,
53
+ bias=use_bias,
54
+ norm=output_norm,
55
+ activation=F.relu,
56
+ )
57
+ weight_init.c2_xavier_fill(output_conv)
58
+ self.add_module("layer_{}".format(idx + 1), output_conv)
59
+
60
+ lateral_convs.append(None)
61
+ output_convs.append(output_conv)
62
+ else:
63
+ lateral_norm = get_norm(norm, conv_dim)
64
+ output_norm = get_norm(norm, conv_dim)
65
+
66
+ lateral_conv = Conv2d(
67
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
68
+ )
69
+ output_conv = Conv2d(
70
+ conv_dim,
71
+ conv_dim,
72
+ kernel_size=3,
73
+ stride=1,
74
+ padding=1,
75
+ bias=use_bias,
76
+ norm=output_norm,
77
+ activation=F.relu,
78
+ )
79
+ weight_init.c2_xavier_fill(lateral_conv)
80
+ weight_init.c2_xavier_fill(output_conv)
81
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
82
+ self.add_module("layer_{}".format(idx + 1), output_conv)
83
+
84
+ lateral_convs.append(lateral_conv)
85
+ output_convs.append(output_conv)
86
+ # Place convs into top-down order (from low to high resolution)
87
+ # to make the top-down computation in forward clearer.
88
+ self.lateral_convs = lateral_convs[::-1]
89
+ self.output_convs = output_convs[::-1]
90
+
91
+ self.mask_dim = mask_dim
92
+ # self.mask_features = Conv2d(
93
+ # conv_dim,
94
+ # mask_dim,
95
+ # kernel_size=3,
96
+ # stride=1,
97
+ # padding=1,
98
+ # )
99
+
100
+ # weight_init.c2_xavier_fill(self.mask_features)
101
+
102
+ self.mask_features = nn.Sequential(
103
+ Conv2d(
104
+ conv_dim,
105
+ conv_dim,
106
+ kernel_size=3,
107
+ stride=1,
108
+ padding=1,
109
+ bias=use_bias,
110
+ norm=output_norm,
111
+ activation=F.relu,
112
+ ),
113
+ nn.UpsamplingNearest2d(scale_factor=2),
114
+ Conv2d(
115
+ conv_dim,
116
+ conv_dim,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=1,
120
+ bias=use_bias,
121
+ norm=output_norm,
122
+ activation=F.relu,
123
+ ),
124
+ Conv2d(
125
+ conv_dim,
126
+ conv_dim,
127
+ kernel_size=3,
128
+ stride=1,
129
+ padding=1,
130
+ bias=use_bias,
131
+ norm=output_norm,
132
+ activation=F.relu,
133
+ ),
134
+ nn.UpsamplingNearest2d(scale_factor=2),
135
+ Conv2d(
136
+ conv_dim,
137
+ conv_dim,
138
+ kernel_size=1,
139
+ stride=1,
140
+ padding=1,
141
+ bias=use_bias,
142
+ norm=output_norm,
143
+ activation=F.relu,
144
+ ),
145
+ Conv2d(
146
+ conv_dim,
147
+ mask_dim,
148
+ kernel_size=3,
149
+ stride=1,
150
+ padding=1,
151
+ )
152
+ )
153
+
154
+ for name, module in self.mask_features.named_modules():
155
+ if 'Conv2d' in name:
156
+ weight_init.c2_xavier_fill(module)
157
+
158
+ @classmethod
159
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
160
+ ret = {}
161
+ ret["input_shape"] = {
162
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
163
+ }
164
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
165
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
166
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
167
+ return ret
168
+
169
+ def forward_features(self, features):
170
+ # Reverse feature maps into top-down order (from low to high resolution)
171
+ for idx, f in enumerate(self.in_features[::-1]):
172
+ x = features[f]
173
+ lateral_conv = self.lateral_convs[idx]
174
+ output_conv = self.output_convs[idx]
175
+ if lateral_conv is None:
176
+ y = output_conv(x)
177
+ else:
178
+ cur_fpn = lateral_conv(x)
179
+ # Following FPN implementation, we use nearest upsampling here
180
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
181
+ y = output_conv(y)
182
+ return self.mask_features(y), None
183
+
184
+ def forward(self, features, targets=None):
185
+ logger = logging.getLogger(__name__)
186
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
187
+ return self.forward_features(features)
188
+
189
+
190
+ class TransformerEncoderOnly(nn.Module):
191
+ def __init__(
192
+ self,
193
+ d_model=512,
194
+ nhead=8,
195
+ num_encoder_layers=6,
196
+ dim_feedforward=2048,
197
+ dropout=0.1,
198
+ activation="relu",
199
+ normalize_before=False,
200
+ ):
201
+ super().__init__()
202
+
203
+ encoder_layer = TransformerEncoderLayer(
204
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
205
+ )
206
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
207
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
208
+
209
+ self._reset_parameters()
210
+
211
+ self.d_model = d_model
212
+ self.nhead = nhead
213
+
214
+ def _reset_parameters(self):
215
+ for p in self.parameters():
216
+ if p.dim() > 1:
217
+ nn.init.xavier_uniform_(p)
218
+
219
+ def forward(self, src, mask, pos_embed):
220
+ # flatten NxCxHxW to HWxNxC
221
+ bs, c, h, w = src.shape
222
+ src = src.flatten(2).permute(2, 0, 1)
223
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
224
+ if mask is not None:
225
+ mask = mask.flatten(1)
226
+
227
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
228
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
mask_former/modeling/heads/mask_former_head.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
12
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
13
+
14
+ from ..transformer.transformer_predictor import TransformerPredictor
15
+ from .pixel_decoder import build_pixel_decoder
16
+
17
+
18
+ @SEM_SEG_HEADS_REGISTRY.register()
19
+ class MaskFormerHead(nn.Module):
20
+
21
+ _version = 2
22
+
23
+ def _load_from_state_dict(
24
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
25
+ ):
26
+ version = local_metadata.get("version", None)
27
+ if version is None or version < 2:
28
+ # Do not warn if train from scratch
29
+ scratch = True
30
+ logger = logging.getLogger(__name__)
31
+ for k in list(state_dict.keys()):
32
+ newk = k
33
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
34
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
35
+ # logger.debug(f"{k} ==> {newk}")
36
+ if newk != k:
37
+ state_dict[newk] = state_dict[k]
38
+ del state_dict[k]
39
+ scratch = False
40
+
41
+ if not scratch:
42
+ logger.warning(
43
+ f"Weight format of {self.__class__.__name__} have changed! "
44
+ "Please upgrade your models. Applying automatic conversion now ..."
45
+ )
46
+
47
+ @configurable
48
+ def __init__(
49
+ self,
50
+ input_shape: Dict[str, ShapeSpec],
51
+ *,
52
+ num_classes: int,
53
+ pixel_decoder: nn.Module,
54
+ loss_weight: float = 1.0,
55
+ ignore_value: int = -1,
56
+ # extra parameters
57
+ transformer_predictor: nn.Module,
58
+ transformer_in_feature: str,
59
+ ):
60
+ """
61
+ NOTE: this interface is experimental.
62
+ Args:
63
+ input_shape: shapes (channels and stride) of the input features
64
+ num_classes: number of classes to predict
65
+ pixel_decoder: the pixel decoder module
66
+ loss_weight: loss weight
67
+ ignore_value: category id to be ignored during training.
68
+ transformer_predictor: the transformer decoder that makes prediction
69
+ transformer_in_feature: input feature name to the transformer_predictor
70
+ """
71
+ super().__init__()
72
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
73
+ self.in_features = [k for k, v in input_shape]
74
+ feature_strides = [v.stride for k, v in input_shape]
75
+ feature_channels = [v.channels for k, v in input_shape]
76
+
77
+ self.ignore_value = ignore_value
78
+ self.common_stride = 4
79
+ self.loss_weight = loss_weight
80
+
81
+ self.pixel_decoder = pixel_decoder
82
+ self.predictor = transformer_predictor
83
+ self.transformer_in_feature = transformer_in_feature
84
+
85
+ self.num_classes = num_classes
86
+
87
+ @classmethod
88
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
89
+ return {
90
+ "input_shape": {
91
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
92
+ },
93
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
94
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
95
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
96
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
97
+ "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
98
+ "transformer_predictor": TransformerPredictor(
99
+ cfg,
100
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
101
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
102
+ else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
103
+ mask_classification=True,
104
+ ),
105
+ }
106
+
107
+ def forward(self, features):
108
+ return self.layers(features)
109
+
110
+ def layers(self, features):
111
+ mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
112
+ if self.transformer_in_feature == "transformer_encoder":
113
+ assert (
114
+ transformer_encoder_features is not None
115
+ ), "Please use the TransformerEncoderPixelDecoder."
116
+ predictions = self.predictor(transformer_encoder_features, mask_features)
117
+ else:
118
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features)
119
+ # predictions['features'] = mask_features
120
+ return predictions
mask_former/modeling/heads/mask_former_head_baseline.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ from torch import nn
8
+ import torch
9
+ from torch.nn import functional as F
10
+
11
+ from detectron2.config import configurable
12
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
13
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
14
+
15
+ from ..transformer.transformer_predictor import TransformerPredictor
16
+ from .pixel_decoder import build_pixel_decoder
17
+
18
+
19
+ @SEM_SEG_HEADS_REGISTRY.register()
20
+ class MaskFormerBaselineHead(nn.Module):
21
+
22
+ _version = 2
23
+
24
+ def _load_from_state_dict(
25
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
26
+ ):
27
+ version = local_metadata.get("version", None)
28
+ if version is None or version < 2:
29
+ # Do not warn if train from scratch
30
+ scratch = True
31
+ logger = logging.getLogger(__name__)
32
+ for k in list(state_dict.keys()):
33
+ newk = k
34
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
35
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
36
+ # logger.debug(f"{k} ==> {newk}")
37
+ if newk != k:
38
+ state_dict[newk] = state_dict[k]
39
+ del state_dict[k]
40
+ scratch = False
41
+
42
+ if not scratch:
43
+ logger.warning(
44
+ f"Weight format of {self.__class__.__name__} have changed! "
45
+ "Please upgrade your models. Applying automatic conversion now ..."
46
+ )
47
+
48
+ @configurable
49
+ def __init__(
50
+ self,
51
+ input_shape: Dict[str, ShapeSpec],
52
+ *,
53
+ num_classes: int,
54
+ pixel_decoder: nn.Module,
55
+ loss_weight: float = 1.0,
56
+ ignore_value: int = -1,
57
+ # extra parameters
58
+ transformer_predictor: nn.Module,
59
+ transformer_in_feature: str,
60
+ ):
61
+ """
62
+ NOTE: this interface is experimental.
63
+ Args:
64
+ input_shape: shapes (channels and stride) of the input features
65
+ num_classes: number of classes to predict
66
+ pixel_decoder: the pixel decoder module
67
+ loss_weight: loss weight
68
+ ignore_value: category id to be ignored during training.
69
+ transformer_predictor: the transformer decoder that makes prediction
70
+ transformer_in_feature: input feature name to the transformer_predictor
71
+ """
72
+ super().__init__()
73
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
74
+ self.in_features = [k for k, v in input_shape]
75
+ feature_strides = [v.stride for k, v in input_shape]
76
+ feature_channels = [v.channels for k, v in input_shape]
77
+
78
+ self.ignore_value = ignore_value
79
+ self.common_stride = 4
80
+ self.loss_weight = loss_weight
81
+
82
+ self.pixel_decoder = pixel_decoder
83
+ self.predictor = transformer_predictor
84
+ self.transformer_in_feature = transformer_in_feature
85
+ inc = 256
86
+ self.out_layers = nn.Sequential(nn.Conv2d(inc, inc, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(), nn.Linear(inc, 1))
87
+ self.num_classes = num_classes
88
+
89
+ @classmethod
90
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
91
+ return {
92
+ "input_shape": {
93
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
94
+ },
95
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
96
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
97
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
98
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
99
+ "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
100
+ "transformer_predictor": TransformerPredictor(
101
+ cfg,
102
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
103
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
104
+ else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
105
+ mask_classification=True,
106
+ ),
107
+ }
108
+
109
+ def forward(self, features):
110
+ f = self.layers(features)
111
+
112
+ return self.out_layers(f).squeeze(-1)
113
+
114
+ def layers(self, features):
115
+ mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
116
+ # if self.transformer_in_feature == "transformer_encoder":
117
+ # assert (
118
+ # transformer_encoder_features is not None
119
+ # ), "Please use the TransformerEncoderPixelDecoder."
120
+ # predictions = self.predictor(transformer_encoder_features, mask_features)
121
+ # else:
122
+ # predictions = self.predictor(features[self.transformer_in_feature], mask_features)
123
+ return mask_features