Spaces:
Runtime error
Runtime error
Code Commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +5 -5
- app.py +158 -0
- checkpoints/checkpoint_best.pth +3 -0
- config.py +377 -0
- configs/README.md +16 -0
- configs/maskformer/Base-ADE20K-150.yaml +60 -0
- configs/maskformer/Base-unsup-vidseg.yaml +3 -0
- configs/maskformer/cityscapes-19/Base-Cityscapes-19.yaml +59 -0
- configs/maskformer/cityscapes-19/maskformer_R101_bs16_90k.yaml +36 -0
- configs/maskformer/cityscapes-19/maskformer_R101c_bs16_90k.yaml +16 -0
- configs/maskformer/maskformer_R50_bs16_160k.yaml +27 -0
- configs/maskformer/maskformer_R50_bs16_160k_dino.yaml +31 -0
- configs/maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml +45 -0
- configs/maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml +45 -0
- configs/maskformer/swin/maskformer_swin_small_bs16_160k.yaml +23 -0
- configs/maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml +23 -0
- configs/maskformer/swin/maskformer_swin_tiny_bs16_160k_moby.yaml +23 -0
- datasets/__init__.py +2 -0
- datasets/flow_eval_detectron.py +209 -0
- datasets/flow_pair_detectron.py +275 -0
- determinism.py +24 -0
- dist.py +34 -0
- eval_utils.py +282 -0
- flow_reconstruction.py +54 -0
- losses/__init__.py +28 -0
- losses/reconstruction_loss.py +85 -0
- main.py +270 -0
- mask_former/__init__.py +19 -0
- mask_former/config.py +85 -0
- mask_former/data/__init__.py +2 -0
- mask_former/data/dataset_mappers/__init__.py +1 -0
- mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py +180 -0
- mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py +165 -0
- mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py +184 -0
- mask_former/data/datasets/__init__.py +7 -0
- mask_former/data/datasets/register_ade20k_full.py +964 -0
- mask_former/data/datasets/register_ade20k_panoptic.py +387 -0
- mask_former/data/datasets/register_coco_stuff_10k.py +223 -0
- mask_former/data/datasets/register_mapillary_vistas.py +507 -0
- mask_former/mask_former_model.py +355 -0
- mask_former/modeling/__init__.py +9 -0
- mask_former/modeling/backbone/__init__.py +1 -0
- mask_former/modeling/backbone/swin.py +772 -0
- mask_former/modeling/backbone/vit.py +441 -0
- mask_former/modeling/criterion.py +187 -0
- mask_former/modeling/heads/__init__.py +1 -0
- mask_former/modeling/heads/big_pixel_decoder.py +228 -0
- mask_former/modeling/heads/mask_former_head.py +120 -0
- mask_former/modeling/heads/mask_former_head_baseline.py +123 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
checkpoints/checkpoint_best.pth filter=lfs diff=lfs merge=lfs -text
|
36 |
+
samples/1920px-Woman_at_work,_Gujarat.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.17.0
|
8 |
app_file: app.py
|
@@ -10,4 +10,4 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
1 |
---
|
2 |
+
title: GWM
|
3 |
+
emoji: 🏄
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.17.0
|
8 |
app_file: app.py
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
This is a demo for https://www.robots.ox.ac.uk/~vgg/research/gwm/.
|
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
try:
|
5 |
+
import detectron2
|
6 |
+
except:
|
7 |
+
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
8 |
+
|
9 |
+
import logging
|
10 |
+
logging.disable(logging.CRITICAL) # comment out to enable verbose logging
|
11 |
+
|
12 |
+
#########################################################
|
13 |
+
import pathlib
|
14 |
+
import gradio as gr
|
15 |
+
import numpy as np
|
16 |
+
import PIL.Image as Image
|
17 |
+
import os
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
from PIL import Image
|
23 |
+
from collections import defaultdict
|
24 |
+
from pathlib import Path
|
25 |
+
from torch.utils.data import Dataset, DataLoader
|
26 |
+
from torchvision import transforms as T
|
27 |
+
from tqdm import tqdm
|
28 |
+
from types import SimpleNamespace
|
29 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
30 |
+
from detectron2.data import MetadataCatalog
|
31 |
+
from detectron2.utils.visualizer import Visualizer
|
32 |
+
|
33 |
+
import config
|
34 |
+
import utils as ut
|
35 |
+
from eval_utils import MaskMerger
|
36 |
+
from mask_former_trainer import setup, Trainer
|
37 |
+
|
38 |
+
|
39 |
+
def load_model_cfg(dataset=None):
|
40 |
+
|
41 |
+
args = SimpleNamespace(config_file='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml', opts=["GWM.DATASET", dataset], wandb_sweep_mode=False, resume_path=str('checkpoints/checkpoint_best.pth'), eval_only=True)
|
42 |
+
cfg = setup(args)
|
43 |
+
cfg.defrost()
|
44 |
+
cfg.MODEL.DEVICE = 'cpu'
|
45 |
+
cfg.freeze()
|
46 |
+
random_state = ut.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE))
|
47 |
+
|
48 |
+
model = Trainer.build_model(cfg)
|
49 |
+
checkpointer = DetectionCheckpointer(model,
|
50 |
+
random_state=random_state,
|
51 |
+
save_dir=None)
|
52 |
+
|
53 |
+
checkpoint_path = 'checkpoints/checkpoint_best.pth'
|
54 |
+
checkpoint = checkpointer.resume_or_load(checkpoint_path, resume=False)
|
55 |
+
model.eval()
|
56 |
+
|
57 |
+
return model, cfg
|
58 |
+
|
59 |
+
def edgeness(masks):
|
60 |
+
|
61 |
+
em = torch.zeros(1, masks.shape[-2], masks.shape[-1], device=masks.device)
|
62 |
+
lm = em.clone()
|
63 |
+
lm[..., :2] = 1.
|
64 |
+
rm = em.clone()
|
65 |
+
rm[...,-2:] = 1.
|
66 |
+
tm = em.clone()
|
67 |
+
tm[..., :2, :] = 1.
|
68 |
+
bm = em.clone()
|
69 |
+
bm[..., -2:,:] = 1.
|
70 |
+
|
71 |
+
one = torch.tensor(1.,dtype= masks.dtype, device=masks.device)
|
72 |
+
|
73 |
+
l = (masks * lm).flatten(-2).sum(-1) / lm.sum()
|
74 |
+
l = torch.where(l > 0.3, one, l)
|
75 |
+
r = (masks * rm).flatten(-2).sum(-1) / rm.sum()
|
76 |
+
r = torch.where(r > 0.3, one, r)
|
77 |
+
t = (masks * tm).flatten(-2).sum(-1) / tm.sum()
|
78 |
+
t = torch.where(t > 0.3, one, t)
|
79 |
+
b = (masks * bm).flatten(-2).sum(-1) / bm.sum()
|
80 |
+
b = torch.where(b > 0.3, one, b)
|
81 |
+
return (l + r + t + b )
|
82 |
+
|
83 |
+
def expand2sizedivisible(pil_img, background_color, size_divisibility):
|
84 |
+
width, height = pil_img.size
|
85 |
+
if width % size_divisibility == 0 and height % size_divisibility == 0:
|
86 |
+
return pil_img
|
87 |
+
result = Image.new(pil_img.mode, (width + (size_divisibility - width%size_divisibility)%size_divisibility, height + (size_divisibility - height%size_divisibility)%size_divisibility), background_color)
|
88 |
+
result.paste(pil_img, (((size_divisibility - width%size_divisibility)%size_divisibility) // 2, ((size_divisibility - height%size_divisibility)%size_divisibility) // 2))
|
89 |
+
|
90 |
+
return result
|
91 |
+
|
92 |
+
def cropfromsizedivisible(img, size_divisibility, orig_size):
|
93 |
+
height, width = img.shape[:2]
|
94 |
+
owidth, oheight = orig_size
|
95 |
+
result = img[(height-oheight)//2:oheight+(height-oheight)//2, (width-owidth)//2:owidth+(width-owidth)//2]
|
96 |
+
|
97 |
+
return result
|
98 |
+
|
99 |
+
|
100 |
+
def evaluate_image(image_path):
|
101 |
+
binary_threshold = 0.5
|
102 |
+
metadata = MetadataCatalog.get("__unused")
|
103 |
+
|
104 |
+
model, cfg = load_model_cfg("DAVIS")
|
105 |
+
|
106 |
+
merger = MaskMerger(cfg, model, merger_model="dino_vitb8")
|
107 |
+
|
108 |
+
|
109 |
+
image_pil = Image.open(image_path).convert('RGB')
|
110 |
+
|
111 |
+
image_pil.thumbnail((384, 384))
|
112 |
+
|
113 |
+
osize = image_pil.size
|
114 |
+
if cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY > 0:
|
115 |
+
image_pil = expand2sizedivisible(image_pil, 0, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY)
|
116 |
+
|
117 |
+
image = np.asarray(image_pil)
|
118 |
+
image_pt = torch.from_numpy(np.array(image)).permute(2,0,1)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
sample = [{'rgb': image_pt}]
|
122 |
+
preds = model.forward_base(sample, keys=['rgb'], get_eval=True)
|
123 |
+
masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
|
124 |
+
|
125 |
+
K = masks_raw.shape[1]
|
126 |
+
if K > 2:
|
127 |
+
masks_softmaxed = torch.softmax(masks_raw, dim=1)
|
128 |
+
masks_dict = merger(sample, masks_softmaxed)
|
129 |
+
K = 2
|
130 |
+
masks = masks_dict['cos']
|
131 |
+
else:
|
132 |
+
print(K)
|
133 |
+
masks = masks_raw.softmax(1)
|
134 |
+
masks_raw = F.interpolate(masks, size=(image_pt.shape[-2], image_pt.shape[-1]), mode='bilinear') # t s 1 h w
|
135 |
+
bg = edgeness(masks_raw)[0].argmax().item()
|
136 |
+
|
137 |
+
masks = masks_raw[0] > binary_threshold
|
138 |
+
frame_visualizer = Visualizer(image, metadata)
|
139 |
+
out = frame_visualizer.overlay_instances(
|
140 |
+
masks=masks[[int(bg==0)]],
|
141 |
+
alpha=0.3,
|
142 |
+
assigned_colors=[(1,0,1)]
|
143 |
+
).get_image()
|
144 |
+
|
145 |
+
return cropfromsizedivisible(out, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, osize)
|
146 |
+
|
147 |
+
|
148 |
+
paths = sorted(pathlib.Path('samples').glob('*.jpg'))
|
149 |
+
css = ".component-1 {height: 256px !important;}"
|
150 |
+
demo = gr.Interface(
|
151 |
+
fn=evaluate_image,
|
152 |
+
inputs=gr.Image(label='Image', type='filepath'),
|
153 |
+
outputs=gr.Image(label='Annotated Image', type='numpy'),
|
154 |
+
examples=[[path.as_posix(), 0.15, 6] for path in paths],
|
155 |
+
title="Guess What Moves",
|
156 |
+
description="#### Unsupervised Image segmentation mode of [Guess What Moves](https://www.robots.ox.ac.uk/~vgg/research/gwm/)",
|
157 |
+
css=css)
|
158 |
+
demo.queue().launch()
|
checkpoints/checkpoint_best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a573e43ffc78e84dbf8b4f2c9e31195bed3c44d8e7be942daba92c7437b8b17d
|
3 |
+
size 63672283
|
config.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import itertools
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch.utils.data
|
9 |
+
from detectron2.config import CfgNode as CN
|
10 |
+
|
11 |
+
import utils
|
12 |
+
from datasets import FlowPairDetectron, FlowEvalDetectron
|
13 |
+
|
14 |
+
logger = logging.getLogger('gwm')
|
15 |
+
|
16 |
+
def scan_train_flow(folders, res, pairs, basepath):
|
17 |
+
pair_list = [p for p in itertools.combinations(pairs, 2)]
|
18 |
+
|
19 |
+
flow_dir = {}
|
20 |
+
for pair in pair_list:
|
21 |
+
p1, p2 = pair
|
22 |
+
flowpairs = []
|
23 |
+
for f in folders:
|
24 |
+
path1 = basepath / f'Flows_gap{p1}' / res / f
|
25 |
+
path2 = basepath / f'Flows_gap{p2}' / res / f
|
26 |
+
|
27 |
+
flows1 = [p.name for p in path1.glob('*.flo')]
|
28 |
+
flows2 = [p.name for p in path2.glob('*.flo')]
|
29 |
+
|
30 |
+
flows1 = sorted(flows1)
|
31 |
+
flows2 = sorted(flows2)
|
32 |
+
|
33 |
+
intersect = list(set(flows1).intersection(flows2))
|
34 |
+
intersect.sort()
|
35 |
+
|
36 |
+
flowpair = np.array([[path1 / i, path2 / i] for i in intersect])
|
37 |
+
flowpairs += [flowpair]
|
38 |
+
flow_dir['gap_{}_{}'.format(p1, p2)] = flowpairs
|
39 |
+
|
40 |
+
# flow_dir is a dictionary, with keys indicating the flow gap, and each value is a list of sequence names,
|
41 |
+
# each item then is an array with Nx2, N indicates the number of available pairs.
|
42 |
+
return flow_dir
|
43 |
+
|
44 |
+
|
45 |
+
def setup_dataset(cfg=None, multi_val=False):
|
46 |
+
dataset_str = cfg.GWM.DATASET
|
47 |
+
if '+' in dataset_str:
|
48 |
+
datasets = dataset_str.split('+')
|
49 |
+
logger.info(f'Multiple datasets detected: {datasets}')
|
50 |
+
train_datasets = []
|
51 |
+
val_datasets = []
|
52 |
+
for ds in datasets:
|
53 |
+
proxy_cfg = copy.deepcopy(cfg)
|
54 |
+
proxy_cfg.merge_from_list(['GWM.DATASET', ds]),
|
55 |
+
train_ds, val_ds = setup_dataset(proxy_cfg, multi_val=multi_val)
|
56 |
+
train_datasets.append(train_ds)
|
57 |
+
val_datasets.append(val_ds)
|
58 |
+
logger.info(f'Multiple datasets detected: {datasets}')
|
59 |
+
logger.info(f'Validation is still : {datasets[0]}')
|
60 |
+
return torch.utils.data.ConcatDataset(train_datasets), val_datasets[0]
|
61 |
+
|
62 |
+
resolution = cfg.GWM.RESOLUTION # h,w
|
63 |
+
res = ""
|
64 |
+
with_gt = True
|
65 |
+
pairs = [1, 2, -1, -2]
|
66 |
+
trainval_data_dir = None
|
67 |
+
|
68 |
+
if cfg.GWM.DATASET == 'DAVIS':
|
69 |
+
basepath = '/DAVIS2016'
|
70 |
+
img_dir = '/DAVIS2016/JPEGImages/480p'
|
71 |
+
gt_dir = '/DAVIS2016/Annotations/480p'
|
72 |
+
|
73 |
+
val_flow_dir = '/DAVIS2016/Flows_gap1/1080p'
|
74 |
+
val_seq = ['dog', 'cows', 'goat', 'camel', 'libby', 'parkour', 'soapbox', 'blackswan', 'bmx-trees',
|
75 |
+
'kite-surf', 'car-shadow', 'breakdance', 'dance-twirl', 'scooter-black', 'drift-chicane',
|
76 |
+
'motocross-jump', 'horsejump-high', 'drift-straight', 'car-roundabout', 'paragliding-launch']
|
77 |
+
val_data_dir = [val_flow_dir, img_dir, gt_dir]
|
78 |
+
res = "1080p"
|
79 |
+
|
80 |
+
elif cfg.GWM.DATASET in ['FBMS']:
|
81 |
+
basepath = '/FBMS_clean'
|
82 |
+
img_dir = '/FBMS_clean/JPEGImages/'
|
83 |
+
gt_dir = '/FBMS_clean/Annotations/'
|
84 |
+
|
85 |
+
val_flow_dir = '/FBMS_val/Flows_gap1/'
|
86 |
+
val_seq = ['camel01', 'cars1', 'cars10', 'cars4', 'cars5', 'cats01', 'cats03', 'cats06',
|
87 |
+
'dogs01', 'dogs02', 'farm01', 'giraffes01', 'goats01', 'horses02', 'horses04',
|
88 |
+
'horses05', 'lion01', 'marple12', 'marple2', 'marple4', 'marple6', 'marple7', 'marple9',
|
89 |
+
'people03', 'people1', 'people2', 'rabbits02', 'rabbits03', 'rabbits04', 'tennis']
|
90 |
+
val_img_dir = '/FBMS_val/JPEGImages/'
|
91 |
+
val_gt_dir = '/FBMS_val/Annotations/'
|
92 |
+
val_data_dir = [val_flow_dir, val_img_dir, val_gt_dir]
|
93 |
+
with_gt = False
|
94 |
+
pairs = [3, 6, -3, -6]
|
95 |
+
|
96 |
+
elif cfg.GWM.DATASET in ['STv2']:
|
97 |
+
basepath = '/SegTrackv2'
|
98 |
+
img_dir = '/SegTrackv2/JPEGImages'
|
99 |
+
gt_dir = '/SegTrackv2/Annotations'
|
100 |
+
|
101 |
+
val_flow_dir = '/SegTrackv2/Flows_gap1/'
|
102 |
+
val_seq = ['drift', 'birdfall', 'girl', 'cheetah', 'worm', 'parachute', 'monkeydog',
|
103 |
+
'hummingbird', 'soldier', 'bmx', 'frog', 'penguin', 'monkey', 'bird_of_paradise']
|
104 |
+
val_data_dir = [val_flow_dir, img_dir, gt_dir]
|
105 |
+
|
106 |
+
else:
|
107 |
+
raise ValueError('Unknown Setting/Dataset.')
|
108 |
+
|
109 |
+
# Switching this section to pathlib, which should prevent double // errors in paths and dict keys
|
110 |
+
|
111 |
+
root_path_str = cfg.GWM.DATA_ROOT
|
112 |
+
logger.info(f"Found DATA_ROOT in config: {root_path_str}")
|
113 |
+
root_path_str = '../data'
|
114 |
+
|
115 |
+
if root_path_str.startswith('/'):
|
116 |
+
root_path = Path(f"/{root_path_str.lstrip('/').rstrip('/')}")
|
117 |
+
else:
|
118 |
+
root_path = Path(f"{root_path_str.lstrip('/').rstrip('/')}")
|
119 |
+
|
120 |
+
logger.info(f"Loading dataset from: {root_path}")
|
121 |
+
|
122 |
+
basepath = root_path / basepath.lstrip('/').rstrip('/')
|
123 |
+
img_dir = root_path / img_dir.lstrip('/').rstrip('/')
|
124 |
+
gt_dir = root_path / gt_dir.lstrip('/').rstrip('/')
|
125 |
+
val_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in val_data_dir]
|
126 |
+
|
127 |
+
folders = [p.name for p in (basepath / f'Flows_gap{pairs[0]}' / res).iterdir() if p.is_dir()]
|
128 |
+
folders = sorted(folders)
|
129 |
+
|
130 |
+
# flow_dir is a dictionary, with keys indicating the flow gap, and each value is a list of sequence names,
|
131 |
+
# each item then is an array with Nx2, N indicates the number of available pairs.
|
132 |
+
|
133 |
+
flow_dir = scan_train_flow(folders, res, pairs, basepath)
|
134 |
+
data_dir = [flow_dir, img_dir, gt_dir]
|
135 |
+
|
136 |
+
force1080p = ('DAVIS' not in cfg.GWM.DATASET) and 'RGB_BIG' in cfg.GWM.SAMPLE_KEYS
|
137 |
+
|
138 |
+
enable_photometric_augmentations = cfg.FLAGS.INF_TPS
|
139 |
+
|
140 |
+
train_dataset = FlowPairDetectron(data_dir=data_dir,
|
141 |
+
resolution=resolution,
|
142 |
+
to_rgb=cfg.GWM.FLOW2RGB,
|
143 |
+
size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
|
144 |
+
enable_photo_aug=enable_photometric_augmentations,
|
145 |
+
flow_clip=cfg.GWM.FLOW_CLIP,
|
146 |
+
norm=cfg.GWM.FLOW_NORM,
|
147 |
+
force1080p=force1080p,
|
148 |
+
flow_res=cfg.GWM.FLOW_RES, )
|
149 |
+
if multi_val:
|
150 |
+
print(f"Using multiple validation datasets from {val_data_dir}")
|
151 |
+
val_dataset = [FlowEvalDetectron(data_dir=val_data_dir,
|
152 |
+
resolution=resolution,
|
153 |
+
pair_list=pairs,
|
154 |
+
val_seq=[vs],
|
155 |
+
to_rgb=cfg.GWM.FLOW2RGB,
|
156 |
+
with_rgb=False,
|
157 |
+
size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
|
158 |
+
flow_clip=cfg.GWM.FLOW_CLIP,
|
159 |
+
norm=cfg.GWM.FLOW_NORM,
|
160 |
+
force1080p=force1080p) for vs in val_seq]
|
161 |
+
for vs, vds in zip(val_seq, val_dataset):
|
162 |
+
print(f"Validation dataset for {vs}: {len(vds)}")
|
163 |
+
if len(vds) == 0:
|
164 |
+
raise ValueError(f"Empty validation dataset for {vs}")
|
165 |
+
|
166 |
+
if cfg.GWM.TTA_AS_TRAIN:
|
167 |
+
if trainval_data_dir is None:
|
168 |
+
trainval_data_dir = val_data_dir
|
169 |
+
else:
|
170 |
+
trainval_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in trainval_data_dir]
|
171 |
+
trainval_dataset = []
|
172 |
+
tvd_basepath = root_path / str(trainval_data_dir[0].relative_to(root_path)).split('/')[0]
|
173 |
+
print("TVD BASE DIR", tvd_basepath)
|
174 |
+
for vs in val_seq:
|
175 |
+
tvd_data_dir = [scan_train_flow([vs], res, pairs, tvd_basepath), *trainval_data_dir[1:]]
|
176 |
+
tvd = FlowPairDetectron(data_dir=tvd_data_dir,
|
177 |
+
resolution=resolution,
|
178 |
+
to_rgb=cfg.GWM.FLOW2RGB,
|
179 |
+
size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
|
180 |
+
enable_photo_aug=cfg.GWM.LOSS_MULT.EQV is not None,
|
181 |
+
flow_clip=cfg.GWM.FLOW_CLIP,
|
182 |
+
norm=cfg.GWM.FLOW_NORM,
|
183 |
+
force1080p=force1080p,
|
184 |
+
flow_res=cfg.GWM.FLOW_RES, )
|
185 |
+
trainval_dataset.append(tvd)
|
186 |
+
print(f'Seq {trainval_data_dir[0]}/{vs} dataset: {len(tvd)}')
|
187 |
+
else:
|
188 |
+
if trainval_data_dir is None:
|
189 |
+
trainval_dataset = val_dataset
|
190 |
+
else:
|
191 |
+
trainval_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in trainval_data_dir]
|
192 |
+
trainval_dataset = []
|
193 |
+
for vs in val_seq:
|
194 |
+
tvd = FlowEvalDetectron(data_dir=trainval_data_dir,
|
195 |
+
resolution=resolution,
|
196 |
+
pair_list=pairs,
|
197 |
+
val_seq=[vs],
|
198 |
+
to_rgb=cfg.GWM.FLOW2RGB,
|
199 |
+
with_rgb=False,
|
200 |
+
size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
|
201 |
+
flow_clip=cfg.GWM.FLOW_CLIP,
|
202 |
+
norm=cfg.GWM.FLOW_NORM,
|
203 |
+
force1080p=force1080p)
|
204 |
+
trainval_dataset.append(tvd)
|
205 |
+
print(f'Seq {trainval_data_dir[0]}/{vs} dataset: {len(tvd)}')
|
206 |
+
return train_dataset, val_dataset, trainval_dataset
|
207 |
+
val_dataset = FlowEvalDetectron(data_dir=val_data_dir,
|
208 |
+
resolution=resolution,
|
209 |
+
pair_list=pairs,
|
210 |
+
val_seq=val_seq,
|
211 |
+
to_rgb=cfg.GWM.FLOW2RGB,
|
212 |
+
with_rgb=False,
|
213 |
+
size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1,
|
214 |
+
flow_clip=cfg.GWM.FLOW_CLIP,
|
215 |
+
norm=cfg.GWM.FLOW_NORM,
|
216 |
+
force1080p=force1080p)
|
217 |
+
|
218 |
+
return train_dataset, val_dataset
|
219 |
+
|
220 |
+
|
221 |
+
def loaders(cfg):
|
222 |
+
train_dataset, val_dataset = setup_dataset(cfg)
|
223 |
+
logger.info(f"Sourcing data from {val_dataset.data_dir[0]}")
|
224 |
+
|
225 |
+
if cfg.FLAGS.DEV_DATA:
|
226 |
+
subset = cfg.SOLVER.IMS_PER_BATCH * 3
|
227 |
+
train_dataset = torch.utils.data.Subset(train_dataset, list(range(subset)))
|
228 |
+
val_dataset = torch.utils.data.Subset(val_dataset, list(range(subset)))
|
229 |
+
|
230 |
+
g = torch.Generator()
|
231 |
+
data_generator_seed = int(torch.randint(int(1e6), (1,)).item())
|
232 |
+
logger.info(f"Dataloaders generator seed {data_generator_seed}")
|
233 |
+
g.manual_seed(data_generator_seed)
|
234 |
+
|
235 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
236 |
+
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
237 |
+
batch_size=cfg.SOLVER.IMS_PER_BATCH,
|
238 |
+
collate_fn=lambda x: x,
|
239 |
+
shuffle=True,
|
240 |
+
pin_memory=True,
|
241 |
+
drop_last=True,
|
242 |
+
persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0,
|
243 |
+
worker_init_fn=utils.random_state.worker_init_function,
|
244 |
+
generator=g
|
245 |
+
)
|
246 |
+
val_loader = torch.utils.data.DataLoader(val_dataset,
|
247 |
+
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
248 |
+
batch_size=1,
|
249 |
+
shuffle=False,
|
250 |
+
pin_memory=True,
|
251 |
+
collate_fn=lambda x: x,
|
252 |
+
drop_last=False,
|
253 |
+
persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0,
|
254 |
+
worker_init_fn=utils.random_state.worker_init_function,
|
255 |
+
generator=g)
|
256 |
+
return train_loader, val_loader
|
257 |
+
|
258 |
+
|
259 |
+
def multi_loaders(cfg):
|
260 |
+
train_dataset, val_datasets, train_val_datasets = setup_dataset(cfg, multi_val=True)
|
261 |
+
logger.info(f"Sourcing multiple loaders from {len(val_datasets)}")
|
262 |
+
logger.info(f"Sourcing data from {val_datasets[0].data_dir[0]}")
|
263 |
+
|
264 |
+
g = torch.Generator()
|
265 |
+
data_generator_seed = int(torch.randint(int(1e6), (1,)).item())
|
266 |
+
logger.info(f"Dataloaders generator seed {data_generator_seed}")
|
267 |
+
g.manual_seed(data_generator_seed)
|
268 |
+
|
269 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
270 |
+
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
271 |
+
batch_size=cfg.SOLVER.IMS_PER_BATCH,
|
272 |
+
collate_fn=lambda x: x,
|
273 |
+
shuffle=True,
|
274 |
+
pin_memory=True,
|
275 |
+
drop_last=True,
|
276 |
+
persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0,
|
277 |
+
worker_init_fn=utils.random_state.worker_init_function,
|
278 |
+
generator=g
|
279 |
+
)
|
280 |
+
|
281 |
+
val_loaders = [(torch.utils.data.DataLoader(val_dataset,
|
282 |
+
num_workers=0,
|
283 |
+
batch_size=1,
|
284 |
+
shuffle=False,
|
285 |
+
pin_memory=True,
|
286 |
+
collate_fn=lambda x: x,
|
287 |
+
drop_last=False,
|
288 |
+
persistent_workers=False,
|
289 |
+
worker_init_fn=utils.random_state.worker_init_function,
|
290 |
+
generator=g),
|
291 |
+
torch.utils.data.DataLoader(tv_dataset,
|
292 |
+
num_workers=0,
|
293 |
+
batch_size=cfg.SOLVER.IMS_PER_BATCH,
|
294 |
+
shuffle=True,
|
295 |
+
pin_memory=False,
|
296 |
+
collate_fn=lambda x: x,
|
297 |
+
drop_last=False,
|
298 |
+
persistent_workers=False,
|
299 |
+
worker_init_fn=utils.random_state.worker_init_function,
|
300 |
+
generator=g))
|
301 |
+
for val_dataset, tv_dataset in zip(val_datasets, train_val_datasets)]
|
302 |
+
|
303 |
+
return train_loader, val_loaders
|
304 |
+
|
305 |
+
|
306 |
+
def add_gwm_config(cfg):
|
307 |
+
cfg.GWM = CN()
|
308 |
+
cfg.GWM.MODEL = "MASKFORMER"
|
309 |
+
cfg.GWM.RESOLUTION = (128, 224)
|
310 |
+
cfg.GWM.FLOW_RES = (480, 854)
|
311 |
+
cfg.GWM.SAMPLE_KEYS = ["rgb"]
|
312 |
+
cfg.GWM.ADD_POS_EMB = False
|
313 |
+
cfg.GWM.CRITERION = "L2"
|
314 |
+
cfg.GWM.L1_OPTIMIZE = False
|
315 |
+
cfg.GWM.HOMOGRAPHY = 'quad' # False
|
316 |
+
cfg.GWM.HOMOGRAPHY_SUBSAMPLE = 8
|
317 |
+
cfg.GWM.HOMOGRAPHY_SKIP = 0.4
|
318 |
+
cfg.GWM.DATASET = 'DAVIS'
|
319 |
+
cfg.GWM.DATA_ROOT = None
|
320 |
+
cfg.GWM.FLOW2RGB = False
|
321 |
+
cfg.GWM.SIMPLE_REC = False
|
322 |
+
cfg.GWM.DAVIS_SINGLE_VID = None
|
323 |
+
cfg.GWM.USE_MULT_FLOW = False
|
324 |
+
cfg.GWM.FLOW_COLORSPACE_REC = None
|
325 |
+
|
326 |
+
cfg.GWM.FLOW_CLIP_U_LOW = float('-inf')
|
327 |
+
cfg.GWM.FLOW_CLIP_U_HIGH = float('inf')
|
328 |
+
cfg.GWM.FLOW_CLIP_V_LOW = float('-inf')
|
329 |
+
cfg.GWM.FLOW_CLIP_V_HIGH = float('inf')
|
330 |
+
|
331 |
+
cfg.GWM.FLOW_CLIP = float('inf')
|
332 |
+
cfg.GWM.FLOW_NORM = False
|
333 |
+
|
334 |
+
cfg.GWM.LOSS_MULT = CN()
|
335 |
+
cfg.GWM.LOSS_MULT.REC = 0.03
|
336 |
+
cfg.GWM.LOSS_MULT.HEIR_W = [0.1, 0.3, 0.6]
|
337 |
+
|
338 |
+
|
339 |
+
cfg.GWM.TTA = 100 # Test-time-adaptation
|
340 |
+
cfg.GWM.TTA_AS_TRAIN = False # Use train-like data logic for test-time-adaptation
|
341 |
+
|
342 |
+
cfg.GWM.LOSS = 'OG'
|
343 |
+
|
344 |
+
cfg.FLAGS = CN()
|
345 |
+
cfg.FLAGS.MAKE_VIS_VIDEOS = False # Making videos is kinda slow
|
346 |
+
cfg.FLAGS.EXTENDED_FLOW_RECON_VIS = False # Does not cost much
|
347 |
+
cfg.FLAGS.COMP_NLL_FOR_GT = False # Should we log loss against ground truth?
|
348 |
+
cfg.FLAGS.DEV_DATA = False
|
349 |
+
cfg.FLAGS.KEEP_ALL = True # Keep all checkoints
|
350 |
+
cfg.FLAGS.ORACLE_CHECK = False # Use oracle check to estimate max performance when grouping multiple components
|
351 |
+
|
352 |
+
cfg.FLAGS.INF_TPS = False
|
353 |
+
|
354 |
+
# cfg.FLAGS.UNFREEZE_AT = [(1, 10000), (0, 20000), (-1, 30000)]
|
355 |
+
cfg.FLAGS.UNFREEZE_AT = [(4, 0), (2, 500), (1, 1000), (-1, 10000)]
|
356 |
+
|
357 |
+
cfg.FLAGS.IGNORE_SIZE_DIV = False
|
358 |
+
|
359 |
+
cfg.FLAGS.IGNORE_TMP = True
|
360 |
+
|
361 |
+
cfg.WANDB = CN()
|
362 |
+
cfg.WANDB.ENABLE = False
|
363 |
+
cfg.WANDB.BASEDIR = '../'
|
364 |
+
|
365 |
+
cfg.DEBUG = False
|
366 |
+
|
367 |
+
cfg.LOG_ID = 'exp'
|
368 |
+
cfg.LOG_FREQ = 250
|
369 |
+
cfg.OUTPUT_BASEDIR = '../outputs'
|
370 |
+
cfg.SLURM = False
|
371 |
+
cfg.SKIP_TB = False
|
372 |
+
cfg.TOTAL_ITER = 20000
|
373 |
+
cfg.CONFIG_FILE = None
|
374 |
+
|
375 |
+
if os.environ.get('SLURM_JOB_ID', None):
|
376 |
+
cfg.LOG_ID = os.environ.get('SLURM_JOB_NAME', cfg.LOG_ID)
|
377 |
+
logger.info(f"Setting name {cfg.LOG_ID} based on SLURM job name")
|
configs/README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Available configs:
|
2 |
+
|
3 |
+
|
4 |
+
Use `main.py --config-file=<config-file>`
|
5 |
+
|
6 |
+
No need to specify `GWM.MODEL`. It is already defined inside the config files
|
7 |
+
|
8 |
+
|
9 |
+
### Available configs:
|
10 |
+
```
|
11 |
+
maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml
|
12 |
+
maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml
|
13 |
+
maskformer/swin/maskformer_swin_small_bs16_160k.yaml
|
14 |
+
maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml
|
15 |
+
maskformer/maskformer_R50_bs16_160k.yaml
|
16 |
+
```
|
configs/maskformer/Base-ADE20K-150.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: Base-unsup-vidseg.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
FREEZE_AT: 0
|
5 |
+
NAME: "build_resnet_backbone"
|
6 |
+
WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
|
7 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
8 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
9 |
+
RESNETS:
|
10 |
+
DEPTH: 50
|
11 |
+
STEM_TYPE: "basic" # not used
|
12 |
+
STEM_OUT_CHANNELS: 64
|
13 |
+
STRIDE_IN_1X1: False
|
14 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
15 |
+
# NORM: "SyncBN"
|
16 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
17 |
+
DATASETS:
|
18 |
+
TRAIN: ("ade20k_sem_seg_train",)
|
19 |
+
TEST: ("ade20k_sem_seg_val",)
|
20 |
+
SOLVER:
|
21 |
+
IMS_PER_BATCH: 16
|
22 |
+
BASE_LR: 0.0001
|
23 |
+
MAX_ITER: 160000
|
24 |
+
WARMUP_FACTOR: 1.0
|
25 |
+
WARMUP_ITERS: 0
|
26 |
+
WEIGHT_DECAY: 0.0001
|
27 |
+
OPTIMIZER: "ADAMW"
|
28 |
+
LR_SCHEDULER_NAME: "WarmupPolyLR"
|
29 |
+
BACKBONE_MULTIPLIER: 0.1
|
30 |
+
CLIP_GRADIENTS:
|
31 |
+
ENABLED: True
|
32 |
+
CLIP_TYPE: "full_model"
|
33 |
+
CLIP_VALUE: 0.01
|
34 |
+
NORM_TYPE: 2.0
|
35 |
+
INPUT:
|
36 |
+
MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"]
|
37 |
+
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
38 |
+
MIN_SIZE_TEST: 512
|
39 |
+
MAX_SIZE_TRAIN: 2048
|
40 |
+
MAX_SIZE_TEST: 2048
|
41 |
+
CROP:
|
42 |
+
ENABLED: True
|
43 |
+
TYPE: "absolute"
|
44 |
+
SIZE: (512, 512)
|
45 |
+
SINGLE_CATEGORY_MAX_AREA: 1.0
|
46 |
+
COLOR_AUG_SSD: True
|
47 |
+
SIZE_DIVISIBILITY: 512 # used in dataset mapper
|
48 |
+
FORMAT: "RGB"
|
49 |
+
DATASET_MAPPER_NAME: "mask_former_semantic"
|
50 |
+
TEST:
|
51 |
+
EVAL_PERIOD: 5000
|
52 |
+
AUG:
|
53 |
+
ENABLED: False
|
54 |
+
MIN_SIZES: [256, 384, 512, 640, 768, 896]
|
55 |
+
MAX_SIZE: 3584
|
56 |
+
FLIP: True
|
57 |
+
DATALOADER:
|
58 |
+
FILTER_EMPTY_ANNOTATIONS: True
|
59 |
+
NUM_WORKERS: 10
|
60 |
+
VERSION: 2
|
configs/maskformer/Base-unsup-vidseg.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
SEED: 42
|
2 |
+
GWM:
|
3 |
+
MODEL: "MASKFORMER"
|
configs/maskformer/cityscapes-19/Base-Cityscapes-19.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
BACKBONE:
|
3 |
+
FREEZE_AT: 0
|
4 |
+
NAME: "build_resnet_backbone"
|
5 |
+
WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
|
6 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
7 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
8 |
+
RESNETS:
|
9 |
+
DEPTH: 50
|
10 |
+
STEM_TYPE: "basic" # not used
|
11 |
+
STEM_OUT_CHANNELS: 64
|
12 |
+
STRIDE_IN_1X1: False
|
13 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
14 |
+
# NORM: "SyncBN"
|
15 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
16 |
+
DATASETS:
|
17 |
+
TRAIN: ("cityscapes_fine_sem_seg_train",)
|
18 |
+
TEST: ("cityscapes_fine_sem_seg_val",)
|
19 |
+
SOLVER:
|
20 |
+
IMS_PER_BATCH: 16
|
21 |
+
BASE_LR: 0.0001
|
22 |
+
MAX_ITER: 90000
|
23 |
+
WARMUP_FACTOR: 1.0
|
24 |
+
WARMUP_ITERS: 0
|
25 |
+
WEIGHT_DECAY: 0.0001
|
26 |
+
OPTIMIZER: "ADAMW"
|
27 |
+
LR_SCHEDULER_NAME: "WarmupPolyLR"
|
28 |
+
BACKBONE_MULTIPLIER: 0.1
|
29 |
+
CLIP_GRADIENTS:
|
30 |
+
ENABLED: True
|
31 |
+
CLIP_TYPE: "full_model"
|
32 |
+
CLIP_VALUE: 0.01
|
33 |
+
NORM_TYPE: 2.0
|
34 |
+
INPUT:
|
35 |
+
MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 1024) for x in range(5, 21)]"]
|
36 |
+
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
37 |
+
MIN_SIZE_TEST: 1024
|
38 |
+
MAX_SIZE_TRAIN: 4096
|
39 |
+
MAX_SIZE_TEST: 2048
|
40 |
+
CROP:
|
41 |
+
ENABLED: True
|
42 |
+
TYPE: "absolute"
|
43 |
+
SIZE: (512, 1024)
|
44 |
+
SINGLE_CATEGORY_MAX_AREA: 1.0
|
45 |
+
COLOR_AUG_SSD: True
|
46 |
+
SIZE_DIVISIBILITY: -1
|
47 |
+
FORMAT: "RGB"
|
48 |
+
DATASET_MAPPER_NAME: "mask_former_semantic"
|
49 |
+
TEST:
|
50 |
+
EVAL_PERIOD: 5000
|
51 |
+
AUG:
|
52 |
+
ENABLED: False
|
53 |
+
MIN_SIZES: [512, 768, 1024, 1280, 1536, 1792]
|
54 |
+
MAX_SIZE: 4096
|
55 |
+
FLIP: True
|
56 |
+
DATALOADER:
|
57 |
+
FILTER_EMPTY_ANNOTATIONS: True
|
58 |
+
NUM_WORKERS: 4
|
59 |
+
VERSION: 2
|
configs/maskformer/cityscapes-19/maskformer_R101_bs16_90k.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: Base-Cityscapes-19.yaml
|
2 |
+
MODEL:
|
3 |
+
WEIGHTS: "R-101.pkl"
|
4 |
+
RESNETS:
|
5 |
+
DEPTH: 101
|
6 |
+
STEM_TYPE: "basic" # not used
|
7 |
+
STEM_OUT_CHANNELS: 64
|
8 |
+
STRIDE_IN_1X1: False
|
9 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
10 |
+
# NORM: "SyncBN"
|
11 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
12 |
+
META_ARCHITECTURE: "MaskFormer"
|
13 |
+
SEM_SEG_HEAD:
|
14 |
+
NAME: "MaskFormerHead"
|
15 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
16 |
+
IGNORE_VALUE: 255
|
17 |
+
NUM_CLASSES: 19
|
18 |
+
COMMON_STRIDE: 4 # not used, hard-coded
|
19 |
+
LOSS_WEIGHT: 1.0
|
20 |
+
CONVS_DIM: 256
|
21 |
+
MASK_DIM: 256
|
22 |
+
NORM: "GN"
|
23 |
+
MASK_FORMER:
|
24 |
+
TRANSFORMER_IN_FEATURE: "res5"
|
25 |
+
DEEP_SUPERVISION: True
|
26 |
+
NO_OBJECT_WEIGHT: 0.1
|
27 |
+
DICE_WEIGHT: 1.0
|
28 |
+
MASK_WEIGHT: 20.0
|
29 |
+
HIDDEN_DIM: 256
|
30 |
+
NUM_OBJECT_QUERIES: 100
|
31 |
+
NHEADS: 8
|
32 |
+
DROPOUT: 0.1
|
33 |
+
DIM_FEEDFORWARD: 2048
|
34 |
+
ENC_LAYERS: 0
|
35 |
+
DEC_LAYERS: 6
|
36 |
+
PRE_NORM: False
|
configs/maskformer/cityscapes-19/maskformer_R101c_bs16_90k.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: maskformer_R101_bs16_90k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
FREEZE_AT: 0
|
5 |
+
NAME: "build_resnet_deeplab_backbone"
|
6 |
+
WEIGHTS: "detectron2://DeepLab/R-103.pkl"
|
7 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
8 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
9 |
+
RESNETS:
|
10 |
+
DEPTH: 101
|
11 |
+
STEM_TYPE: "deeplab"
|
12 |
+
STEM_OUT_CHANNELS: 128
|
13 |
+
STRIDE_IN_1X1: False
|
14 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
15 |
+
# NORM: "SyncBN"
|
16 |
+
RES5_MULTI_GRID: [1, 2, 4]
|
configs/maskformer/maskformer_R50_bs16_160k.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: Base-ADE20K-150.yaml
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "MaskFormer"
|
4 |
+
SEM_SEG_HEAD:
|
5 |
+
NAME: "MaskFormerHead"
|
6 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
7 |
+
IGNORE_VALUE: 255
|
8 |
+
NUM_CLASSES: 2
|
9 |
+
COMMON_STRIDE: 4 # not used, hard-coded
|
10 |
+
LOSS_WEIGHT: 1.0
|
11 |
+
CONVS_DIM: 256
|
12 |
+
MASK_DIM: 256
|
13 |
+
NORM: "GN"
|
14 |
+
MASK_FORMER:
|
15 |
+
TRANSFORMER_IN_FEATURE: "res5"
|
16 |
+
DEEP_SUPERVISION: False
|
17 |
+
NO_OBJECT_WEIGHT: 0.1
|
18 |
+
DICE_WEIGHT: 1.0
|
19 |
+
MASK_WEIGHT: 20.0
|
20 |
+
HIDDEN_DIM: 256
|
21 |
+
NUM_OBJECT_QUERIES: 2
|
22 |
+
NHEADS: 8
|
23 |
+
DROPOUT: 0.1
|
24 |
+
DIM_FEEDFORWARD: 2048
|
25 |
+
ENC_LAYERS: 0
|
26 |
+
DEC_LAYERS: 6
|
27 |
+
PRE_NORM: False
|
configs/maskformer/maskformer_R50_bs16_160k_dino.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: ./maskformer_R50_bs16_160k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2ViTTransformer"
|
5 |
+
FREEZE_AT: -1
|
6 |
+
SWIN:
|
7 |
+
EMBED_DIM: 768
|
8 |
+
DEPTHS: [2, 2, 6, 2]
|
9 |
+
NUM_HEADS: [3, 6, 12, 24]
|
10 |
+
WINDOW_SIZE: 7
|
11 |
+
APE: False
|
12 |
+
DROP_PATH_RATE: 0.3
|
13 |
+
PATCH_NORM: True
|
14 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
15 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
16 |
+
WEIGHTS: None
|
17 |
+
MASK_FORMER:
|
18 |
+
NUM_OBJECT_QUERIES: 4
|
19 |
+
SEM_SEG_HEAD:
|
20 |
+
PIXEL_DECODER_NAME: BigPixelDecoder
|
21 |
+
SOLVER:
|
22 |
+
BASE_LR: 0.00015
|
23 |
+
IMS_PER_BATCH: 8
|
24 |
+
WARMUP_FACTOR: 1e-6
|
25 |
+
WARMUP_ITERS: 1500
|
26 |
+
WEIGHT_DECAY: 0.01
|
27 |
+
WEIGHT_DECAY_NORM: 0.0
|
28 |
+
WEIGHT_DECAY_EMBED: 0.0
|
29 |
+
BACKBONE_MULTIPLIER: 1.0
|
30 |
+
FLAGS:
|
31 |
+
UNFREEZE_AT: []
|
configs/maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: ../maskformer_R50_bs16_160k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2SwinTransformer"
|
5 |
+
SWIN:
|
6 |
+
EMBED_DIM: 128
|
7 |
+
DEPTHS: [2, 2, 18, 2]
|
8 |
+
NUM_HEADS: [4, 8, 16, 32]
|
9 |
+
WINDOW_SIZE: 12
|
10 |
+
APE: False
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
PATCH_NORM: True
|
13 |
+
PRETRAIN_IMG_SIZE: 384
|
14 |
+
WEIGHTS: "pretrained_weights/swin_base_patch4_window12_384_22k.pkl"
|
15 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
16 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
17 |
+
SOLVER:
|
18 |
+
BASE_LR: 0.00006
|
19 |
+
WARMUP_FACTOR: 1e-6
|
20 |
+
WARMUP_ITERS: 1500
|
21 |
+
WEIGHT_DECAY: 0.01
|
22 |
+
WEIGHT_DECAY_NORM: 0.0
|
23 |
+
WEIGHT_DECAY_EMBED: 0.0
|
24 |
+
BACKBONE_MULTIPLIER: 1.0
|
25 |
+
INPUT:
|
26 |
+
MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
|
27 |
+
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
28 |
+
MIN_SIZE_TEST: 640
|
29 |
+
MAX_SIZE_TRAIN: 2560
|
30 |
+
MAX_SIZE_TEST: 2560
|
31 |
+
CROP:
|
32 |
+
ENABLED: True
|
33 |
+
TYPE: "absolute"
|
34 |
+
SIZE: (640, 640)
|
35 |
+
SINGLE_CATEGORY_MAX_AREA: 1.0
|
36 |
+
COLOR_AUG_SSD: True
|
37 |
+
SIZE_DIVISIBILITY: 640 # used in dataset mapper
|
38 |
+
FORMAT: "RGB"
|
39 |
+
TEST:
|
40 |
+
EVAL_PERIOD: 5000
|
41 |
+
AUG:
|
42 |
+
ENABLED: False
|
43 |
+
MIN_SIZES: [320, 480, 640, 800, 960, 1120]
|
44 |
+
MAX_SIZE: 4480
|
45 |
+
FLIP: True
|
configs/maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: ../maskformer_R50_bs16_160k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2SwinTransformer"
|
5 |
+
SWIN:
|
6 |
+
EMBED_DIM: 192
|
7 |
+
DEPTHS: [2, 2, 18, 2]
|
8 |
+
NUM_HEADS: [6, 12, 24, 48]
|
9 |
+
WINDOW_SIZE: 12
|
10 |
+
APE: False
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
PATCH_NORM: True
|
13 |
+
PRETRAIN_IMG_SIZE: 384
|
14 |
+
WEIGHTS: "pretrained_weights/swin_large_patch4_window12_384_22k.pkl"
|
15 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
16 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
17 |
+
SOLVER:
|
18 |
+
BASE_LR: 0.00006
|
19 |
+
WARMUP_FACTOR: 1e-6
|
20 |
+
WARMUP_ITERS: 1500
|
21 |
+
WEIGHT_DECAY: 0.01
|
22 |
+
WEIGHT_DECAY_NORM: 0.0
|
23 |
+
WEIGHT_DECAY_EMBED: 0.0
|
24 |
+
BACKBONE_MULTIPLIER: 1.0
|
25 |
+
INPUT:
|
26 |
+
MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
|
27 |
+
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
28 |
+
MIN_SIZE_TEST: 640
|
29 |
+
MAX_SIZE_TRAIN: 2560
|
30 |
+
MAX_SIZE_TEST: 2560
|
31 |
+
CROP:
|
32 |
+
ENABLED: True
|
33 |
+
TYPE: "absolute"
|
34 |
+
SIZE: (640, 640)
|
35 |
+
SINGLE_CATEGORY_MAX_AREA: 1.0
|
36 |
+
COLOR_AUG_SSD: True
|
37 |
+
SIZE_DIVISIBILITY: 640 # used in dataset mapper
|
38 |
+
FORMAT: "RGB"
|
39 |
+
TEST:
|
40 |
+
EVAL_PERIOD: 5000
|
41 |
+
AUG:
|
42 |
+
ENABLED: False
|
43 |
+
MIN_SIZES: [320, 480, 640, 800, 960, 1120]
|
44 |
+
MAX_SIZE: 4480
|
45 |
+
FLIP: True
|
configs/maskformer/swin/maskformer_swin_small_bs16_160k.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: ../maskformer_R50_bs16_160k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2SwinTransformer"
|
5 |
+
SWIN:
|
6 |
+
EMBED_DIM: 96
|
7 |
+
DEPTHS: [2, 2, 18, 2]
|
8 |
+
NUM_HEADS: [3, 6, 12, 24]
|
9 |
+
WINDOW_SIZE: 7
|
10 |
+
APE: False
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
PATCH_NORM: True
|
13 |
+
WEIGHTS: "pretrained_weights/swin_small_patch4_window7_224.pkl"
|
14 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
15 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
16 |
+
SOLVER:
|
17 |
+
BASE_LR: 0.00006
|
18 |
+
WARMUP_FACTOR: 1e-6
|
19 |
+
WARMUP_ITERS: 1500
|
20 |
+
WEIGHT_DECAY: 0.01
|
21 |
+
WEIGHT_DECAY_NORM: 0.0
|
22 |
+
WEIGHT_DECAY_EMBED: 0.0
|
23 |
+
BACKBONE_MULTIPLIER: 1.0
|
configs/maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: ../maskformer_R50_bs16_160k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2SwinTransformer"
|
5 |
+
SWIN:
|
6 |
+
EMBED_DIM: 96
|
7 |
+
DEPTHS: [2, 2, 6, 2]
|
8 |
+
NUM_HEADS: [3, 6, 12, 24]
|
9 |
+
WINDOW_SIZE: 7
|
10 |
+
APE: False
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
PATCH_NORM: True
|
13 |
+
WEIGHTS: "pretrained_weights/swin_tiny_patch4_window7_224.pkl"
|
14 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
15 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
16 |
+
SOLVER:
|
17 |
+
BASE_LR: 0.00006
|
18 |
+
WARMUP_FACTOR: 1e-6
|
19 |
+
WARMUP_ITERS: 1500
|
20 |
+
WEIGHT_DECAY: 0.01
|
21 |
+
WEIGHT_DECAY_NORM: 0.0
|
22 |
+
WEIGHT_DECAY_EMBED: 0.0
|
23 |
+
BACKBONE_MULTIPLIER: 1.0
|
configs/maskformer/swin/maskformer_swin_tiny_bs16_160k_moby.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: ../maskformer_R50_bs16_160k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2SwinTransformer"
|
5 |
+
SWIN:
|
6 |
+
EMBED_DIM: 96
|
7 |
+
DEPTHS: [2, 2, 6, 2]
|
8 |
+
NUM_HEADS: [3, 6, 12, 24]
|
9 |
+
WINDOW_SIZE: 7
|
10 |
+
APE: False
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
PATCH_NORM: True
|
13 |
+
WEIGHTS: "pretrained_weights/moby_swin_t_300ep_pretrained.pkl"
|
14 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
15 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
16 |
+
SOLVER:
|
17 |
+
BASE_LR: 0.00006
|
18 |
+
WARMUP_FACTOR: 1e-6
|
19 |
+
WARMUP_ITERS: 1500
|
20 |
+
WEIGHT_DECAY: 0.01
|
21 |
+
WEIGHT_DECAY_NORM: 0.0
|
22 |
+
WEIGHT_DECAY_EMBED: 0.0
|
23 |
+
BACKBONE_MULTIPLIER: 1.0
|
datasets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .flow_eval_detectron import FlowEvalDetectron
|
2 |
+
from .flow_pair_detectron import FlowPairDetectron
|
datasets/flow_eval_detectron.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import detectron2.data.transforms as DT
|
6 |
+
import einops
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from PIL import Image
|
11 |
+
from detectron2.data import detection_utils as d2_utils
|
12 |
+
from detectron2.structures import Instances, BitMasks
|
13 |
+
from sklearn.model_selection import train_test_split
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
|
16 |
+
from utils.data import read_flow
|
17 |
+
|
18 |
+
|
19 |
+
class FlowEvalDetectron(Dataset):
|
20 |
+
def __init__(self, data_dir, resolution, pair_list, val_seq, to_rgb=False, with_rgb=False, size_divisibility=None,
|
21 |
+
small_val=0, flow_clip=1., norm=True, read_big=True, eval_size=True, force1080p=False):
|
22 |
+
self.val_seq = val_seq
|
23 |
+
self.to_rgb = to_rgb
|
24 |
+
self.with_rgb = with_rgb
|
25 |
+
self.data_dir = data_dir
|
26 |
+
self.pair_list = pair_list
|
27 |
+
self.resolution = resolution
|
28 |
+
|
29 |
+
self.eval_size = eval_size
|
30 |
+
|
31 |
+
self.samples = []
|
32 |
+
self.samples_fid = {}
|
33 |
+
for v in self.val_seq:
|
34 |
+
seq_dir = Path(self.data_dir[0]) / v
|
35 |
+
frames_paths = sorted(seq_dir.glob('*.flo'))
|
36 |
+
self.samples_fid[str(seq_dir)] = {fp: i for i, fp in enumerate(frames_paths)}
|
37 |
+
self.samples.extend(frames_paths)
|
38 |
+
self.samples = [os.path.join(x.parent.name, x.name) for x in self.samples]
|
39 |
+
if small_val > 0:
|
40 |
+
_, self.samples = train_test_split(self.samples, test_size=small_val, random_state=42)
|
41 |
+
self.gaps = ['gap{}'.format(i) for i in pair_list]
|
42 |
+
self.neg_gaps = ['gap{}'.format(-i) for i in pair_list]
|
43 |
+
self.size_divisibility = size_divisibility
|
44 |
+
self.ignore_label = -1
|
45 |
+
self.transforms = DT.AugmentationList([
|
46 |
+
DT.Resize(self.resolution, interp=Image.BICUBIC),
|
47 |
+
])
|
48 |
+
self.flow_clip=flow_clip
|
49 |
+
self.norm_flow=norm
|
50 |
+
self.read_big=read_big
|
51 |
+
self.force1080p_transforms=None
|
52 |
+
if force1080p:
|
53 |
+
self.force1080p_transforms = DT.AugmentationList([
|
54 |
+
DT.Resize((1088, 1920), interp=Image.BICUBIC),
|
55 |
+
])
|
56 |
+
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.samples)
|
60 |
+
|
61 |
+
def __getitem__(self, idx):
|
62 |
+
dataset_dicts = []
|
63 |
+
|
64 |
+
dataset_dict = {}
|
65 |
+
flow_dir = Path(self.data_dir[0]) / self.samples[idx]
|
66 |
+
fid = self.samples_fid[str(flow_dir.parent)][flow_dir]
|
67 |
+
flo = einops.rearrange(read_flow(str(flow_dir), self.resolution, self.to_rgb), 'c h w -> h w c')
|
68 |
+
dataset_dict["gap"] = 'gap1'
|
69 |
+
|
70 |
+
suffix = '.png' if 'CLEVR' in self.samples[idx] else '.jpg'
|
71 |
+
rgb_dir = (self.data_dir[1] / self.samples[idx]).with_suffix(suffix)
|
72 |
+
gt_dir = (self.data_dir[2] / self.samples[idx]).with_suffix('.png')
|
73 |
+
|
74 |
+
rgb = d2_utils.read_image(str(rgb_dir)).astype(np.float32)
|
75 |
+
original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float()
|
76 |
+
if self.read_big:
|
77 |
+
rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32)
|
78 |
+
rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.)
|
79 |
+
if self.force1080p_transforms is not None:
|
80 |
+
rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0]
|
81 |
+
|
82 |
+
input = DT.AugInput(rgb)
|
83 |
+
|
84 |
+
# Apply the augmentation:
|
85 |
+
preprocessing_transforms = self.transforms(input) # type: DT.Transform
|
86 |
+
rgb = input.image
|
87 |
+
rgb = np.transpose(rgb, (2, 0, 1))
|
88 |
+
rgb = rgb.clip(0., 255.)
|
89 |
+
d2_utils.check_image_size(dataset_dict, flo)
|
90 |
+
|
91 |
+
if gt_dir.exists():
|
92 |
+
sem_seg_gt_ori = d2_utils.read_image(gt_dir)
|
93 |
+
sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt_ori)
|
94 |
+
if sem_seg_gt.ndim == 3:
|
95 |
+
sem_seg_gt = sem_seg_gt[:, :, 0]
|
96 |
+
sem_seg_gt_ori = sem_seg_gt_ori[:, :, 0]
|
97 |
+
if sem_seg_gt.max() == 255:
|
98 |
+
sem_seg_gt = (sem_seg_gt > 128).astype(int)
|
99 |
+
sem_seg_gt_ori = (sem_seg_gt_ori > 128).astype(int)
|
100 |
+
else:
|
101 |
+
sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1]))
|
102 |
+
sem_seg_gt_ori = np.zeros((original_rgb.shape[-2], original_rgb.shape[-1]))
|
103 |
+
|
104 |
+
gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / self.samples[idx]).with_suffix(
|
105 |
+
'.png')
|
106 |
+
if gwm_dir.exists():
|
107 |
+
gwm_seg_gt = preprocessing_transforms.apply_segmentation(d2_utils.read_image(str(gwm_dir)))
|
108 |
+
gwm_seg_gt = np.array(gwm_seg_gt)
|
109 |
+
if gwm_seg_gt.ndim == 3:
|
110 |
+
gwm_seg_gt = gwm_seg_gt[:, :, 0]
|
111 |
+
if gwm_seg_gt.max() == 255:
|
112 |
+
gwm_seg_gt[gwm_seg_gt == 255] = 1
|
113 |
+
else:
|
114 |
+
gwm_seg_gt = None
|
115 |
+
|
116 |
+
if sem_seg_gt is None:
|
117 |
+
raise ValueError(
|
118 |
+
"Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
|
119 |
+
dataset_dict["file_name"]
|
120 |
+
)
|
121 |
+
)
|
122 |
+
|
123 |
+
# Pad image and segmentation label here!
|
124 |
+
if self.to_rgb:
|
125 |
+
flo = torch.as_tensor(np.ascontiguousarray(flo.transpose(2, 0, 1))) / 2 + .5
|
126 |
+
flo = flo * 255
|
127 |
+
else:
|
128 |
+
flo = torch.as_tensor(np.ascontiguousarray(flo.transpose(2, 0, 1)))
|
129 |
+
if self.norm_flow:
|
130 |
+
flo = flo/(flo ** 2).sum(0).max().sqrt()
|
131 |
+
flo = flo.clip(-self.flow_clip, self.flow_clip)
|
132 |
+
|
133 |
+
rgb = torch.as_tensor(np.ascontiguousarray(rgb)).float()
|
134 |
+
if sem_seg_gt is not None:
|
135 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
136 |
+
sem_seg_gt_ori = torch.as_tensor(sem_seg_gt_ori.astype("long"))
|
137 |
+
if gwm_seg_gt is not None:
|
138 |
+
gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long"))
|
139 |
+
|
140 |
+
if self.size_divisibility > 0:
|
141 |
+
image_size = (flo.shape[-2], flo.shape[-1])
|
142 |
+
padding_size = [
|
143 |
+
0,
|
144 |
+
int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1],
|
145 |
+
0,
|
146 |
+
int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0],
|
147 |
+
]
|
148 |
+
flo = F.pad(flo, padding_size, value=0).contiguous()
|
149 |
+
rgb = F.pad(rgb, padding_size, value=128).contiguous()
|
150 |
+
if sem_seg_gt is not None:
|
151 |
+
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
152 |
+
if gwm_seg_gt is not None:
|
153 |
+
gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
154 |
+
|
155 |
+
image_shape = (flo.shape[-2], flo.shape[-1]) # h, w
|
156 |
+
if self.eval_size:
|
157 |
+
image_shape = (sem_seg_gt_ori.shape[-2], sem_seg_gt_ori.shape[-1])
|
158 |
+
|
159 |
+
|
160 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
161 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
162 |
+
# Therefore it's important to use torch.Tensor.
|
163 |
+
dataset_dict["flow"] = flo
|
164 |
+
dataset_dict["rgb"] = rgb
|
165 |
+
|
166 |
+
|
167 |
+
dataset_dict["original_rgb"] = F.interpolate(original_rgb[None], mode='bicubic', size=sem_seg_gt_ori.shape[-2:], align_corners=False).clip(0.,255.)[0]
|
168 |
+
if self.read_big:
|
169 |
+
dataset_dict["RGB_BIG"] = rgb_big
|
170 |
+
|
171 |
+
dataset_dict["category"] = str(gt_dir).split('/')[-2:]
|
172 |
+
dataset_dict['frame_id'] = fid
|
173 |
+
|
174 |
+
if sem_seg_gt is not None:
|
175 |
+
dataset_dict["sem_seg"] = sem_seg_gt.long()
|
176 |
+
dataset_dict["sem_seg_ori"] = sem_seg_gt_ori.long()
|
177 |
+
|
178 |
+
if gwm_seg_gt is not None:
|
179 |
+
dataset_dict["gwm_seg"] = gwm_seg_gt.long()
|
180 |
+
|
181 |
+
if "annotations" in dataset_dict:
|
182 |
+
raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
|
183 |
+
|
184 |
+
# Prepare per-category binary masks
|
185 |
+
if sem_seg_gt is not None:
|
186 |
+
sem_seg_gt = sem_seg_gt.numpy()
|
187 |
+
instances = Instances(image_shape)
|
188 |
+
classes = np.unique(sem_seg_gt)
|
189 |
+
# remove ignored region
|
190 |
+
classes = classes[classes != self.ignore_label]
|
191 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
192 |
+
|
193 |
+
masks = []
|
194 |
+
for class_id in classes:
|
195 |
+
masks.append(sem_seg_gt == class_id)
|
196 |
+
|
197 |
+
if len(masks) == 0:
|
198 |
+
# Some image does not have annotation (all ignored)
|
199 |
+
instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
|
200 |
+
else:
|
201 |
+
masks = BitMasks(
|
202 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
203 |
+
)
|
204 |
+
instances.gt_masks = masks.tensor
|
205 |
+
|
206 |
+
dataset_dict["instances"] = instances
|
207 |
+
dataset_dicts.append(dataset_dict)
|
208 |
+
|
209 |
+
return dataset_dicts
|
datasets/flow_pair_detectron.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pathlib import Path
|
3 |
+
import random
|
4 |
+
|
5 |
+
import detectron2.data.transforms as DT
|
6 |
+
import einops
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision.transforms as T
|
11 |
+
from PIL import Image
|
12 |
+
from detectron2.data import detection_utils as d2_utils
|
13 |
+
from detectron2.structures import Instances, BitMasks
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
|
16 |
+
from utils.data import read_flow, read_flo
|
17 |
+
|
18 |
+
|
19 |
+
def load_flow_tensor(path, resize=None, normalize=True, align_corners=True):
|
20 |
+
"""
|
21 |
+
Load flow, scale the pixel values according to the resized scale.
|
22 |
+
If normalize is true, return rescaled in normalized pixel coordinates
|
23 |
+
where pixel coordinates are in range [-1, 1].
|
24 |
+
NOTE: RAFT USES ALIGN_CORNERS=TRUE SO WE NEED TO ACCOUNT FOR THIS
|
25 |
+
Returns (2, H, W) float32
|
26 |
+
"""
|
27 |
+
flow = read_flo(path).astype(np.float32)
|
28 |
+
H, W, _ = flow.shape
|
29 |
+
h, w = (H, W) if resize is None else resize
|
30 |
+
u, v = flow[..., 0], flow[..., 1]
|
31 |
+
if normalize:
|
32 |
+
if align_corners:
|
33 |
+
u = 2.0 * u / (W - 1)
|
34 |
+
v = 2.0 * v / (H - 1)
|
35 |
+
else:
|
36 |
+
u = 2.0 * u / W
|
37 |
+
v = 2.0 * v / H
|
38 |
+
else:
|
39 |
+
h, w = resize
|
40 |
+
u = w * u / W
|
41 |
+
v = h * v / H
|
42 |
+
|
43 |
+
if h != H or w !=W:
|
44 |
+
u = Image.fromarray(u).resize((w, h), Image.ANTIALIAS)
|
45 |
+
v = Image.fromarray(v).resize((w, h), Image.ANTIALIAS)
|
46 |
+
u, v = np.array(u), np.array(v)
|
47 |
+
return torch.from_numpy(np.stack([u, v], axis=0))
|
48 |
+
|
49 |
+
|
50 |
+
class FlowPairDetectron(Dataset):
|
51 |
+
def __init__(self, data_dir, resolution, to_rgb=False, size_divisibility=None, enable_photo_aug=False, flow_clip=1., norm=True, read_big=True, force1080p=False, flow_res=None):
|
52 |
+
self.eval = eval
|
53 |
+
self.to_rgb = to_rgb
|
54 |
+
self.data_dir = data_dir
|
55 |
+
self.flow_dir = {k: [e for e in v if e.shape[0] > 0] for k, v in data_dir[0].items()}
|
56 |
+
self.flow_dir = {k: v for k, v in self.flow_dir.items() if len(v) > 0}
|
57 |
+
self.resolution = resolution
|
58 |
+
self.size_divisibility = size_divisibility
|
59 |
+
self.ignore_label = -1
|
60 |
+
self.transforms = DT.AugmentationList([
|
61 |
+
DT.Resize(self.resolution, interp=Image.BICUBIC),
|
62 |
+
])
|
63 |
+
self.photometric_aug = T.Compose([
|
64 |
+
T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)]),
|
65 |
+
p=0.8),
|
66 |
+
T.RandomGrayscale(p=0.2),
|
67 |
+
]) if enable_photo_aug else None
|
68 |
+
self.flow_clip=flow_clip
|
69 |
+
self.norm_flow=norm
|
70 |
+
self.read_big = read_big
|
71 |
+
self.force1080p_transforms = None
|
72 |
+
if force1080p:
|
73 |
+
self.force1080p_transforms = DT.AugmentationList([
|
74 |
+
DT.Resize((1088, 1920), interp=Image.BICUBIC),
|
75 |
+
])
|
76 |
+
self.big_flow_resolution = flow_res
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
return sum([cat.shape[0] for cat in next(iter(self.flow_dir.values()))]) if len(
|
80 |
+
self.flow_dir.values()) > 0 else 0
|
81 |
+
|
82 |
+
def __getitem__(self, idx):
|
83 |
+
|
84 |
+
dataset_dicts = []
|
85 |
+
|
86 |
+
random_gap = random.choice(list(self.flow_dir.keys()))
|
87 |
+
flowgaps = self.flow_dir[random_gap]
|
88 |
+
vid = random.choice(flowgaps)
|
89 |
+
flos = random.choice(vid)
|
90 |
+
dataset_dict = {}
|
91 |
+
|
92 |
+
fname = Path(flos[0]).stem
|
93 |
+
dname = Path(flos[0]).parent.name
|
94 |
+
suffix = '.png' if 'CLEVR' in fname else '.jpg'
|
95 |
+
rgb_dir = (self.data_dir[1] / dname / fname).with_suffix(suffix)
|
96 |
+
gt_dir = (self.data_dir[2] / dname / fname).with_suffix('.png')
|
97 |
+
|
98 |
+
flo0 = einops.rearrange(read_flow(str(flos[0]), self.resolution, self.to_rgb), 'c h w -> h w c')
|
99 |
+
flo1 = einops.rearrange(read_flow(str(flos[1]), self.resolution, self.to_rgb), 'c h w -> h w c')
|
100 |
+
if self.big_flow_resolution is not None:
|
101 |
+
flo0_big = einops.rearrange(read_flow(str(flos[0]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c')
|
102 |
+
flo1_big = einops.rearrange(read_flow(str(flos[1]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c')
|
103 |
+
rgb = d2_utils.read_image(rgb_dir).astype(np.float32)
|
104 |
+
original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float()
|
105 |
+
if self.read_big:
|
106 |
+
rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32)
|
107 |
+
rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.)
|
108 |
+
if self.force1080p_transforms is not None:
|
109 |
+
rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0]
|
110 |
+
|
111 |
+
# print('not here', rgb.min(), rgb.max())
|
112 |
+
input = DT.AugInput(rgb)
|
113 |
+
|
114 |
+
# Apply the augmentation:
|
115 |
+
preprocessing_transforms = self.transforms(input) # type: DT.Transform
|
116 |
+
rgb = input.image
|
117 |
+
if self.photometric_aug:
|
118 |
+
rgb_aug = Image.fromarray(rgb.astype(np.uint8))
|
119 |
+
rgb_aug = self.photometric_aug(rgb_aug)
|
120 |
+
rgb_aug = d2_utils.convert_PIL_to_numpy(rgb_aug, 'RGB')
|
121 |
+
rgb_aug = np.transpose(rgb_aug, (2, 0, 1)).astype(np.float32)
|
122 |
+
rgb = np.transpose(rgb, (2, 0, 1))
|
123 |
+
rgb = rgb.clip(0., 255.)
|
124 |
+
# print('here', rgb.min(), rgb.max())
|
125 |
+
d2_utils.check_image_size(dataset_dict, flo0)
|
126 |
+
if gt_dir.exists():
|
127 |
+
sem_seg_gt = d2_utils.read_image(str(gt_dir))
|
128 |
+
sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt)
|
129 |
+
# sem_seg_gt = cv2.resize(sem_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST)
|
130 |
+
if sem_seg_gt.ndim == 3:
|
131 |
+
sem_seg_gt = sem_seg_gt[:, :, 0]
|
132 |
+
if sem_seg_gt.max() == 255:
|
133 |
+
sem_seg_gt = (sem_seg_gt > 128).astype(int)
|
134 |
+
else:
|
135 |
+
sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1]))
|
136 |
+
|
137 |
+
|
138 |
+
gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / dname / fname).with_suffix('.png')
|
139 |
+
if gwm_dir.exists():
|
140 |
+
gwm_seg_gt = d2_utils.read_image(str(gwm_dir))
|
141 |
+
gwm_seg_gt = preprocessing_transforms.apply_segmentation(gwm_seg_gt)
|
142 |
+
gwm_seg_gt = np.array(gwm_seg_gt)
|
143 |
+
# gwm_seg_gt = cv2.resize(gwm_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST)
|
144 |
+
if gwm_seg_gt.ndim == 3:
|
145 |
+
gwm_seg_gt = gwm_seg_gt[:, :, 0]
|
146 |
+
if gwm_seg_gt.max() == 255:
|
147 |
+
gwm_seg_gt[gwm_seg_gt == 255] = 1
|
148 |
+
else:
|
149 |
+
gwm_seg_gt = None
|
150 |
+
|
151 |
+
if sem_seg_gt is None:
|
152 |
+
raise ValueError(
|
153 |
+
"Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
|
154 |
+
dataset_dict["file_name"]
|
155 |
+
)
|
156 |
+
)
|
157 |
+
|
158 |
+
# Pad image and segmentation label here!
|
159 |
+
if self.to_rgb:
|
160 |
+
flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1))) / 2 + .5
|
161 |
+
flo0 = flo0 * 255
|
162 |
+
flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1))) / 2 + .5
|
163 |
+
flo1 = flo1 * 255
|
164 |
+
if self.big_flow_resolution is not None:
|
165 |
+
flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1))) / 2 + .5
|
166 |
+
flo0_big = flo0_big * 255
|
167 |
+
flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1))) / 2 + .5
|
168 |
+
flo1_big = flo1_big * 255
|
169 |
+
else:
|
170 |
+
flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1)))
|
171 |
+
flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1)))
|
172 |
+
|
173 |
+
if self.norm_flow:
|
174 |
+
flo0 = flo0 / (flo0 ** 2).sum(0).max().sqrt()
|
175 |
+
flo1 = flo1 / (flo1 ** 2).sum(0).max().sqrt()
|
176 |
+
|
177 |
+
flo0 = flo0.clip(-self.flow_clip, self.flow_clip)
|
178 |
+
flo1 = flo1.clip(-self.flow_clip, self.flow_clip)
|
179 |
+
|
180 |
+
if self.big_flow_resolution is not None:
|
181 |
+
flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1)))
|
182 |
+
flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1)))
|
183 |
+
if self.norm_flow:
|
184 |
+
flo0_big = flo0_big / (flo0_big ** 2).sum(0).max().sqrt()
|
185 |
+
flo1_big = flo1_big / (flo1_big ** 2).sum(0).max().sqrt()
|
186 |
+
flo0_big = flo0_big.clip(-self.flow_clip, self.flow_clip)
|
187 |
+
flo1_big = flo1_big.clip(-self.flow_clip, self.flow_clip)
|
188 |
+
|
189 |
+
rgb = torch.as_tensor(np.ascontiguousarray(rgb))
|
190 |
+
if self.photometric_aug:
|
191 |
+
rgb_aug = torch.as_tensor(np.ascontiguousarray(rgb_aug))
|
192 |
+
|
193 |
+
if sem_seg_gt is not None:
|
194 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
195 |
+
if gwm_seg_gt is not None:
|
196 |
+
gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long"))
|
197 |
+
|
198 |
+
if self.size_divisibility > 0:
|
199 |
+
image_size = (flo0.shape[-2], flo0.shape[-1])
|
200 |
+
padding_size = [
|
201 |
+
0,
|
202 |
+
int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1],
|
203 |
+
0,
|
204 |
+
int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0],
|
205 |
+
]
|
206 |
+
flo0 = F.pad(flo0, padding_size, value=0).contiguous()
|
207 |
+
flo1 = F.pad(flo1, padding_size, value=0).contiguous()
|
208 |
+
rgb = F.pad(rgb, padding_size, value=128).contiguous()
|
209 |
+
if self.photometric_aug:
|
210 |
+
rgb_aug = F.pad(rgb_aug, padding_size, value=128).contiguous()
|
211 |
+
if sem_seg_gt is not None:
|
212 |
+
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
213 |
+
if gwm_seg_gt is not None:
|
214 |
+
gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
215 |
+
|
216 |
+
image_shape = (rgb.shape[-2], rgb.shape[-1]) # h, w
|
217 |
+
|
218 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
219 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
220 |
+
# Therefore it's important to use torch.Tensor.
|
221 |
+
dataset_dict["flow"] = flo0
|
222 |
+
dataset_dict["flow_2"] = flo1
|
223 |
+
|
224 |
+
# dataset_dict["flow_fwd"] = flo_norm_fwd
|
225 |
+
# dataset_dict["flow_bwd"] = flo_norm_bwd
|
226 |
+
# dataset_dict["flow_rgb"] = rgb_flo0
|
227 |
+
# dataset_dict["flow_gap"] = gap
|
228 |
+
|
229 |
+
dataset_dict["rgb"] = rgb
|
230 |
+
dataset_dict["original_rgb"] = original_rgb
|
231 |
+
if self.read_big:
|
232 |
+
dataset_dict["RGB_BIG"] = rgb_big
|
233 |
+
if self.photometric_aug:
|
234 |
+
dataset_dict["rgb_aug"] = rgb_aug
|
235 |
+
|
236 |
+
if self.big_flow_resolution is not None:
|
237 |
+
dataset_dict["flow_big"] = flo0_big
|
238 |
+
dataset_dict["flow_big_2"] = flo1_big
|
239 |
+
|
240 |
+
|
241 |
+
if sem_seg_gt is not None:
|
242 |
+
dataset_dict["sem_seg"] = sem_seg_gt.long()
|
243 |
+
|
244 |
+
if gwm_seg_gt is not None:
|
245 |
+
dataset_dict["gwm_seg"] = gwm_seg_gt.long()
|
246 |
+
|
247 |
+
if "annotations" in dataset_dict:
|
248 |
+
raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
|
249 |
+
|
250 |
+
# Prepare per-category binary masks
|
251 |
+
if sem_seg_gt is not None:
|
252 |
+
sem_seg_gt = sem_seg_gt.numpy()
|
253 |
+
instances = Instances(image_shape)
|
254 |
+
classes = np.unique(sem_seg_gt)
|
255 |
+
# remove ignored region
|
256 |
+
classes = classes[classes != self.ignore_label]
|
257 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
258 |
+
|
259 |
+
masks = []
|
260 |
+
for class_id in classes:
|
261 |
+
masks.append(sem_seg_gt == class_id)
|
262 |
+
|
263 |
+
if len(masks) == 0:
|
264 |
+
# Some image does not have annotation (all ignored)
|
265 |
+
instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
|
266 |
+
else:
|
267 |
+
masks = BitMasks(
|
268 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
269 |
+
)
|
270 |
+
instances.gt_masks = masks.tensor
|
271 |
+
|
272 |
+
dataset_dict["instances"] = instances
|
273 |
+
dataset_dicts.append(dataset_dict)
|
274 |
+
|
275 |
+
return dataset_dicts
|
determinism.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
lvl = int(os.environ.get('TRY_DETERMISM_LVL', '0'))
|
3 |
+
if lvl > 0:
|
4 |
+
print(f'Attempting to enable deterministic cuDNN and cuBLAS operations to lvl {lvl}')
|
5 |
+
if lvl >= 2:
|
6 |
+
# turn on deterministic operations
|
7 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" #Need to set before torch gets loaded
|
8 |
+
import torch
|
9 |
+
# Since using unstable torch version, it looks like 1.12.0.devXXXXXXX
|
10 |
+
if torch.version.__version__ >= '1.12.0':
|
11 |
+
torch.use_deterministic_algorithms(True, warn_only=(lvl < 3))
|
12 |
+
elif lvl >= 3:
|
13 |
+
torch.use_deterministic_algorithms(True) # This will throw errors if implementations are missing
|
14 |
+
else:
|
15 |
+
print(f"Torch verions is only {torch.version.__version__}, which will cause errors on lvl {lvl}")
|
16 |
+
if lvl >= 1:
|
17 |
+
import torch
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
torch.backends.cudnn.benchmark = False
|
20 |
+
|
21 |
+
|
22 |
+
def i_do_nothing_but_dont_remove_me_otherwise_things_break():
|
23 |
+
"""This exists to prevent formatters from treating this file as dead code"""
|
24 |
+
pass
|
dist.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributions
|
5 |
+
|
6 |
+
import utils
|
7 |
+
|
8 |
+
LOGGER = utils.log.getLogger(__name__)
|
9 |
+
|
10 |
+
__defined_kl = False
|
11 |
+
|
12 |
+
EPS = 1e-5
|
13 |
+
|
14 |
+
|
15 |
+
def clamp_probs(probs):
|
16 |
+
probs = probs.clamp(EPS, 1. - EPS) # Will no longer sum to 1
|
17 |
+
return probs / probs.sum(-1, keepdim=True) # to simplex
|
18 |
+
|
19 |
+
|
20 |
+
def grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
|
21 |
+
hr = torch.arange(h + 2 * pad, device=device) - pad
|
22 |
+
wr = torch.arange(w + 2 * pad, device=device) - pad
|
23 |
+
if norm:
|
24 |
+
hr = hr / (h + 2 * pad - 1)
|
25 |
+
wr = wr / (w + 2 * pad - 1)
|
26 |
+
ig, jg = torch.meshgrid(hr, wr)
|
27 |
+
g = torch.stack([jg, ig]).to(dtype)[None]
|
28 |
+
return g
|
29 |
+
|
30 |
+
|
31 |
+
@functools.lru_cache(2)
|
32 |
+
def cached_grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
|
33 |
+
return grid(h, w, pad, device, dtype, norm)
|
34 |
+
|
eval_utils.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import random
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
import einops
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from PIL import Image, ImageDraw, ImageFont
|
10 |
+
from sklearn.cluster import SpectralClustering
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
import flow_reconstruction
|
14 |
+
from utils import visualisation, log, grid
|
15 |
+
from utils.vit_extractor import ViTExtractor
|
16 |
+
|
17 |
+
label_colors = visualisation.create_label_colormap()
|
18 |
+
logger = log.getLogger('gwm')
|
19 |
+
|
20 |
+
|
21 |
+
def __default_font(fontsize):
|
22 |
+
try:
|
23 |
+
FNT = ImageFont.truetype("dejavu/DejaVuSansMono.ttf", fontsize)
|
24 |
+
except OSError:
|
25 |
+
FNT = ImageFont.truetype("dejavu/DejaVuSans.ttf", fontsize)
|
26 |
+
return FNT
|
27 |
+
|
28 |
+
|
29 |
+
@functools.lru_cache(None) # cache the result
|
30 |
+
def autosized_default_font(size_limit: float) -> ImageFont.ImageFont:
|
31 |
+
fontsize = 1 # starting font size
|
32 |
+
font = __default_font(fontsize)
|
33 |
+
while font.getsize('test123')[1] < size_limit:
|
34 |
+
fontsize += 1
|
35 |
+
font = __default_font(fontsize)
|
36 |
+
fontsize -= 1
|
37 |
+
font = __default_font(fontsize)
|
38 |
+
return font
|
39 |
+
|
40 |
+
|
41 |
+
def iou(masks, gt, thres=0.5):
|
42 |
+
masks = (masks > thres).float()
|
43 |
+
intersect = torch.tensordot(masks, gt, dims=([-2, -1], [0, 1]))
|
44 |
+
union = masks.sum(dim=[-2, -1]) + gt.sum(dim=[-2, -1]) - intersect
|
45 |
+
return intersect / union.clip(min=1e-12)
|
46 |
+
|
47 |
+
|
48 |
+
def get_unsup_image_viz(model, cfg, sample, criterion):
|
49 |
+
if model.training:
|
50 |
+
model.eval()
|
51 |
+
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
|
52 |
+
model.train()
|
53 |
+
else:
|
54 |
+
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
|
55 |
+
return get_image_vis(model, cfg, sample, preds, criterion)
|
56 |
+
|
57 |
+
def get_vis_header(header_size, image_size, header_texts, header_height=20):
|
58 |
+
W, H = (image_size, header_height)
|
59 |
+
header_labels = []
|
60 |
+
font = autosized_default_font(0.8 * H)
|
61 |
+
|
62 |
+
for text in header_texts:
|
63 |
+
im = Image.new("RGB", (W, H), "white")
|
64 |
+
draw = ImageDraw.Draw(im)
|
65 |
+
w, h = draw.textsize(text, font=font)
|
66 |
+
draw.text(((W - w) / 2, (H - h) / 2), text, fill="black", font=font)
|
67 |
+
header_labels.append(torch.from_numpy(np.array(im)))
|
68 |
+
header_labels = torch.cat(header_labels, dim=1)
|
69 |
+
ret = (torch.ones((header_height, header_size, 3)) * 255)
|
70 |
+
ret[:, :header_labels.size(1)] = header_labels
|
71 |
+
|
72 |
+
return ret.permute(2, 0, 1).clip(0, 255).to(torch.uint8)
|
73 |
+
|
74 |
+
def get_image_vis(model, cfg, sample, preds, criterion):
|
75 |
+
masks_pred = torch.stack([x['sem_seg'] for x in preds], 0)
|
76 |
+
|
77 |
+
with torch.no_grad():
|
78 |
+
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
|
79 |
+
|
80 |
+
masks_softmaxed = torch.softmax(masks_pred, dim=1)
|
81 |
+
masks_pred = masks_softmaxed
|
82 |
+
rec_flows = criterion.flow_reconstruction(sample, criterion.process_flow(sample, flow), masks_softmaxed)
|
83 |
+
rec_headers = ['rec_flow']
|
84 |
+
if len(rec_flows) > 1:
|
85 |
+
rec_headers.append('rec_bwd_flow')
|
86 |
+
|
87 |
+
rgb = torch.stack([x['rgb'] for x in sample])
|
88 |
+
flow = criterion.viz_flow(criterion.process_flow(sample, flow).cpu()) * 255
|
89 |
+
rec_flows = [
|
90 |
+
(criterion.viz_flow(rec_flow_.detach().cpu().cpu()) * 255).clip(0, 255).to(torch.uint8) for rec_flow_ in rec_flows
|
91 |
+
]
|
92 |
+
|
93 |
+
|
94 |
+
gt_labels = torch.stack([x['sem_seg'] for x in sample])
|
95 |
+
gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2)
|
96 |
+
target_K = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
97 |
+
masks = F.one_hot(masks_pred.argmax(1).cpu(), target_K).permute(0, 3, 1, 2)
|
98 |
+
masks_each = torch.stack([masks_softmaxed, masks_softmaxed, masks_softmaxed], 2) * 255
|
99 |
+
masks_each = einops.rearrange(F.pad(masks_each.cpu(), pad=[0, 1], value=255), 'b n c h w -> b c h (n w)')
|
100 |
+
|
101 |
+
gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1])
|
102 |
+
pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:target_K])
|
103 |
+
if all('gwm_seg' in d for d in sample):
|
104 |
+
gwm_labels = torch.stack([x['gwm_seg'] for x in sample])
|
105 |
+
mg = F.one_hot(gwm_labels, gwm_labels.max().item() + 1).permute(0, 3, 1, 2)
|
106 |
+
gwm_seg = torch.einsum('b k h w, k c -> b c h w', mg, label_colors[:gwm_labels.max().item() + 1])
|
107 |
+
image_viz = torch.cat(
|
108 |
+
[rgb, flow, F.pad(gt_seg.cpu(), pad=[0, 1], value=255), F.pad(gwm_seg, pad=[0, 1], value=255),
|
109 |
+
pred_seg.cpu(), *rec_flows], -1)
|
110 |
+
header_text = ['rgb', 'gt_flow', 'gt_seg', 'GWM', 'pred_seg', *rec_headers]
|
111 |
+
else:
|
112 |
+
image_viz = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), *rec_flows], -1)
|
113 |
+
header_text = ['rgb', 'gt_flow', 'gt_seg', 'pred_seg', *rec_headers]
|
114 |
+
|
115 |
+
image_viz = torch.cat([image_viz, masks_each], -1)
|
116 |
+
header_text.extend(['slot'] * masks_softmaxed.shape[1])
|
117 |
+
if 'flow_edges' in sample[0]:
|
118 |
+
flow_edges = torch.stack([x['flow_edges'].to(image_viz.device) for x in sample])
|
119 |
+
if len(flow_edges.shape) >= 4:
|
120 |
+
flow_edges = flow_edges.sum(1, keepdim=len(flow_edges.shape) == 4)
|
121 |
+
flow_edges = flow_edges.expand(-1, 3, -1, -1)
|
122 |
+
flow_edges = flow_edges * 255
|
123 |
+
image_viz = torch.cat([image_viz, flow_edges], -1)
|
124 |
+
header_text.append('flow_edges')
|
125 |
+
image_viz = einops.rearrange(image_viz[:8], 'b c h w -> c (b h) w').detach().clip(0, 255).to(torch.uint8)
|
126 |
+
|
127 |
+
return image_viz, header_text
|
128 |
+
|
129 |
+
|
130 |
+
def get_frame_vis(model, cfg, sample, preds):
|
131 |
+
masks_pred = torch.stack([x['sem_seg'] for x in preds], 0)
|
132 |
+
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
|
133 |
+
|
134 |
+
masks_softmaxed = torch.softmax(masks_pred, dim=1)
|
135 |
+
if cfg.GWM.SIMPLE_REC:
|
136 |
+
mask_denom = einops.reduce(masks_softmaxed, 'b k h w -> b k 1', 'sum') + 1e-7
|
137 |
+
means = torch.einsum('brhw, bchw -> brc', masks_softmaxed, flow) / mask_denom
|
138 |
+
rec_flow = torch.einsum('bkhw, bkc-> bchw', masks_softmaxed, means)
|
139 |
+
elif cfg.GWM.HOMOGRAPHY:
|
140 |
+
rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow)
|
141 |
+
else:
|
142 |
+
grid_x, grid_y = grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device)
|
143 |
+
rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow, grid_x, grid_y)
|
144 |
+
|
145 |
+
rgb = torch.stack([x['rgb'] for x in sample])
|
146 |
+
flow = torch.stack([visualisation.flow2rgb_torch(x) for x in flow.cpu()]) * 255
|
147 |
+
rec_flow = torch.stack([visualisation.flow2rgb_torch(x) for x in rec_flow.detach().cpu()]) * 255
|
148 |
+
|
149 |
+
gt_labels = torch.stack([x['sem_seg'] for x in sample])
|
150 |
+
gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2)
|
151 |
+
|
152 |
+
masks = F.one_hot(masks_pred.argmax(1).cpu(), cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES).permute(0, 3, 1, 2)
|
153 |
+
|
154 |
+
gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1])
|
155 |
+
pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES])
|
156 |
+
frame_vis = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), rec_flow.clip(0, 255).to(torch.uint8)], -1)
|
157 |
+
frame_vis = einops.rearrange(frame_vis, 'b c h w -> b c h w').detach().clip(0, 255).to(torch.uint8)
|
158 |
+
return frame_vis
|
159 |
+
|
160 |
+
|
161 |
+
def is_2comp_dataset(dataset):
|
162 |
+
if '+' in dataset:
|
163 |
+
d = dataset.split('+')[0].strip()
|
164 |
+
else:
|
165 |
+
d = dataset.strip()
|
166 |
+
logger.info_once(f"Is 2comp dataset? {d}")
|
167 |
+
for s in ['DAVIS', 'FBMS', 'STv2']:
|
168 |
+
if s in d:
|
169 |
+
return True
|
170 |
+
return d in ['DAVIS',
|
171 |
+
'FBMS',
|
172 |
+
'STv2']
|
173 |
+
|
174 |
+
def eval_unsupmf(cfg, val_loader, model, criterion, writer=None, writer_iteration=0, use_wandb=False):
|
175 |
+
logger.info(f'Running Evaluation: {cfg.LOG_ID} {"Simple" if cfg.GWM.SIMPLE_REC else "Gradient"}:')
|
176 |
+
logger.info(f'Model mode: {"train" if model.training else "eval"}, wandb: {use_wandb}')
|
177 |
+
logger.info(f'Dataset: {cfg.GWM.DATASET} # components: {cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES}')
|
178 |
+
|
179 |
+
merger = None
|
180 |
+
if cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES > 2:
|
181 |
+
merger = MaskMerger(cfg, model)
|
182 |
+
|
183 |
+
print_idxs = random.sample(range(len(val_loader)), k=10)
|
184 |
+
|
185 |
+
images_viz = []
|
186 |
+
ious_davis_eval = defaultdict(list)
|
187 |
+
ious = defaultdict(list)
|
188 |
+
|
189 |
+
for idx, sample in enumerate(tqdm(val_loader)):
|
190 |
+
t = 1
|
191 |
+
sample = [e for s in sample for e in s]
|
192 |
+
category = [s['category'] for s in sample]
|
193 |
+
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
|
194 |
+
masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
|
195 |
+
|
196 |
+
masks_softmaxed = torch.softmax(masks_raw, dim=1)
|
197 |
+
masks_dict = merger(sample, masks_softmaxed)
|
198 |
+
|
199 |
+
if writer and idx in print_idxs:
|
200 |
+
flow = torch.stack([x['flow'] for x in sample]).to(model.device)
|
201 |
+
img_viz, header_text = get_image_vis(model, cfg, sample, preds, criterion)
|
202 |
+
images_viz.append(img_viz)
|
203 |
+
|
204 |
+
masks = masks_dict['cos']
|
205 |
+
gt_seg = torch.stack([x['sem_seg_ori'] for x in sample]).cpu()
|
206 |
+
HW = gt_seg.shape[-2:]
|
207 |
+
if HW != masks.shape[-2:]:
|
208 |
+
logger.info_once(f"Upsampling predicted masks to {HW} for evaluation")
|
209 |
+
masks_softmaxed_sel = F.interpolate(masks.detach().cpu(), size=HW, mode='bilinear', align_corners=False)
|
210 |
+
else:
|
211 |
+
masks_softmaxed_sel = masks.detach().cpu()
|
212 |
+
masks_ = einops.rearrange(masks_softmaxed_sel, '(b t) s h w -> b t s 1 h w', t=t).detach()
|
213 |
+
gt_seg = einops.rearrange(gt_seg, 'b h w -> b 1 h w').float()
|
214 |
+
for i in range(masks_.size(0)):
|
215 |
+
masks_k = F.interpolate(masks_[i], size=(1, gt_seg.shape[-2], gt_seg.shape[-1])) # t s 1 h w
|
216 |
+
mask_iou = iou(masks_k[:, :, 0], gt_seg[i, 0], thres=0.5) # t s
|
217 |
+
iou_max, slot_max = mask_iou.max(dim=1)
|
218 |
+
|
219 |
+
ious[category[i][0]].append(iou_max)
|
220 |
+
frame_id = category[i][1]
|
221 |
+
ious_davis_eval[category[i][0]].append((frame_id.strip().replace('.png', ''), iou_max))
|
222 |
+
|
223 |
+
frameious = sum(ious.values(), [])
|
224 |
+
frame_mean_iou = torch.cat(frameious).sum().item() * 100 / len(frameious)
|
225 |
+
if 'DAVIS' in cfg.GWM.DATASET.split('+')[0]:
|
226 |
+
logger.info_once("Using DAVIS evaluator methods for evaluting IoU -- mean of mean of sequences without first frame")
|
227 |
+
seq_scores = dict()
|
228 |
+
for c in ious_davis_eval:
|
229 |
+
seq_scores[c] = np.nanmean([v.item() for n, v in ious_davis_eval[c] if int(n) > 1])
|
230 |
+
|
231 |
+
frame_mean_iou = np.nanmean(list(seq_scores.values())) * 100
|
232 |
+
|
233 |
+
if writer:
|
234 |
+
header = get_vis_header(images_viz[0].size(2), flow.size(3), header_text)
|
235 |
+
images_viz = torch.cat(images_viz, dim=1)
|
236 |
+
images_viz = torch.cat([header, images_viz], dim=1)
|
237 |
+
writer.add_image('val/images', images_viz, writer_iteration) # C H W
|
238 |
+
writer.add_scalar('eval/mIoU', frame_mean_iou, writer_iteration)
|
239 |
+
|
240 |
+
logger.info(f"mIoU: {frame_mean_iou:.3f} \n")
|
241 |
+
return frame_mean_iou
|
242 |
+
|
243 |
+
|
244 |
+
class MaskMerger:
|
245 |
+
def __init__(self, cfg, model, merger_model="dino_vits8"):
|
246 |
+
self.extractor = ViTExtractor(model_type=merger_model, device=model.device)
|
247 |
+
self.out_dim = 384
|
248 |
+
|
249 |
+
self.mu = torch.tensor(self.extractor.mean).to(model.device).view(1, -1, 1, 1)
|
250 |
+
self.sigma = torch.tensor(self.extractor.std).to(model.device).view(1, -1, 1, 1)
|
251 |
+
self.start_idx = 0
|
252 |
+
|
253 |
+
def get_feats(self, batch):
|
254 |
+
with torch.no_grad():
|
255 |
+
feat = self.extractor.extract_descriptors(batch, facet='key', layer=11, bin=False)
|
256 |
+
feat = feat.reshape(feat.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
|
257 |
+
return F.interpolate(feat, batch.shape[-2:], mode='bilinear')
|
258 |
+
|
259 |
+
def spectral(self, A):
|
260 |
+
clustering = SpectralClustering(n_clusters=2,
|
261 |
+
affinity='precomputed',
|
262 |
+
random_state=0).fit(A.detach().cpu().numpy())
|
263 |
+
return np.arange(A.shape[-1])[clustering.labels_ == 0], np.arange(A.shape[-1])[clustering.labels_ == 1]
|
264 |
+
|
265 |
+
def cos_merge(self, basis, masks):
|
266 |
+
basis = basis / torch.linalg.vector_norm(basis, dim=-1, keepdim=True).clamp(min=1e-6)
|
267 |
+
A = torch.einsum('brc, blc -> brl', basis, basis)[0].clamp(min=1e-6)
|
268 |
+
inda, indb = self.spectral(A)
|
269 |
+
return torch.stack([masks[:, inda].sum(1),
|
270 |
+
masks[:, indb].sum(1)], 1)
|
271 |
+
|
272 |
+
def __call__(self, sample, masks_softmaxed):
|
273 |
+
with torch.no_grad():
|
274 |
+
masks_softmaxed = masks_softmaxed[:, self.start_idx:]
|
275 |
+
batch = torch.stack([x['rgb'].to(masks_softmaxed.device) for x in sample], 0) / 255.0
|
276 |
+
features = self.get_feats((batch - self.mu) / self.sigma)
|
277 |
+
basis = torch.einsum('brhw, bchw -> brc', masks_softmaxed, features)
|
278 |
+
basis /= einops.reduce(masks_softmaxed, 'b r h w -> b r 1', 'sum').clamp_min(1e-12)
|
279 |
+
|
280 |
+
return {
|
281 |
+
'cos': self.cos_merge(basis, masks_softmaxed),
|
282 |
+
}
|
flow_reconstruction.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from dist import LOGGER
|
6 |
+
|
7 |
+
|
8 |
+
def lstq(A, F_u, F_v, lamda=0.01):
|
9 |
+
# cols = A.shape[2]
|
10 |
+
# assert all(cols == torch.linalg.matrix_rank(A)) # something better?
|
11 |
+
try:
|
12 |
+
Q, R = torch.linalg.qr(A)
|
13 |
+
theta_x = torch.bmm(torch.bmm(torch.linalg.inv(R), Q.transpose(1, 2)), F_u)
|
14 |
+
theta_y = torch.bmm(torch.bmm(torch.linalg.inv(R), Q.transpose(1, 2)), F_v)
|
15 |
+
except:
|
16 |
+
LOGGER.exception("Least Squares failed")
|
17 |
+
sys.exit(-1)
|
18 |
+
return theta_x, theta_y
|
19 |
+
|
20 |
+
def get_quad_flow(masks_softmaxed, flow, grid_x, grid_y):
|
21 |
+
rec_flow = 0
|
22 |
+
for k in range(masks_softmaxed.size(1)):
|
23 |
+
mask = masks_softmaxed[:, k].unsqueeze(1)
|
24 |
+
_F = flow * mask
|
25 |
+
M = mask.flatten(1)
|
26 |
+
bs = _F.shape[0]
|
27 |
+
x = grid_x.unsqueeze(0).flatten(1)
|
28 |
+
y = grid_y.unsqueeze(0).flatten(1)
|
29 |
+
|
30 |
+
F_u = _F[:, 0].flatten(1).unsqueeze(2) # B x L x 1
|
31 |
+
F_v = _F[:, 1].flatten(1).unsqueeze(2) # B x L x 1
|
32 |
+
A = torch.stack([x * M, y * M, x*x *M, y*y*M, x*y*M, torch.ones_like(y) * M], 2) # B x L x 2
|
33 |
+
|
34 |
+
theta_x, theta_y = lstq(A, F_u, F_v, lamda=.01)
|
35 |
+
rec_flow_m = torch.stack([torch.einsum('bln,bnk->blk', A, theta_x).view(bs, *grid_x.shape),
|
36 |
+
torch.einsum('bln,bnk->blk', A, theta_y).view(bs, *grid_y.shape)], 1)
|
37 |
+
|
38 |
+
rec_flow += rec_flow_m
|
39 |
+
return rec_flow
|
40 |
+
|
41 |
+
|
42 |
+
SUBSAMPLE = 8
|
43 |
+
SKIP = 0.4
|
44 |
+
SIZE = 0.3
|
45 |
+
NITER = 50
|
46 |
+
METHOD = 'inv_score'
|
47 |
+
|
48 |
+
def set_subsample_skip(sub=None, skip=None, size=None, niter=None, method=None):
|
49 |
+
global SUBSAMPLE, SKIP, SIZE, NITER, METHOD
|
50 |
+
if sub is not None: SUBSAMPLE=sub
|
51 |
+
if skip is not None: SKIP=skip
|
52 |
+
if size is not None: SIZE=size
|
53 |
+
if niter is not None: NITER=niter
|
54 |
+
if method is not None: METHOD=method
|
losses/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .reconstruction_loss import ReconstructionLoss
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class CriterionDict:
|
6 |
+
def __init__(self, dict):
|
7 |
+
self.criterions = dict
|
8 |
+
|
9 |
+
def __call__(self, sample, flow, masks_softmaxed, iteration, train=True, prefix=''):
|
10 |
+
loss = torch.tensor(0., device=masks_softmaxed.device, dtype=masks_softmaxed.dtype)
|
11 |
+
log_dict = {}
|
12 |
+
for name_i, (criterion_i, loss_multiplier_i, anneal_fn_i) in self.criterions.items():
|
13 |
+
loss_i = loss_multiplier_i * anneal_fn_i(iteration) * criterion_i(sample, flow, masks_softmaxed, iteration, train=train)
|
14 |
+
loss += loss_i
|
15 |
+
log_dict[f'loss_{name_i}'] = loss_i.item()
|
16 |
+
|
17 |
+
log_dict['loss_total'] = loss.item()
|
18 |
+
return loss, log_dict
|
19 |
+
|
20 |
+
def flow_reconstruction(self, sample, flow, masks_softmaxed):
|
21 |
+
return self.criterions['reconstruction'][0].rec_flow(sample, flow, masks_softmaxed)
|
22 |
+
|
23 |
+
def process_flow(self, sample, flow):
|
24 |
+
return self.criterions['reconstruction'][0].process_flow(sample, flow)
|
25 |
+
|
26 |
+
def viz_flow(self, flow):
|
27 |
+
return self.criterions['reconstruction'][0].viz_flow(flow)
|
28 |
+
|
losses/reconstruction_loss.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import functools
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
import flow_reconstruction
|
10 |
+
import utils
|
11 |
+
from utils.visualisation import flow2rgb_torch
|
12 |
+
|
13 |
+
logger = utils.log.getLogger(__name__)
|
14 |
+
|
15 |
+
class ReconstructionLoss:
|
16 |
+
def __init__(self, cfg, model):
|
17 |
+
self.criterion = nn.MSELoss() if cfg.GWM.CRITERION == 'L2' else nn.L1Loss()
|
18 |
+
self.l1_optimize = cfg.GWM.L1_OPTIMIZE
|
19 |
+
self.homography = cfg.GWM.HOMOGRAPHY
|
20 |
+
self.device=model.device
|
21 |
+
self.cfg = cfg
|
22 |
+
self.grid_x, self.grid_y = utils.grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device)
|
23 |
+
# self.mult_flow = cfg.GWM.USE_MULT_FLOW
|
24 |
+
self.flow_colorspace_rec = cfg.GWM.FLOW_COLORSPACE_REC
|
25 |
+
flow_reconstruction.set_subsample_skip(cfg.GWM.HOMOGRAPHY_SUBSAMPLE, cfg.GWM.HOMOGRAPHY_SKIP)
|
26 |
+
self.flow_u_low = cfg.GWM.FLOW_CLIP_U_LOW
|
27 |
+
self.flow_u_high = cfg.GWM.FLOW_CLIP_U_HIGH
|
28 |
+
self.flow_v_low = cfg.GWM.FLOW_CLIP_V_LOW
|
29 |
+
self.flow_v_high = cfg.GWM.FLOW_CLIP_V_HIGH
|
30 |
+
|
31 |
+
self._recon_fn = self.flow_quad
|
32 |
+
logger.info(f'Using reconstruction method {self._recon_fn.__name__}')
|
33 |
+
self.it = 0
|
34 |
+
self._extra_losses = []
|
35 |
+
|
36 |
+
def __call__(self, sample, flow, masks_softmaxed, it, train=True):
|
37 |
+
return self.loss(sample, flow, masks_softmaxed, it, train=train)
|
38 |
+
|
39 |
+
def loss(self, sample, flow, mask_softmaxed, it, train=True):
|
40 |
+
self.training = train
|
41 |
+
flow = self.process_flow(sample, flow)
|
42 |
+
self.it = it
|
43 |
+
self._extra_losses = []
|
44 |
+
|
45 |
+
if self.cfg.GWM.FLOW_RES is not None:
|
46 |
+
if flow.shape[-2:] != mask_softmaxed.shape[-2:]:
|
47 |
+
logger.debug_once(f'Resizing predicted masks to {self.cfg.GWM.FLOW_RES}')
|
48 |
+
mask_softmaxed = F.interpolate(mask_softmaxed, flow.shape[-2:], mode='bilinear', align_corners=False)
|
49 |
+
|
50 |
+
rec_flows = self.rec_flow(sample, flow, mask_softmaxed)
|
51 |
+
if not isinstance(rec_flows, (list, tuple)):
|
52 |
+
rec_flows = (rec_flows,)
|
53 |
+
k = len(rec_flows)
|
54 |
+
loss = sum(self.criterion(flow, rec_flow) / k for rec_flow in rec_flows)
|
55 |
+
if len(self._extra_losses):
|
56 |
+
loss = loss + sum(self._extra_losses, 0.) / len(self._extra_losses)
|
57 |
+
self._extra_losses = []
|
58 |
+
return loss
|
59 |
+
|
60 |
+
def flow_quad(self, sample, flow, masks_softmaxed, it, **_):
|
61 |
+
logger.debug_once(f'Reconstruction using quadratic. Masks shape {masks_softmaxed.shape} | '
|
62 |
+
f'Flow shape {flow.shape} | '
|
63 |
+
f'Grid shape {self.grid_x.shape, self.grid_y.shape}')
|
64 |
+
return flow_reconstruction.get_quad_flow(masks_softmaxed, flow, self.grid_x, self.grid_y)
|
65 |
+
|
66 |
+
def _clipped_recon_fn(self, *args, **kwargs):
|
67 |
+
flow = self._recon_fn(*args, **kwargs)
|
68 |
+
flow_o = flow[:, :-2]
|
69 |
+
flow_u = flow[:, -2:-1].clip(self.flow_u_low, self.flow_u_high)
|
70 |
+
flow_v = flow[:, -1:].clip(self.flow_v_low, self.flow_v_high)
|
71 |
+
return torch.cat([flow_o, flow_u, flow_v], dim=1)
|
72 |
+
|
73 |
+
def rec_flow(self, sample, flow, masks_softmaxed):
|
74 |
+
it = self.it
|
75 |
+
if self.cfg.GWM.FLOW_RES is not None and flow.shape[-2:] != self.grid_x.shape[-2:]:
|
76 |
+
logger.debug_once(f'Generating new grid predicted masks of {flow.shape[-2:]}')
|
77 |
+
self.grid_x, self.grid_y = utils.grid.get_meshgrid(flow.shape[-2:], self.device)
|
78 |
+
return [self._clipped_recon_fn(sample, flow, masks_softmaxed, it)]
|
79 |
+
|
80 |
+
def process_flow(self, sample, flow_cuda):
|
81 |
+
return flow_cuda
|
82 |
+
|
83 |
+
def viz_flow(self, flow):
|
84 |
+
return torch.stack([flow2rgb_torch(x) for x in flow])
|
85 |
+
|
main.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import determinism # noqa
|
2 |
+
|
3 |
+
determinism.i_do_nothing_but_dont_remove_me_otherwise_things_break() # noqa
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import bisect
|
7 |
+
import copy
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import wandb
|
15 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
16 |
+
from detectron2.engine import PeriodicCheckpointer
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
import config
|
21 |
+
import losses
|
22 |
+
import utils
|
23 |
+
from eval_utils import eval_unsupmf, get_unsup_image_viz, get_vis_header
|
24 |
+
from mask_former_trainer import setup, Trainer
|
25 |
+
|
26 |
+
|
27 |
+
logger = utils.log.getLogger('gwm')
|
28 |
+
|
29 |
+
def freeze(module, set=False):
|
30 |
+
for param in module.parameters():
|
31 |
+
param.requires_grad = set
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
cfg = setup(args)
|
36 |
+
logger.info(f"Called as {' '.join(sys.argv)}")
|
37 |
+
logger.info(f'Output dir {cfg.OUTPUT_DIR}')
|
38 |
+
|
39 |
+
random_state = utils.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE))
|
40 |
+
random_state.seed_everything()
|
41 |
+
utils.log.checkpoint_code(cfg.OUTPUT_DIR)
|
42 |
+
|
43 |
+
if not cfg.SKIP_TB:
|
44 |
+
writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
|
45 |
+
else:
|
46 |
+
writer = None
|
47 |
+
|
48 |
+
# initialize model
|
49 |
+
model = Trainer.build_model(cfg)
|
50 |
+
optimizer = Trainer.build_optimizer(cfg, model)
|
51 |
+
scheduler = Trainer.build_lr_scheduler(cfg, optimizer)
|
52 |
+
|
53 |
+
logger.info(f'Optimiser is {type(optimizer)}')
|
54 |
+
|
55 |
+
|
56 |
+
checkpointer = DetectionCheckpointer(model,
|
57 |
+
save_dir=os.path.join(cfg.OUTPUT_DIR, 'checkpoints'),
|
58 |
+
random_state=random_state,
|
59 |
+
optimizer=optimizer,
|
60 |
+
scheduler=scheduler)
|
61 |
+
periodic_checkpointer = PeriodicCheckpointer(checkpointer=checkpointer,
|
62 |
+
period=cfg.SOLVER.CHECKPOINT_PERIOD,
|
63 |
+
max_iter=cfg.SOLVER.MAX_ITER,
|
64 |
+
max_to_keep=None if cfg.FLAGS.KEEP_ALL else 5,
|
65 |
+
file_prefix='checkpoint')
|
66 |
+
checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume_path is not None)
|
67 |
+
iteration = 0 if args.resume_path is None else checkpoint['iteration']
|
68 |
+
|
69 |
+
train_loader, val_loader = config.loaders(cfg)
|
70 |
+
|
71 |
+
# overfit single batch for debug
|
72 |
+
# sample = next(iter(loader))
|
73 |
+
|
74 |
+
criterions = {
|
75 |
+
'reconstruction': (losses.ReconstructionLoss(cfg, model), cfg.GWM.LOSS_MULT.REC, lambda x: 1)}
|
76 |
+
|
77 |
+
criterion = losses.CriterionDict(criterions)
|
78 |
+
|
79 |
+
if args.eval_only:
|
80 |
+
if len(val_loader.dataset) == 0:
|
81 |
+
logger.error("Training dataset: empty")
|
82 |
+
sys.exit(0)
|
83 |
+
model.eval()
|
84 |
+
iou = eval_unsupmf(cfg=cfg, val_loader=val_loader, model=model, criterion=criterion, writer=writer,
|
85 |
+
writer_iteration=iteration)
|
86 |
+
logger.info(f"Results: iteration: {iteration} IOU = {iou}")
|
87 |
+
return
|
88 |
+
if len(train_loader.dataset) == 0:
|
89 |
+
logger.error("Training dataset: empty")
|
90 |
+
sys.exit(0)
|
91 |
+
|
92 |
+
logger.info(
|
93 |
+
f'Start of training: dataset {cfg.GWM.DATASET},'
|
94 |
+
f' train {len(train_loader.dataset)}, val {len(val_loader.dataset)},'
|
95 |
+
f' device {model.device}, keys {cfg.GWM.SAMPLE_KEYS}, '
|
96 |
+
f'multiple flows {cfg.GWM.USE_MULT_FLOW}')
|
97 |
+
|
98 |
+
iou_best = 0
|
99 |
+
timestart = time.time()
|
100 |
+
dilate_kernel = torch.ones((2, 2), device=model.device)
|
101 |
+
|
102 |
+
total_iter = cfg.TOTAL_ITER if cfg.TOTAL_ITER else cfg.SOLVER.MAX_ITER # early stop
|
103 |
+
with torch.autograd.set_detect_anomaly(cfg.DEBUG) and \
|
104 |
+
tqdm(initial=iteration, total=total_iter, disable=utils.environment.is_slurm()) as pbar:
|
105 |
+
while iteration < total_iter:
|
106 |
+
for sample in train_loader:
|
107 |
+
|
108 |
+
if cfg.MODEL.META_ARCHITECTURE != 'UNET' and cfg.FLAGS.UNFREEZE_AT:
|
109 |
+
if hasattr(model.backbone, 'frozen_stages'):
|
110 |
+
assert cfg.MODEL.BACKBONE.FREEZE_AT == -1, f"MODEL initial parameters forced frozen"
|
111 |
+
stages = [s for s, m in cfg.FLAGS.UNFREEZE_AT]
|
112 |
+
milest = [m for s, m in cfg.FLAGS.UNFREEZE_AT]
|
113 |
+
pos = bisect.bisect_right(milest, iteration) - 1
|
114 |
+
if pos >= 0:
|
115 |
+
curr_setting = model.backbone.frozen_stages
|
116 |
+
if curr_setting != stages[pos]:
|
117 |
+
logger.info(f"Updating backbone freezing stages from {curr_setting} to {stages[pos]}")
|
118 |
+
model.backbone.frozen_stages = stages[pos]
|
119 |
+
model.train()
|
120 |
+
else:
|
121 |
+
assert cfg.MODEL.BACKBONE.FREEZE_AT == -1, f"MODEL initial parameters forced frozen"
|
122 |
+
stages = [s for s, m in cfg.FLAGS.UNFREEZE_AT]
|
123 |
+
milest = [m for s, m in cfg.FLAGS.UNFREEZE_AT]
|
124 |
+
pos = bisect.bisect_right(milest, iteration) - 1
|
125 |
+
freeze(model, set=False)
|
126 |
+
freeze(model.sem_seg_head.predictor, set=True)
|
127 |
+
if pos >= 0:
|
128 |
+
stage = stages[pos]
|
129 |
+
if stage <= 2:
|
130 |
+
freeze(model.sem_seg_head, set=True)
|
131 |
+
if stage <= 1:
|
132 |
+
freeze(model.backbone, set=True)
|
133 |
+
model.train()
|
134 |
+
|
135 |
+
else:
|
136 |
+
logger.debug_once(f'Unfreezing disabled schedule: {cfg.FLAGS.UNFREEZE_AT}')
|
137 |
+
|
138 |
+
sample = [e for s in sample for e in s]
|
139 |
+
flow_key = 'flow'
|
140 |
+
raw_sem_seg = False
|
141 |
+
if cfg.GWM.FLOW_RES is not None:
|
142 |
+
flow_key = 'flow_big'
|
143 |
+
raw_sem_seg = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME == 'MegaBigPixelDecoder'
|
144 |
+
|
145 |
+
flow = torch.stack([x[flow_key].to(model.device) for x in sample]).clip(-20, 20)
|
146 |
+
logger.debug_once(f'flow shape: {flow.shape}')
|
147 |
+
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True, raw_sem_seg=raw_sem_seg)
|
148 |
+
masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
|
149 |
+
logger.debug_once(f'mask shape: {masks_raw.shape}')
|
150 |
+
masks_softmaxed_list = [torch.softmax(masks_raw, dim=1)]
|
151 |
+
|
152 |
+
|
153 |
+
total_losses = []
|
154 |
+
log_dicts = []
|
155 |
+
for mask_idx, masks_softmaxed in enumerate(masks_softmaxed_list):
|
156 |
+
|
157 |
+
loss, log_dict = criterion(sample, flow, masks_softmaxed, iteration)
|
158 |
+
|
159 |
+
if cfg.GWM.USE_MULT_FLOW:
|
160 |
+
flow2 = torch.stack([x[flow_key + '_2'].to(model.device) for x in sample]).clip(-20, 20)
|
161 |
+
other_loss, other_log_dict = criterion(sample, flow2, masks_softmaxed, iteration)
|
162 |
+
loss = loss / 2 + other_loss / 2
|
163 |
+
for k, v in other_log_dict.items():
|
164 |
+
log_dict[k] = other_log_dict[k] / 2 + v / 2
|
165 |
+
total_losses.append(loss)
|
166 |
+
log_dicts.append(log_dict)
|
167 |
+
|
168 |
+
loss_ws = cfg.GWM.LOSS_MULT.HEIR_W
|
169 |
+
total_w = float(sum(loss_ws[:len(total_losses)]))
|
170 |
+
log_dict = {}
|
171 |
+
if len(total_losses) == 1:
|
172 |
+
log_dict = log_dicts[0]
|
173 |
+
loss = total_losses[0]
|
174 |
+
else:
|
175 |
+
loss = 0
|
176 |
+
for i, (tl, w, ld) in enumerate(zip(total_losses, loss_ws, log_dicts)):
|
177 |
+
for k, v in ld.items():
|
178 |
+
log_dict[f'{k}_{i}'] = v * w / total_w
|
179 |
+
loss += tl * w / total_w
|
180 |
+
|
181 |
+
train_log_dict = {f'train/{k}': v for k, v in log_dict.items()}
|
182 |
+
del log_dict
|
183 |
+
train_log_dict['train/learning_rate'] = optimizer.param_groups[-1]['lr']
|
184 |
+
train_log_dict['train/loss_total'] = loss.item()
|
185 |
+
|
186 |
+
|
187 |
+
optimizer.zero_grad()
|
188 |
+
|
189 |
+
|
190 |
+
loss.backward()
|
191 |
+
optimizer.step()
|
192 |
+
scheduler.step()
|
193 |
+
|
194 |
+
pbar.set_postfix(loss=loss.item())
|
195 |
+
pbar.update()
|
196 |
+
|
197 |
+
# Sanity check for RNG state
|
198 |
+
if (iteration + 1) % 1000 == 0 or iteration + 1 in {1, 50}:
|
199 |
+
logger.info(
|
200 |
+
f'Iteration {iteration + 1}. RNG outputs {utils.random_state.get_randstate_magic_numbers(model.device)}')
|
201 |
+
|
202 |
+
if cfg.DEBUG or (iteration + 1) % 100 == 0:
|
203 |
+
logger.info(
|
204 |
+
f'Iteration: {iteration + 1}, time: {time.time() - timestart:.01f}s, loss: {loss.item():.02f}.')
|
205 |
+
|
206 |
+
for k, v in train_log_dict.items():
|
207 |
+
if writer:
|
208 |
+
writer.add_scalar(k, v, iteration + 1)
|
209 |
+
|
210 |
+
if cfg.WANDB.ENABLE:
|
211 |
+
wandb.log(train_log_dict, step=iteration + 1)
|
212 |
+
|
213 |
+
if (iteration + 1) % cfg.LOG_FREQ == 0 or (iteration + 1) in [1, 50, 500]:
|
214 |
+
model.eval()
|
215 |
+
if writer:
|
216 |
+
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
|
217 |
+
image_viz, header_text = get_unsup_image_viz(model, cfg, sample, criterion)
|
218 |
+
header = get_vis_header(image_viz.size(2), flow.size(3), header_text)
|
219 |
+
image_viz = torch.cat([header, image_viz], dim=1)
|
220 |
+
writer.add_image('train/images', image_viz, iteration + 1)
|
221 |
+
if cfg.WANDB.ENABLE and (iteration + 1) % 2500 == 0:
|
222 |
+
image_viz = get_unsup_image_viz(model, cfg, sample)
|
223 |
+
wandb.log({'train/viz': wandb.Image(image_viz.float())}, step=iteration + 1)
|
224 |
+
|
225 |
+
if iou := eval_unsupmf(cfg=cfg, val_loader=val_loader, model=model, criterion=criterion,
|
226 |
+
writer=writer, writer_iteration=iteration + 1, use_wandb=cfg.WANDB.ENABLE):
|
227 |
+
if cfg.SOLVER.CHECKPOINT_PERIOD and iou > iou_best:
|
228 |
+
iou_best = iou
|
229 |
+
if not args.wandb_sweep_mode:
|
230 |
+
checkpointer.save(name='checkpoint_best', iteration=iteration + 1, loss=loss,
|
231 |
+
iou=iou_best)
|
232 |
+
logger.info(f'New best IoU {iou_best:.02f} after iteration {iteration + 1}')
|
233 |
+
if cfg.WANDB.ENABLE:
|
234 |
+
wandb.log({'eval/IoU_best': iou_best}, step=iteration + 1)
|
235 |
+
if writer:
|
236 |
+
writer.add_scalar('eval/IoU_best', iou_best, iteration + 1)
|
237 |
+
|
238 |
+
|
239 |
+
model.train()
|
240 |
+
|
241 |
+
periodic_checkpointer.step(iteration=iteration + 1, loss=loss)
|
242 |
+
|
243 |
+
iteration += 1
|
244 |
+
timestart = time.time()
|
245 |
+
|
246 |
+
|
247 |
+
def get_argparse_args():
|
248 |
+
parser = ArgumentParser()
|
249 |
+
parser.add_argument('--resume_path', type=str, default=None)
|
250 |
+
parser.add_argument('--use_wandb', dest='wandb_sweep_mode', action='store_true') # for sweep
|
251 |
+
parser.add_argument('--config-file', type=str,
|
252 |
+
default='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml')
|
253 |
+
parser.add_argument('--eval_only', action='store_true')
|
254 |
+
parser.add_argument(
|
255 |
+
"opts",
|
256 |
+
help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. "
|
257 |
+
"See config references at "
|
258 |
+
"https://detectron2.readthedocs.io/modules/config.html#config-references",
|
259 |
+
default=None,
|
260 |
+
nargs=argparse.REMAINDER,
|
261 |
+
)
|
262 |
+
return parser
|
263 |
+
|
264 |
+
|
265 |
+
if __name__ == "__main__":
|
266 |
+
args = get_argparse_args().parse_args()
|
267 |
+
if args.resume_path:
|
268 |
+
args.config_file = "/".join(args.resume_path.split('/')[:-2]) + '/config.yaml'
|
269 |
+
print(args.config_file)
|
270 |
+
main(args)
|
mask_former/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import data # register all new datasets
|
3 |
+
from . import modeling
|
4 |
+
|
5 |
+
# config
|
6 |
+
from .config import add_mask_former_config
|
7 |
+
|
8 |
+
# dataset loading
|
9 |
+
from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper
|
10 |
+
from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import (
|
11 |
+
MaskFormerPanopticDatasetMapper,
|
12 |
+
)
|
13 |
+
from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
|
14 |
+
MaskFormerSemanticDatasetMapper,
|
15 |
+
)
|
16 |
+
|
17 |
+
# models
|
18 |
+
from .mask_former_model import MaskFormer
|
19 |
+
from .test_time_augmentation import SemanticSegmentorWithTTA
|
mask_former/config.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
from detectron2.config import CfgNode as CN
|
4 |
+
|
5 |
+
|
6 |
+
def add_mask_former_config(cfg):
|
7 |
+
"""
|
8 |
+
Add config for MASK_FORMER.
|
9 |
+
"""
|
10 |
+
# data config
|
11 |
+
# select the dataset mapper
|
12 |
+
cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
|
13 |
+
# Color augmentation
|
14 |
+
cfg.INPUT.COLOR_AUG_SSD = False
|
15 |
+
# We retry random cropping until no single category in semantic segmentation GT occupies more
|
16 |
+
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
|
17 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
|
18 |
+
# Pad image and segmentation GT in dataset mapper.
|
19 |
+
cfg.INPUT.SIZE_DIVISIBILITY = -1
|
20 |
+
|
21 |
+
# solver config
|
22 |
+
# weight decay on embedding
|
23 |
+
cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
|
24 |
+
# optimizer
|
25 |
+
cfg.SOLVER.OPTIMIZER = "ADAMW"
|
26 |
+
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
|
27 |
+
|
28 |
+
# mask_former model config
|
29 |
+
cfg.MODEL.MASK_FORMER = CN()
|
30 |
+
|
31 |
+
# loss
|
32 |
+
cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
|
33 |
+
cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
|
34 |
+
cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
|
35 |
+
cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
|
36 |
+
|
37 |
+
# transformer config
|
38 |
+
cfg.MODEL.MASK_FORMER.NHEADS = 8
|
39 |
+
cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
|
40 |
+
cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
|
41 |
+
cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
|
42 |
+
cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
|
43 |
+
cfg.MODEL.MASK_FORMER.PRE_NORM = False
|
44 |
+
|
45 |
+
cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
|
46 |
+
cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
|
47 |
+
|
48 |
+
cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
|
49 |
+
cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
|
50 |
+
|
51 |
+
# mask_former inference config
|
52 |
+
cfg.MODEL.MASK_FORMER.TEST = CN()
|
53 |
+
cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
|
54 |
+
cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
|
55 |
+
cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
|
56 |
+
cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
|
57 |
+
|
58 |
+
# Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
|
59 |
+
# you can use this config to override
|
60 |
+
cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
|
61 |
+
|
62 |
+
# pixel decoder config
|
63 |
+
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
|
64 |
+
# adding transformer in pixel decoder
|
65 |
+
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
|
66 |
+
# pixel decoder
|
67 |
+
cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
|
68 |
+
|
69 |
+
# swin transformer backbone
|
70 |
+
cfg.MODEL.SWIN = CN()
|
71 |
+
cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
|
72 |
+
cfg.MODEL.SWIN.PATCH_SIZE = 4
|
73 |
+
cfg.MODEL.SWIN.EMBED_DIM = 96
|
74 |
+
cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
|
75 |
+
cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
76 |
+
cfg.MODEL.SWIN.WINDOW_SIZE = 7
|
77 |
+
cfg.MODEL.SWIN.MLP_RATIO = 4.0
|
78 |
+
cfg.MODEL.SWIN.QKV_BIAS = True
|
79 |
+
cfg.MODEL.SWIN.QK_SCALE = None
|
80 |
+
cfg.MODEL.SWIN.DROP_RATE = 0.0
|
81 |
+
cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
|
82 |
+
cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
|
83 |
+
cfg.MODEL.SWIN.APE = False
|
84 |
+
cfg.MODEL.SWIN.PATCH_NORM = True
|
85 |
+
cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
mask_former/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import datasets
|
mask_former/data/dataset_mappers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
|
3 |
+
import copy
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.data import detection_utils as utils
|
11 |
+
from detectron2.data import transforms as T
|
12 |
+
from detectron2.data.transforms import TransformGen
|
13 |
+
from detectron2.structures import BitMasks, Instances
|
14 |
+
|
15 |
+
__all__ = ["DETRPanopticDatasetMapper"]
|
16 |
+
|
17 |
+
|
18 |
+
def build_transform_gen(cfg, is_train):
|
19 |
+
"""
|
20 |
+
Create a list of :class:`TransformGen` from config.
|
21 |
+
Returns:
|
22 |
+
list[TransformGen]
|
23 |
+
"""
|
24 |
+
if is_train:
|
25 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
26 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
27 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
28 |
+
else:
|
29 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
30 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
31 |
+
sample_style = "choice"
|
32 |
+
if sample_style == "range":
|
33 |
+
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(
|
34 |
+
len(min_size)
|
35 |
+
)
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
tfm_gens = []
|
39 |
+
if is_train:
|
40 |
+
tfm_gens.append(T.RandomFlip())
|
41 |
+
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
|
42 |
+
if is_train:
|
43 |
+
logger.info("TransformGens used in training: " + str(tfm_gens))
|
44 |
+
return tfm_gens
|
45 |
+
|
46 |
+
|
47 |
+
# This is specifically designed for the COCO dataset.
|
48 |
+
class DETRPanopticDatasetMapper:
|
49 |
+
"""
|
50 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
51 |
+
and map it into a format used by MaskFormer.
|
52 |
+
|
53 |
+
This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
|
54 |
+
|
55 |
+
The callable currently does the following:
|
56 |
+
|
57 |
+
1. Read the image from "file_name"
|
58 |
+
2. Applies geometric transforms to the image and annotation
|
59 |
+
3. Find and applies suitable cropping to the image and annotation
|
60 |
+
4. Prepare image and annotation to Tensors
|
61 |
+
"""
|
62 |
+
|
63 |
+
@configurable
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
is_train=True,
|
67 |
+
*,
|
68 |
+
crop_gen,
|
69 |
+
tfm_gens,
|
70 |
+
image_format,
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
NOTE: this interface is experimental.
|
74 |
+
Args:
|
75 |
+
is_train: for training or inference
|
76 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
77 |
+
crop_gen: crop augmentation
|
78 |
+
tfm_gens: data augmentation
|
79 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
80 |
+
"""
|
81 |
+
self.crop_gen = crop_gen
|
82 |
+
self.tfm_gens = tfm_gens
|
83 |
+
logging.getLogger(__name__).info(
|
84 |
+
"[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format(
|
85 |
+
str(self.tfm_gens), str(self.crop_gen)
|
86 |
+
)
|
87 |
+
)
|
88 |
+
|
89 |
+
self.img_format = image_format
|
90 |
+
self.is_train = is_train
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_config(cls, cfg, is_train=True):
|
94 |
+
# Build augmentation
|
95 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
96 |
+
crop_gen = [
|
97 |
+
T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
|
98 |
+
T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
|
99 |
+
]
|
100 |
+
else:
|
101 |
+
crop_gen = None
|
102 |
+
|
103 |
+
tfm_gens = build_transform_gen(cfg, is_train)
|
104 |
+
|
105 |
+
ret = {
|
106 |
+
"is_train": is_train,
|
107 |
+
"crop_gen": crop_gen,
|
108 |
+
"tfm_gens": tfm_gens,
|
109 |
+
"image_format": cfg.INPUT.FORMAT,
|
110 |
+
}
|
111 |
+
return ret
|
112 |
+
|
113 |
+
def __call__(self, dataset_dict):
|
114 |
+
"""
|
115 |
+
Args:
|
116 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
dict: a format that builtin models in detectron2 accept
|
120 |
+
"""
|
121 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
122 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
123 |
+
utils.check_image_size(dataset_dict, image)
|
124 |
+
|
125 |
+
if self.crop_gen is None:
|
126 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
127 |
+
else:
|
128 |
+
if np.random.rand() > 0.5:
|
129 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
130 |
+
else:
|
131 |
+
image, transforms = T.apply_transform_gens(
|
132 |
+
self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
|
133 |
+
)
|
134 |
+
|
135 |
+
image_shape = image.shape[:2] # h, w
|
136 |
+
|
137 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
138 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
139 |
+
# Therefore it's important to use torch.Tensor.
|
140 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
141 |
+
|
142 |
+
if not self.is_train:
|
143 |
+
# USER: Modify this if you want to keep them for some reason.
|
144 |
+
dataset_dict.pop("annotations", None)
|
145 |
+
return dataset_dict
|
146 |
+
|
147 |
+
if "pan_seg_file_name" in dataset_dict:
|
148 |
+
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
|
149 |
+
segments_info = dataset_dict["segments_info"]
|
150 |
+
|
151 |
+
# apply the same transformation to panoptic segmentation
|
152 |
+
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
|
153 |
+
|
154 |
+
from panopticapi.utils import rgb2id
|
155 |
+
|
156 |
+
pan_seg_gt = rgb2id(pan_seg_gt)
|
157 |
+
|
158 |
+
instances = Instances(image_shape)
|
159 |
+
classes = []
|
160 |
+
masks = []
|
161 |
+
for segment_info in segments_info:
|
162 |
+
class_id = segment_info["category_id"]
|
163 |
+
if not segment_info["iscrowd"]:
|
164 |
+
classes.append(class_id)
|
165 |
+
masks.append(pan_seg_gt == segment_info["id"])
|
166 |
+
|
167 |
+
classes = np.array(classes)
|
168 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
169 |
+
if len(masks) == 0:
|
170 |
+
# Some image does not have annotation (all ignored)
|
171 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
172 |
+
else:
|
173 |
+
masks = BitMasks(
|
174 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
175 |
+
)
|
176 |
+
instances.gt_masks = masks.tensor
|
177 |
+
|
178 |
+
dataset_dict["instances"] = instances
|
179 |
+
|
180 |
+
return dataset_dict
|
mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.data import detection_utils as utils
|
11 |
+
from detectron2.data import transforms as T
|
12 |
+
from detectron2.structures import BitMasks, Instances
|
13 |
+
|
14 |
+
from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
|
15 |
+
|
16 |
+
__all__ = ["MaskFormerPanopticDatasetMapper"]
|
17 |
+
|
18 |
+
|
19 |
+
class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper):
|
20 |
+
"""
|
21 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
22 |
+
and map it into a format used by MaskFormer for panoptic segmentation.
|
23 |
+
|
24 |
+
The callable currently does the following:
|
25 |
+
|
26 |
+
1. Read the image from "file_name"
|
27 |
+
2. Applies geometric transforms to the image and annotation
|
28 |
+
3. Find and applies suitable cropping to the image and annotation
|
29 |
+
4. Prepare image and annotation to Tensors
|
30 |
+
"""
|
31 |
+
|
32 |
+
@configurable
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
is_train=True,
|
36 |
+
*,
|
37 |
+
augmentations,
|
38 |
+
image_format,
|
39 |
+
ignore_label,
|
40 |
+
size_divisibility,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
NOTE: this interface is experimental.
|
44 |
+
Args:
|
45 |
+
is_train: for training or inference
|
46 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
47 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
48 |
+
ignore_label: the label that is ignored to evaluation
|
49 |
+
size_divisibility: pad image size to be divisible by this value
|
50 |
+
"""
|
51 |
+
super().__init__(
|
52 |
+
is_train,
|
53 |
+
augmentations=augmentations,
|
54 |
+
image_format=image_format,
|
55 |
+
ignore_label=ignore_label,
|
56 |
+
size_divisibility=size_divisibility,
|
57 |
+
)
|
58 |
+
|
59 |
+
def __call__(self, dataset_dict):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
dict: a format that builtin models in detectron2 accept
|
66 |
+
"""
|
67 |
+
assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!"
|
68 |
+
|
69 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
70 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
71 |
+
utils.check_image_size(dataset_dict, image)
|
72 |
+
|
73 |
+
# semantic segmentation
|
74 |
+
if "sem_seg_file_name" in dataset_dict:
|
75 |
+
# PyTorch transformation not implemented for uint16, so converting it to double first
|
76 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
|
77 |
+
else:
|
78 |
+
sem_seg_gt = None
|
79 |
+
|
80 |
+
# panoptic segmentation
|
81 |
+
if "pan_seg_file_name" in dataset_dict:
|
82 |
+
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
|
83 |
+
segments_info = dataset_dict["segments_info"]
|
84 |
+
else:
|
85 |
+
pan_seg_gt = None
|
86 |
+
segments_info = None
|
87 |
+
|
88 |
+
if pan_seg_gt is None:
|
89 |
+
raise ValueError(
|
90 |
+
"Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
|
91 |
+
dataset_dict["file_name"]
|
92 |
+
)
|
93 |
+
)
|
94 |
+
|
95 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
96 |
+
aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
|
97 |
+
image = aug_input.image
|
98 |
+
if sem_seg_gt is not None:
|
99 |
+
sem_seg_gt = aug_input.sem_seg
|
100 |
+
|
101 |
+
# apply the same transformation to panoptic segmentation
|
102 |
+
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
|
103 |
+
|
104 |
+
from panopticapi.utils import rgb2id
|
105 |
+
|
106 |
+
pan_seg_gt = rgb2id(pan_seg_gt)
|
107 |
+
|
108 |
+
# Pad image and segmentation label here!
|
109 |
+
image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
110 |
+
if sem_seg_gt is not None:
|
111 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
112 |
+
pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
|
113 |
+
|
114 |
+
if self.size_divisibility > 0:
|
115 |
+
image_size = (image.shape[-2], image.shape[-1])
|
116 |
+
padding_size = [
|
117 |
+
0,
|
118 |
+
self.size_divisibility - image_size[1],
|
119 |
+
0,
|
120 |
+
self.size_divisibility - image_size[0],
|
121 |
+
]
|
122 |
+
image = F.pad(image, padding_size, value=128).contiguous()
|
123 |
+
if sem_seg_gt is not None:
|
124 |
+
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
125 |
+
pan_seg_gt = F.pad(
|
126 |
+
pan_seg_gt, padding_size, value=0
|
127 |
+
).contiguous() # 0 is the VOID panoptic label
|
128 |
+
|
129 |
+
image_shape = (image.shape[-2], image.shape[-1]) # h, w
|
130 |
+
|
131 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
132 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
133 |
+
# Therefore it's important to use torch.Tensor.
|
134 |
+
dataset_dict["image"] = image
|
135 |
+
if sem_seg_gt is not None:
|
136 |
+
dataset_dict["sem_seg"] = sem_seg_gt.long()
|
137 |
+
|
138 |
+
if "annotations" in dataset_dict:
|
139 |
+
raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
|
140 |
+
|
141 |
+
# Prepare per-category binary masks
|
142 |
+
pan_seg_gt = pan_seg_gt.numpy()
|
143 |
+
instances = Instances(image_shape)
|
144 |
+
classes = []
|
145 |
+
masks = []
|
146 |
+
for segment_info in segments_info:
|
147 |
+
class_id = segment_info["category_id"]
|
148 |
+
if not segment_info["iscrowd"]:
|
149 |
+
classes.append(class_id)
|
150 |
+
masks.append(pan_seg_gt == segment_info["id"])
|
151 |
+
|
152 |
+
classes = np.array(classes)
|
153 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
154 |
+
if len(masks) == 0:
|
155 |
+
# Some image does not have annotation (all ignored)
|
156 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
157 |
+
else:
|
158 |
+
masks = BitMasks(
|
159 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
160 |
+
)
|
161 |
+
instances.gt_masks = masks.tensor
|
162 |
+
|
163 |
+
dataset_dict["instances"] = instances
|
164 |
+
|
165 |
+
return dataset_dict
|
mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.data import MetadataCatalog
|
11 |
+
from detectron2.data import detection_utils as utils
|
12 |
+
from detectron2.data import transforms as T
|
13 |
+
from detectron2.projects.point_rend import ColorAugSSDTransform
|
14 |
+
from detectron2.structures import BitMasks, Instances
|
15 |
+
|
16 |
+
__all__ = ["MaskFormerSemanticDatasetMapper"]
|
17 |
+
|
18 |
+
|
19 |
+
class MaskFormerSemanticDatasetMapper:
|
20 |
+
"""
|
21 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
22 |
+
and map it into a format used by MaskFormer for semantic segmentation.
|
23 |
+
|
24 |
+
The callable currently does the following:
|
25 |
+
|
26 |
+
1. Read the image from "file_name"
|
27 |
+
2. Applies geometric transforms to the image and annotation
|
28 |
+
3. Find and applies suitable cropping to the image and annotation
|
29 |
+
4. Prepare image and annotation to Tensors
|
30 |
+
"""
|
31 |
+
|
32 |
+
@configurable
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
is_train=True,
|
36 |
+
*,
|
37 |
+
augmentations,
|
38 |
+
image_format,
|
39 |
+
ignore_label,
|
40 |
+
size_divisibility,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
NOTE: this interface is experimental.
|
44 |
+
Args:
|
45 |
+
is_train: for training or inference
|
46 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
47 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
48 |
+
ignore_label: the label that is ignored to evaluation
|
49 |
+
size_divisibility: pad image size to be divisible by this value
|
50 |
+
"""
|
51 |
+
self.is_train = is_train
|
52 |
+
self.tfm_gens = augmentations
|
53 |
+
self.img_format = image_format
|
54 |
+
self.ignore_label = ignore_label
|
55 |
+
self.size_divisibility = size_divisibility
|
56 |
+
|
57 |
+
logger = logging.getLogger(__name__)
|
58 |
+
mode = "training" if is_train else "inference"
|
59 |
+
logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_config(cls, cfg, is_train=True):
|
63 |
+
# Build augmentation
|
64 |
+
augs = [
|
65 |
+
T.ResizeShortestEdge(
|
66 |
+
cfg.INPUT.MIN_SIZE_TRAIN,
|
67 |
+
cfg.INPUT.MAX_SIZE_TRAIN,
|
68 |
+
cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
|
69 |
+
)
|
70 |
+
]
|
71 |
+
if cfg.INPUT.CROP.ENABLED:
|
72 |
+
augs.append(
|
73 |
+
T.RandomCrop_CategoryAreaConstraint(
|
74 |
+
cfg.INPUT.CROP.TYPE,
|
75 |
+
cfg.INPUT.CROP.SIZE,
|
76 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
|
77 |
+
cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
78 |
+
)
|
79 |
+
)
|
80 |
+
if cfg.INPUT.COLOR_AUG_SSD:
|
81 |
+
augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
|
82 |
+
augs.append(T.RandomFlip())
|
83 |
+
|
84 |
+
# Assume always applies to the training set.
|
85 |
+
dataset_names = cfg.DATASETS.TRAIN
|
86 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
87 |
+
ignore_label = meta.ignore_label
|
88 |
+
|
89 |
+
ret = {
|
90 |
+
"is_train": is_train,
|
91 |
+
"augmentations": augs,
|
92 |
+
"image_format": cfg.INPUT.FORMAT,
|
93 |
+
"ignore_label": ignore_label,
|
94 |
+
"size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
|
95 |
+
}
|
96 |
+
return ret
|
97 |
+
|
98 |
+
def __call__(self, dataset_dict):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
dict: a format that builtin models in detectron2 accept
|
105 |
+
"""
|
106 |
+
assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
|
107 |
+
|
108 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
109 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
110 |
+
utils.check_image_size(dataset_dict, image)
|
111 |
+
|
112 |
+
if "sem_seg_file_name" in dataset_dict:
|
113 |
+
# PyTorch transformation not implemented for uint16, so converting it to double first
|
114 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
|
115 |
+
else:
|
116 |
+
sem_seg_gt = None
|
117 |
+
|
118 |
+
if sem_seg_gt is None:
|
119 |
+
raise ValueError(
|
120 |
+
"Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
|
121 |
+
dataset_dict["file_name"]
|
122 |
+
)
|
123 |
+
)
|
124 |
+
|
125 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
126 |
+
aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
|
127 |
+
image = aug_input.image
|
128 |
+
sem_seg_gt = aug_input.sem_seg
|
129 |
+
|
130 |
+
# Pad image and segmentation label here!
|
131 |
+
image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
132 |
+
if sem_seg_gt is not None:
|
133 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
134 |
+
|
135 |
+
if self.size_divisibility > 0:
|
136 |
+
image_size = (image.shape[-2], image.shape[-1])
|
137 |
+
padding_size = [
|
138 |
+
0,
|
139 |
+
self.size_divisibility - image_size[1],
|
140 |
+
0,
|
141 |
+
self.size_divisibility - image_size[0],
|
142 |
+
]
|
143 |
+
image = F.pad(image, padding_size, value=128).contiguous()
|
144 |
+
if sem_seg_gt is not None:
|
145 |
+
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
146 |
+
|
147 |
+
image_shape = (image.shape[-2], image.shape[-1]) # h, w
|
148 |
+
|
149 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
150 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
151 |
+
# Therefore it's important to use torch.Tensor.
|
152 |
+
dataset_dict["image"] = image
|
153 |
+
|
154 |
+
if sem_seg_gt is not None:
|
155 |
+
dataset_dict["sem_seg"] = sem_seg_gt.long()
|
156 |
+
|
157 |
+
if "annotations" in dataset_dict:
|
158 |
+
raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
|
159 |
+
|
160 |
+
# Prepare per-category binary masks
|
161 |
+
if sem_seg_gt is not None:
|
162 |
+
sem_seg_gt = sem_seg_gt.numpy()
|
163 |
+
instances = Instances(image_shape)
|
164 |
+
classes = np.unique(sem_seg_gt)
|
165 |
+
# remove ignored region
|
166 |
+
classes = classes[classes != self.ignore_label]
|
167 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
168 |
+
|
169 |
+
masks = []
|
170 |
+
for class_id in classes:
|
171 |
+
masks.append(sem_seg_gt == class_id)
|
172 |
+
|
173 |
+
if len(masks) == 0:
|
174 |
+
# Some image does not have annotation (all ignored)
|
175 |
+
instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
|
176 |
+
else:
|
177 |
+
masks = BitMasks(
|
178 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
179 |
+
)
|
180 |
+
instances.gt_masks = masks.tensor
|
181 |
+
|
182 |
+
dataset_dict["instances"] = instances
|
183 |
+
|
184 |
+
return dataset_dict
|
mask_former/data/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import (
|
3 |
+
register_ade20k_full,
|
4 |
+
register_ade20k_panoptic,
|
5 |
+
register_coco_stuff_10k,
|
6 |
+
register_mapillary_vistas,
|
7 |
+
)
|
mask_former/data/datasets/register_ade20k_full.py
ADDED
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import os
|
3 |
+
|
4 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
5 |
+
from detectron2.data.datasets import load_sem_seg
|
6 |
+
|
7 |
+
ADE20K_SEM_SEG_FULL_CATEGORIES = [
|
8 |
+
{"name": "wall", "id": 2978, "trainId": 0},
|
9 |
+
{"name": "building, edifice", "id": 312, "trainId": 1},
|
10 |
+
{"name": "sky", "id": 2420, "trainId": 2},
|
11 |
+
{"name": "tree", "id": 2855, "trainId": 3},
|
12 |
+
{"name": "road, route", "id": 2131, "trainId": 4},
|
13 |
+
{"name": "floor, flooring", "id": 976, "trainId": 5},
|
14 |
+
{"name": "ceiling", "id": 447, "trainId": 6},
|
15 |
+
{"name": "bed", "id": 165, "trainId": 7},
|
16 |
+
{"name": "sidewalk, pavement", "id": 2377, "trainId": 8},
|
17 |
+
{"name": "earth, ground", "id": 838, "trainId": 9},
|
18 |
+
{"name": "cabinet", "id": 350, "trainId": 10},
|
19 |
+
{"name": "person, individual, someone, somebody, mortal, soul", "id": 1831, "trainId": 11},
|
20 |
+
{"name": "grass", "id": 1125, "trainId": 12},
|
21 |
+
{"name": "windowpane, window", "id": 3055, "trainId": 13},
|
22 |
+
{"name": "car, auto, automobile, machine, motorcar", "id": 401, "trainId": 14},
|
23 |
+
{"name": "mountain, mount", "id": 1610, "trainId": 15},
|
24 |
+
{"name": "plant, flora, plant life", "id": 1910, "trainId": 16},
|
25 |
+
{"name": "table", "id": 2684, "trainId": 17},
|
26 |
+
{"name": "chair", "id": 471, "trainId": 18},
|
27 |
+
{"name": "curtain, drape, drapery, mantle, pall", "id": 687, "trainId": 19},
|
28 |
+
{"name": "door", "id": 774, "trainId": 20},
|
29 |
+
{"name": "sofa, couch, lounge", "id": 2473, "trainId": 21},
|
30 |
+
{"name": "sea", "id": 2264, "trainId": 22},
|
31 |
+
{"name": "painting, picture", "id": 1735, "trainId": 23},
|
32 |
+
{"name": "water", "id": 2994, "trainId": 24},
|
33 |
+
{"name": "mirror", "id": 1564, "trainId": 25},
|
34 |
+
{"name": "house", "id": 1276, "trainId": 26},
|
35 |
+
{"name": "rug, carpet, carpeting", "id": 2178, "trainId": 27},
|
36 |
+
{"name": "shelf", "id": 2329, "trainId": 28},
|
37 |
+
{"name": "armchair", "id": 57, "trainId": 29},
|
38 |
+
{"name": "fence, fencing", "id": 907, "trainId": 30},
|
39 |
+
{"name": "field", "id": 913, "trainId": 31},
|
40 |
+
{"name": "lamp", "id": 1395, "trainId": 32},
|
41 |
+
{"name": "rock, stone", "id": 2138, "trainId": 33},
|
42 |
+
{"name": "seat", "id": 2272, "trainId": 34},
|
43 |
+
{"name": "river", "id": 2128, "trainId": 35},
|
44 |
+
{"name": "desk", "id": 724, "trainId": 36},
|
45 |
+
{"name": "bathtub, bathing tub, bath, tub", "id": 155, "trainId": 37},
|
46 |
+
{"name": "railing, rail", "id": 2053, "trainId": 38},
|
47 |
+
{"name": "signboard, sign", "id": 2380, "trainId": 39},
|
48 |
+
{"name": "cushion", "id": 689, "trainId": 40},
|
49 |
+
{"name": "path", "id": 1788, "trainId": 41},
|
50 |
+
{"name": "work surface", "id": 3087, "trainId": 42},
|
51 |
+
{"name": "stairs, steps", "id": 2530, "trainId": 43},
|
52 |
+
{"name": "column, pillar", "id": 581, "trainId": 44},
|
53 |
+
{"name": "sink", "id": 2388, "trainId": 45},
|
54 |
+
{"name": "wardrobe, closet, press", "id": 2985, "trainId": 46},
|
55 |
+
{"name": "snow", "id": 2454, "trainId": 47},
|
56 |
+
{"name": "refrigerator, icebox", "id": 2096, "trainId": 48},
|
57 |
+
{"name": "base, pedestal, stand", "id": 137, "trainId": 49},
|
58 |
+
{"name": "bridge, span", "id": 294, "trainId": 50},
|
59 |
+
{"name": "blind, screen", "id": 212, "trainId": 51},
|
60 |
+
{"name": "runway", "id": 2185, "trainId": 52},
|
61 |
+
{"name": "cliff, drop, drop-off", "id": 524, "trainId": 53},
|
62 |
+
{"name": "sand", "id": 2212, "trainId": 54},
|
63 |
+
{"name": "fireplace, hearth, open fireplace", "id": 943, "trainId": 55},
|
64 |
+
{"name": "pillow", "id": 1869, "trainId": 56},
|
65 |
+
{"name": "screen door, screen", "id": 2251, "trainId": 57},
|
66 |
+
{"name": "toilet, can, commode, crapper, pot, potty, stool, throne", "id": 2793, "trainId": 58},
|
67 |
+
{"name": "skyscraper", "id": 2423, "trainId": 59},
|
68 |
+
{"name": "grandstand, covered stand", "id": 1121, "trainId": 60},
|
69 |
+
{"name": "box", "id": 266, "trainId": 61},
|
70 |
+
{"name": "pool table, billiard table, snooker table", "id": 1948, "trainId": 62},
|
71 |
+
{"name": "palm, palm tree", "id": 1744, "trainId": 63},
|
72 |
+
{"name": "double door", "id": 783, "trainId": 64},
|
73 |
+
{"name": "coffee table, cocktail table", "id": 571, "trainId": 65},
|
74 |
+
{"name": "counter", "id": 627, "trainId": 66},
|
75 |
+
{"name": "countertop", "id": 629, "trainId": 67},
|
76 |
+
{"name": "chest of drawers, chest, bureau, dresser", "id": 491, "trainId": 68},
|
77 |
+
{"name": "kitchen island", "id": 1374, "trainId": 69},
|
78 |
+
{"name": "boat", "id": 223, "trainId": 70},
|
79 |
+
{"name": "waterfall, falls", "id": 3016, "trainId": 71},
|
80 |
+
{
|
81 |
+
"name": "stove, kitchen stove, range, kitchen range, cooking stove",
|
82 |
+
"id": 2598,
|
83 |
+
"trainId": 72,
|
84 |
+
},
|
85 |
+
{"name": "flower", "id": 978, "trainId": 73},
|
86 |
+
{"name": "bookcase", "id": 239, "trainId": 74},
|
87 |
+
{"name": "controls", "id": 608, "trainId": 75},
|
88 |
+
{"name": "book", "id": 236, "trainId": 76},
|
89 |
+
{"name": "stairway, staircase", "id": 2531, "trainId": 77},
|
90 |
+
{"name": "streetlight, street lamp", "id": 2616, "trainId": 78},
|
91 |
+
{
|
92 |
+
"name": "computer, computing machine, computing device, data processor, electronic computer, information processing system",
|
93 |
+
"id": 591,
|
94 |
+
"trainId": 79,
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"name": "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle",
|
98 |
+
"id": 327,
|
99 |
+
"trainId": 80,
|
100 |
+
},
|
101 |
+
{"name": "swivel chair", "id": 2679, "trainId": 81},
|
102 |
+
{"name": "light, light source", "id": 1451, "trainId": 82},
|
103 |
+
{"name": "bench", "id": 181, "trainId": 83},
|
104 |
+
{"name": "case, display case, showcase, vitrine", "id": 420, "trainId": 84},
|
105 |
+
{"name": "towel", "id": 2821, "trainId": 85},
|
106 |
+
{"name": "fountain", "id": 1023, "trainId": 86},
|
107 |
+
{"name": "embankment", "id": 855, "trainId": 87},
|
108 |
+
{
|
109 |
+
"name": "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box",
|
110 |
+
"id": 2733,
|
111 |
+
"trainId": 88,
|
112 |
+
},
|
113 |
+
{"name": "van", "id": 2928, "trainId": 89},
|
114 |
+
{"name": "hill", "id": 1240, "trainId": 90},
|
115 |
+
{"name": "awning, sunshade, sunblind", "id": 77, "trainId": 91},
|
116 |
+
{"name": "poster, posting, placard, notice, bill, card", "id": 1969, "trainId": 92},
|
117 |
+
{"name": "truck, motortruck", "id": 2880, "trainId": 93},
|
118 |
+
{"name": "airplane, aeroplane, plane", "id": 14, "trainId": 94},
|
119 |
+
{"name": "pole", "id": 1936, "trainId": 95},
|
120 |
+
{"name": "tower", "id": 2828, "trainId": 96},
|
121 |
+
{"name": "court", "id": 631, "trainId": 97},
|
122 |
+
{"name": "ball", "id": 103, "trainId": 98},
|
123 |
+
{
|
124 |
+
"name": "aircraft carrier, carrier, flattop, attack aircraft carrier",
|
125 |
+
"id": 3144,
|
126 |
+
"trainId": 99,
|
127 |
+
},
|
128 |
+
{"name": "buffet, counter, sideboard", "id": 308, "trainId": 100},
|
129 |
+
{"name": "hovel, hut, hutch, shack, shanty", "id": 1282, "trainId": 101},
|
130 |
+
{"name": "apparel, wearing apparel, dress, clothes", "id": 38, "trainId": 102},
|
131 |
+
{"name": "minibike, motorbike", "id": 1563, "trainId": 103},
|
132 |
+
{"name": "animal, animate being, beast, brute, creature, fauna", "id": 29, "trainId": 104},
|
133 |
+
{"name": "chandelier, pendant, pendent", "id": 480, "trainId": 105},
|
134 |
+
{"name": "step, stair", "id": 2569, "trainId": 106},
|
135 |
+
{"name": "booth, cubicle, stall, kiosk", "id": 247, "trainId": 107},
|
136 |
+
{"name": "bicycle, bike, wheel, cycle", "id": 187, "trainId": 108},
|
137 |
+
{"name": "doorframe, doorcase", "id": 778, "trainId": 109},
|
138 |
+
{"name": "sconce", "id": 2243, "trainId": 110},
|
139 |
+
{"name": "pond", "id": 1941, "trainId": 111},
|
140 |
+
{"name": "trade name, brand name, brand, marque", "id": 2833, "trainId": 112},
|
141 |
+
{"name": "bannister, banister, balustrade, balusters, handrail", "id": 120, "trainId": 113},
|
142 |
+
{"name": "bag", "id": 95, "trainId": 114},
|
143 |
+
{"name": "traffic light, traffic signal, stoplight", "id": 2836, "trainId": 115},
|
144 |
+
{"name": "gazebo", "id": 1087, "trainId": 116},
|
145 |
+
{"name": "escalator, moving staircase, moving stairway", "id": 868, "trainId": 117},
|
146 |
+
{"name": "land, ground, soil", "id": 1401, "trainId": 118},
|
147 |
+
{"name": "board, plank", "id": 220, "trainId": 119},
|
148 |
+
{"name": "arcade machine", "id": 47, "trainId": 120},
|
149 |
+
{"name": "eiderdown, duvet, continental quilt", "id": 843, "trainId": 121},
|
150 |
+
{"name": "bar", "id": 123, "trainId": 122},
|
151 |
+
{"name": "stall, stand, sales booth", "id": 2537, "trainId": 123},
|
152 |
+
{"name": "playground", "id": 1927, "trainId": 124},
|
153 |
+
{"name": "ship", "id": 2337, "trainId": 125},
|
154 |
+
{"name": "ottoman, pouf, pouffe, puff, hassock", "id": 1702, "trainId": 126},
|
155 |
+
{
|
156 |
+
"name": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
|
157 |
+
"id": 64,
|
158 |
+
"trainId": 127,
|
159 |
+
},
|
160 |
+
{"name": "bottle", "id": 249, "trainId": 128},
|
161 |
+
{"name": "cradle", "id": 642, "trainId": 129},
|
162 |
+
{"name": "pot, flowerpot", "id": 1981, "trainId": 130},
|
163 |
+
{
|
164 |
+
"name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
|
165 |
+
"id": 609,
|
166 |
+
"trainId": 131,
|
167 |
+
},
|
168 |
+
{"name": "train, railroad train", "id": 2840, "trainId": 132},
|
169 |
+
{"name": "stool", "id": 2586, "trainId": 133},
|
170 |
+
{"name": "lake", "id": 1393, "trainId": 134},
|
171 |
+
{"name": "tank, storage tank", "id": 2704, "trainId": 135},
|
172 |
+
{"name": "ice, water ice", "id": 1304, "trainId": 136},
|
173 |
+
{"name": "basket, handbasket", "id": 146, "trainId": 137},
|
174 |
+
{"name": "manhole", "id": 1494, "trainId": 138},
|
175 |
+
{"name": "tent, collapsible shelter", "id": 2739, "trainId": 139},
|
176 |
+
{"name": "canopy", "id": 389, "trainId": 140},
|
177 |
+
{"name": "microwave, microwave oven", "id": 1551, "trainId": 141},
|
178 |
+
{"name": "barrel, cask", "id": 131, "trainId": 142},
|
179 |
+
{"name": "dirt track", "id": 738, "trainId": 143},
|
180 |
+
{"name": "beam", "id": 161, "trainId": 144},
|
181 |
+
{"name": "dishwasher, dish washer, dishwashing machine", "id": 747, "trainId": 145},
|
182 |
+
{"name": "plate", "id": 1919, "trainId": 146},
|
183 |
+
{"name": "screen, crt screen", "id": 3109, "trainId": 147},
|
184 |
+
{"name": "ruins", "id": 2179, "trainId": 148},
|
185 |
+
{"name": "washer, automatic washer, washing machine", "id": 2989, "trainId": 149},
|
186 |
+
{"name": "blanket, cover", "id": 206, "trainId": 150},
|
187 |
+
{"name": "plaything, toy", "id": 1930, "trainId": 151},
|
188 |
+
{"name": "food, solid food", "id": 1002, "trainId": 152},
|
189 |
+
{"name": "screen, silver screen, projection screen", "id": 2254, "trainId": 153},
|
190 |
+
{"name": "oven", "id": 1708, "trainId": 154},
|
191 |
+
{"name": "stage", "id": 2526, "trainId": 155},
|
192 |
+
{"name": "beacon, lighthouse, beacon light, pharos", "id": 160, "trainId": 156},
|
193 |
+
{"name": "umbrella", "id": 2901, "trainId": 157},
|
194 |
+
{"name": "sculpture", "id": 2262, "trainId": 158},
|
195 |
+
{"name": "aqueduct", "id": 44, "trainId": 159},
|
196 |
+
{"name": "container", "id": 597, "trainId": 160},
|
197 |
+
{"name": "scaffolding, staging", "id": 2235, "trainId": 161},
|
198 |
+
{"name": "hood, exhaust hood", "id": 1260, "trainId": 162},
|
199 |
+
{"name": "curb, curbing, kerb", "id": 682, "trainId": 163},
|
200 |
+
{"name": "roller coaster", "id": 2151, "trainId": 164},
|
201 |
+
{"name": "horse, equus caballus", "id": 3107, "trainId": 165},
|
202 |
+
{"name": "catwalk", "id": 432, "trainId": 166},
|
203 |
+
{"name": "glass, drinking glass", "id": 1098, "trainId": 167},
|
204 |
+
{"name": "vase", "id": 2932, "trainId": 168},
|
205 |
+
{"name": "central reservation", "id": 461, "trainId": 169},
|
206 |
+
{"name": "carousel", "id": 410, "trainId": 170},
|
207 |
+
{"name": "radiator", "id": 2046, "trainId": 171},
|
208 |
+
{"name": "closet", "id": 533, "trainId": 172},
|
209 |
+
{"name": "machine", "id": 1481, "trainId": 173},
|
210 |
+
{"name": "pier, wharf, wharfage, dock", "id": 1858, "trainId": 174},
|
211 |
+
{"name": "fan", "id": 894, "trainId": 175},
|
212 |
+
{"name": "inflatable bounce game", "id": 1322, "trainId": 176},
|
213 |
+
{"name": "pitch", "id": 1891, "trainId": 177},
|
214 |
+
{"name": "paper", "id": 1756, "trainId": 178},
|
215 |
+
{"name": "arcade, colonnade", "id": 49, "trainId": 179},
|
216 |
+
{"name": "hot tub", "id": 1272, "trainId": 180},
|
217 |
+
{"name": "helicopter", "id": 1229, "trainId": 181},
|
218 |
+
{"name": "tray", "id": 2850, "trainId": 182},
|
219 |
+
{"name": "partition, divider", "id": 1784, "trainId": 183},
|
220 |
+
{"name": "vineyard", "id": 2962, "trainId": 184},
|
221 |
+
{"name": "bowl", "id": 259, "trainId": 185},
|
222 |
+
{"name": "bullring", "id": 319, "trainId": 186},
|
223 |
+
{"name": "flag", "id": 954, "trainId": 187},
|
224 |
+
{"name": "pot", "id": 1974, "trainId": 188},
|
225 |
+
{"name": "footbridge, overcrossing, pedestrian bridge", "id": 1013, "trainId": 189},
|
226 |
+
{"name": "shower", "id": 2356, "trainId": 190},
|
227 |
+
{"name": "bag, traveling bag, travelling bag, grip, suitcase", "id": 97, "trainId": 191},
|
228 |
+
{"name": "bulletin board, notice board", "id": 318, "trainId": 192},
|
229 |
+
{"name": "confessional booth", "id": 592, "trainId": 193},
|
230 |
+
{"name": "trunk, tree trunk, bole", "id": 2885, "trainId": 194},
|
231 |
+
{"name": "forest", "id": 1017, "trainId": 195},
|
232 |
+
{"name": "elevator door", "id": 851, "trainId": 196},
|
233 |
+
{"name": "laptop, laptop computer", "id": 1407, "trainId": 197},
|
234 |
+
{"name": "instrument panel", "id": 1332, "trainId": 198},
|
235 |
+
{"name": "bucket, pail", "id": 303, "trainId": 199},
|
236 |
+
{"name": "tapestry, tapis", "id": 2714, "trainId": 200},
|
237 |
+
{"name": "platform", "id": 1924, "trainId": 201},
|
238 |
+
{"name": "jacket", "id": 1346, "trainId": 202},
|
239 |
+
{"name": "gate", "id": 1081, "trainId": 203},
|
240 |
+
{"name": "monitor, monitoring device", "id": 1583, "trainId": 204},
|
241 |
+
{
|
242 |
+
"name": "telephone booth, phone booth, call box, telephone box, telephone kiosk",
|
243 |
+
"id": 2727,
|
244 |
+
"trainId": 205,
|
245 |
+
},
|
246 |
+
{"name": "spotlight, spot", "id": 2509, "trainId": 206},
|
247 |
+
{"name": "ring", "id": 2123, "trainId": 207},
|
248 |
+
{"name": "control panel", "id": 602, "trainId": 208},
|
249 |
+
{"name": "blackboard, chalkboard", "id": 202, "trainId": 209},
|
250 |
+
{"name": "air conditioner, air conditioning", "id": 10, "trainId": 210},
|
251 |
+
{"name": "chest", "id": 490, "trainId": 211},
|
252 |
+
{"name": "clock", "id": 530, "trainId": 212},
|
253 |
+
{"name": "sand dune", "id": 2213, "trainId": 213},
|
254 |
+
{"name": "pipe, pipage, piping", "id": 1884, "trainId": 214},
|
255 |
+
{"name": "vault", "id": 2934, "trainId": 215},
|
256 |
+
{"name": "table football", "id": 2687, "trainId": 216},
|
257 |
+
{"name": "cannon", "id": 387, "trainId": 217},
|
258 |
+
{"name": "swimming pool, swimming bath, natatorium", "id": 2668, "trainId": 218},
|
259 |
+
{"name": "fluorescent, fluorescent fixture", "id": 982, "trainId": 219},
|
260 |
+
{"name": "statue", "id": 2547, "trainId": 220},
|
261 |
+
{
|
262 |
+
"name": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
|
263 |
+
"id": 1474,
|
264 |
+
"trainId": 221,
|
265 |
+
},
|
266 |
+
{"name": "exhibitor", "id": 877, "trainId": 222},
|
267 |
+
{"name": "ladder", "id": 1391, "trainId": 223},
|
268 |
+
{"name": "carport", "id": 414, "trainId": 224},
|
269 |
+
{"name": "dam", "id": 698, "trainId": 225},
|
270 |
+
{"name": "pulpit", "id": 2019, "trainId": 226},
|
271 |
+
{"name": "skylight, fanlight", "id": 2422, "trainId": 227},
|
272 |
+
{"name": "water tower", "id": 3010, "trainId": 228},
|
273 |
+
{"name": "grill, grille, grillwork", "id": 1139, "trainId": 229},
|
274 |
+
{"name": "display board", "id": 753, "trainId": 230},
|
275 |
+
{"name": "pane, pane of glass, window glass", "id": 1747, "trainId": 231},
|
276 |
+
{"name": "rubbish, trash, scrap", "id": 2175, "trainId": 232},
|
277 |
+
{"name": "ice rink", "id": 1301, "trainId": 233},
|
278 |
+
{"name": "fruit", "id": 1033, "trainId": 234},
|
279 |
+
{"name": "patio", "id": 1789, "trainId": 235},
|
280 |
+
{"name": "vending machine", "id": 2939, "trainId": 236},
|
281 |
+
{"name": "telephone, phone, telephone set", "id": 2730, "trainId": 237},
|
282 |
+
{"name": "net", "id": 1652, "trainId": 238},
|
283 |
+
{
|
284 |
+
"name": "backpack, back pack, knapsack, packsack, rucksack, haversack",
|
285 |
+
"id": 90,
|
286 |
+
"trainId": 239,
|
287 |
+
},
|
288 |
+
{"name": "jar", "id": 1349, "trainId": 240},
|
289 |
+
{"name": "track", "id": 2830, "trainId": 241},
|
290 |
+
{"name": "magazine", "id": 1485, "trainId": 242},
|
291 |
+
{"name": "shutter", "id": 2370, "trainId": 243},
|
292 |
+
{"name": "roof", "id": 2155, "trainId": 244},
|
293 |
+
{"name": "banner, streamer", "id": 118, "trainId": 245},
|
294 |
+
{"name": "landfill", "id": 1402, "trainId": 246},
|
295 |
+
{"name": "post", "id": 1957, "trainId": 247},
|
296 |
+
{"name": "altarpiece, reredos", "id": 3130, "trainId": 248},
|
297 |
+
{"name": "hat, chapeau, lid", "id": 1197, "trainId": 249},
|
298 |
+
{"name": "arch, archway", "id": 52, "trainId": 250},
|
299 |
+
{"name": "table game", "id": 2688, "trainId": 251},
|
300 |
+
{"name": "bag, handbag, pocketbook, purse", "id": 96, "trainId": 252},
|
301 |
+
{"name": "document, written document, papers", "id": 762, "trainId": 253},
|
302 |
+
{"name": "dome", "id": 772, "trainId": 254},
|
303 |
+
{"name": "pier", "id": 1857, "trainId": 255},
|
304 |
+
{"name": "shanties", "id": 2315, "trainId": 256},
|
305 |
+
{"name": "forecourt", "id": 1016, "trainId": 257},
|
306 |
+
{"name": "crane", "id": 643, "trainId": 258},
|
307 |
+
{"name": "dog, domestic dog, canis familiaris", "id": 3105, "trainId": 259},
|
308 |
+
{"name": "piano, pianoforte, forte-piano", "id": 1849, "trainId": 260},
|
309 |
+
{"name": "drawing", "id": 791, "trainId": 261},
|
310 |
+
{"name": "cabin", "id": 349, "trainId": 262},
|
311 |
+
{
|
312 |
+
"name": "ad, advertisement, advertizement, advertising, advertizing, advert",
|
313 |
+
"id": 6,
|
314 |
+
"trainId": 263,
|
315 |
+
},
|
316 |
+
{"name": "amphitheater, amphitheatre, coliseum", "id": 3114, "trainId": 264},
|
317 |
+
{"name": "monument", "id": 1587, "trainId": 265},
|
318 |
+
{"name": "henhouse", "id": 1233, "trainId": 266},
|
319 |
+
{"name": "cockpit", "id": 559, "trainId": 267},
|
320 |
+
{"name": "heater, warmer", "id": 1223, "trainId": 268},
|
321 |
+
{"name": "windmill, aerogenerator, wind generator", "id": 3049, "trainId": 269},
|
322 |
+
{"name": "pool", "id": 1943, "trainId": 270},
|
323 |
+
{"name": "elevator, lift", "id": 853, "trainId": 271},
|
324 |
+
{"name": "decoration, ornament, ornamentation", "id": 709, "trainId": 272},
|
325 |
+
{"name": "labyrinth", "id": 1390, "trainId": 273},
|
326 |
+
{"name": "text, textual matter", "id": 2748, "trainId": 274},
|
327 |
+
{"name": "printer", "id": 2007, "trainId": 275},
|
328 |
+
{"name": "mezzanine, first balcony", "id": 1546, "trainId": 276},
|
329 |
+
{"name": "mattress", "id": 1513, "trainId": 277},
|
330 |
+
{"name": "straw", "id": 2600, "trainId": 278},
|
331 |
+
{"name": "stalls", "id": 2538, "trainId": 279},
|
332 |
+
{"name": "patio, terrace", "id": 1790, "trainId": 280},
|
333 |
+
{"name": "billboard, hoarding", "id": 194, "trainId": 281},
|
334 |
+
{"name": "bus stop", "id": 326, "trainId": 282},
|
335 |
+
{"name": "trouser, pant", "id": 2877, "trainId": 283},
|
336 |
+
{"name": "console table, console", "id": 594, "trainId": 284},
|
337 |
+
{"name": "rack", "id": 2036, "trainId": 285},
|
338 |
+
{"name": "notebook", "id": 1662, "trainId": 286},
|
339 |
+
{"name": "shrine", "id": 2366, "trainId": 287},
|
340 |
+
{"name": "pantry", "id": 1754, "trainId": 288},
|
341 |
+
{"name": "cart", "id": 418, "trainId": 289},
|
342 |
+
{"name": "steam shovel", "id": 2553, "trainId": 290},
|
343 |
+
{"name": "porch", "id": 1951, "trainId": 291},
|
344 |
+
{"name": "postbox, mailbox, letter box", "id": 1963, "trainId": 292},
|
345 |
+
{"name": "figurine, statuette", "id": 918, "trainId": 293},
|
346 |
+
{"name": "recycling bin", "id": 2086, "trainId": 294},
|
347 |
+
{"name": "folding screen", "id": 997, "trainId": 295},
|
348 |
+
{"name": "telescope", "id": 2731, "trainId": 296},
|
349 |
+
{"name": "deck chair, beach chair", "id": 704, "trainId": 297},
|
350 |
+
{"name": "kennel", "id": 1365, "trainId": 298},
|
351 |
+
{"name": "coffee maker", "id": 569, "trainId": 299},
|
352 |
+
{"name": "altar, communion table, lord's table", "id": 3108, "trainId": 300},
|
353 |
+
{"name": "fish", "id": 948, "trainId": 301},
|
354 |
+
{"name": "easel", "id": 839, "trainId": 302},
|
355 |
+
{"name": "artificial golf green", "id": 63, "trainId": 303},
|
356 |
+
{"name": "iceberg", "id": 1305, "trainId": 304},
|
357 |
+
{"name": "candlestick, candle holder", "id": 378, "trainId": 305},
|
358 |
+
{"name": "shower stall, shower bath", "id": 2362, "trainId": 306},
|
359 |
+
{"name": "television stand", "id": 2734, "trainId": 307},
|
360 |
+
{
|
361 |
+
"name": "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle",
|
362 |
+
"id": 2982,
|
363 |
+
"trainId": 308,
|
364 |
+
},
|
365 |
+
{"name": "skeleton", "id": 2398, "trainId": 309},
|
366 |
+
{"name": "grand piano, grand", "id": 1119, "trainId": 310},
|
367 |
+
{"name": "candy, confect", "id": 382, "trainId": 311},
|
368 |
+
{"name": "grille door", "id": 1141, "trainId": 312},
|
369 |
+
{"name": "pedestal, plinth, footstall", "id": 1805, "trainId": 313},
|
370 |
+
{"name": "jersey, t-shirt, tee shirt", "id": 3102, "trainId": 314},
|
371 |
+
{"name": "shoe", "id": 2341, "trainId": 315},
|
372 |
+
{"name": "gravestone, headstone, tombstone", "id": 1131, "trainId": 316},
|
373 |
+
{"name": "shanty", "id": 2316, "trainId": 317},
|
374 |
+
{"name": "structure", "id": 2626, "trainId": 318},
|
375 |
+
{"name": "rocking chair, rocker", "id": 3104, "trainId": 319},
|
376 |
+
{"name": "bird", "id": 198, "trainId": 320},
|
377 |
+
{"name": "place mat", "id": 1896, "trainId": 321},
|
378 |
+
{"name": "tomb", "id": 2800, "trainId": 322},
|
379 |
+
{"name": "big top", "id": 190, "trainId": 323},
|
380 |
+
{"name": "gas pump, gasoline pump, petrol pump, island dispenser", "id": 3131, "trainId": 324},
|
381 |
+
{"name": "lockers", "id": 1463, "trainId": 325},
|
382 |
+
{"name": "cage", "id": 357, "trainId": 326},
|
383 |
+
{"name": "finger", "id": 929, "trainId": 327},
|
384 |
+
{"name": "bleachers", "id": 209, "trainId": 328},
|
385 |
+
{"name": "ferris wheel", "id": 912, "trainId": 329},
|
386 |
+
{"name": "hairdresser chair", "id": 1164, "trainId": 330},
|
387 |
+
{"name": "mat", "id": 1509, "trainId": 331},
|
388 |
+
{"name": "stands", "id": 2539, "trainId": 332},
|
389 |
+
{"name": "aquarium, fish tank, marine museum", "id": 3116, "trainId": 333},
|
390 |
+
{"name": "streetcar, tram, tramcar, trolley, trolley car", "id": 2615, "trainId": 334},
|
391 |
+
{"name": "napkin, table napkin, serviette", "id": 1644, "trainId": 335},
|
392 |
+
{"name": "dummy", "id": 818, "trainId": 336},
|
393 |
+
{"name": "booklet, brochure, folder, leaflet, pamphlet", "id": 242, "trainId": 337},
|
394 |
+
{"name": "sand trap", "id": 2217, "trainId": 338},
|
395 |
+
{"name": "shop, store", "id": 2347, "trainId": 339},
|
396 |
+
{"name": "table cloth", "id": 2686, "trainId": 340},
|
397 |
+
{"name": "service station", "id": 2300, "trainId": 341},
|
398 |
+
{"name": "coffin", "id": 572, "trainId": 342},
|
399 |
+
{"name": "drawer", "id": 789, "trainId": 343},
|
400 |
+
{"name": "cages", "id": 358, "trainId": 344},
|
401 |
+
{"name": "slot machine, coin machine", "id": 2443, "trainId": 345},
|
402 |
+
{"name": "balcony", "id": 101, "trainId": 346},
|
403 |
+
{"name": "volleyball court", "id": 2969, "trainId": 347},
|
404 |
+
{"name": "table tennis", "id": 2692, "trainId": 348},
|
405 |
+
{"name": "control table", "id": 606, "trainId": 349},
|
406 |
+
{"name": "shirt", "id": 2339, "trainId": 350},
|
407 |
+
{"name": "merchandise, ware, product", "id": 1533, "trainId": 351},
|
408 |
+
{"name": "railway", "id": 2060, "trainId": 352},
|
409 |
+
{"name": "parterre", "id": 1782, "trainId": 353},
|
410 |
+
{"name": "chimney", "id": 495, "trainId": 354},
|
411 |
+
{"name": "can, tin, tin can", "id": 371, "trainId": 355},
|
412 |
+
{"name": "tanks", "id": 2707, "trainId": 356},
|
413 |
+
{"name": "fabric, cloth, material, textile", "id": 889, "trainId": 357},
|
414 |
+
{"name": "alga, algae", "id": 3156, "trainId": 358},
|
415 |
+
{"name": "system", "id": 2683, "trainId": 359},
|
416 |
+
{"name": "map", "id": 1499, "trainId": 360},
|
417 |
+
{"name": "greenhouse", "id": 1135, "trainId": 361},
|
418 |
+
{"name": "mug", "id": 1619, "trainId": 362},
|
419 |
+
{"name": "barbecue", "id": 125, "trainId": 363},
|
420 |
+
{"name": "trailer", "id": 2838, "trainId": 364},
|
421 |
+
{"name": "toilet tissue, toilet paper, bathroom tissue", "id": 2792, "trainId": 365},
|
422 |
+
{"name": "organ", "id": 1695, "trainId": 366},
|
423 |
+
{"name": "dishrag, dishcloth", "id": 746, "trainId": 367},
|
424 |
+
{"name": "island", "id": 1343, "trainId": 368},
|
425 |
+
{"name": "keyboard", "id": 1370, "trainId": 369},
|
426 |
+
{"name": "trench", "id": 2858, "trainId": 370},
|
427 |
+
{"name": "basket, basketball hoop, hoop", "id": 145, "trainId": 371},
|
428 |
+
{"name": "steering wheel, wheel", "id": 2565, "trainId": 372},
|
429 |
+
{"name": "pitcher, ewer", "id": 1892, "trainId": 373},
|
430 |
+
{"name": "goal", "id": 1103, "trainId": 374},
|
431 |
+
{"name": "bread, breadstuff, staff of life", "id": 286, "trainId": 375},
|
432 |
+
{"name": "beds", "id": 170, "trainId": 376},
|
433 |
+
{"name": "wood", "id": 3073, "trainId": 377},
|
434 |
+
{"name": "file cabinet", "id": 922, "trainId": 378},
|
435 |
+
{"name": "newspaper, paper", "id": 1655, "trainId": 379},
|
436 |
+
{"name": "motorboat", "id": 1602, "trainId": 380},
|
437 |
+
{"name": "rope", "id": 2160, "trainId": 381},
|
438 |
+
{"name": "guitar", "id": 1151, "trainId": 382},
|
439 |
+
{"name": "rubble", "id": 2176, "trainId": 383},
|
440 |
+
{"name": "scarf", "id": 2239, "trainId": 384},
|
441 |
+
{"name": "barrels", "id": 132, "trainId": 385},
|
442 |
+
{"name": "cap", "id": 394, "trainId": 386},
|
443 |
+
{"name": "leaves", "id": 1424, "trainId": 387},
|
444 |
+
{"name": "control tower", "id": 607, "trainId": 388},
|
445 |
+
{"name": "dashboard", "id": 700, "trainId": 389},
|
446 |
+
{"name": "bandstand", "id": 116, "trainId": 390},
|
447 |
+
{"name": "lectern", "id": 1425, "trainId": 391},
|
448 |
+
{"name": "switch, electric switch, electrical switch", "id": 2676, "trainId": 392},
|
449 |
+
{"name": "baseboard, mopboard, skirting board", "id": 141, "trainId": 393},
|
450 |
+
{"name": "shower room", "id": 2360, "trainId": 394},
|
451 |
+
{"name": "smoke", "id": 2449, "trainId": 395},
|
452 |
+
{"name": "faucet, spigot", "id": 897, "trainId": 396},
|
453 |
+
{"name": "bulldozer", "id": 317, "trainId": 397},
|
454 |
+
{"name": "saucepan", "id": 2228, "trainId": 398},
|
455 |
+
{"name": "shops", "id": 2351, "trainId": 399},
|
456 |
+
{"name": "meter", "id": 1543, "trainId": 400},
|
457 |
+
{"name": "crevasse", "id": 656, "trainId": 401},
|
458 |
+
{"name": "gear", "id": 1088, "trainId": 402},
|
459 |
+
{"name": "candelabrum, candelabra", "id": 373, "trainId": 403},
|
460 |
+
{"name": "sofa bed", "id": 2472, "trainId": 404},
|
461 |
+
{"name": "tunnel", "id": 2892, "trainId": 405},
|
462 |
+
{"name": "pallet", "id": 1740, "trainId": 406},
|
463 |
+
{"name": "wire, conducting wire", "id": 3067, "trainId": 407},
|
464 |
+
{"name": "kettle, boiler", "id": 1367, "trainId": 408},
|
465 |
+
{"name": "bidet", "id": 188, "trainId": 409},
|
466 |
+
{
|
467 |
+
"name": "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher",
|
468 |
+
"id": 79,
|
469 |
+
"trainId": 410,
|
470 |
+
},
|
471 |
+
{"name": "music stand", "id": 1633, "trainId": 411},
|
472 |
+
{"name": "pipe, tube", "id": 1885, "trainId": 412},
|
473 |
+
{"name": "cup", "id": 677, "trainId": 413},
|
474 |
+
{"name": "parking meter", "id": 1779, "trainId": 414},
|
475 |
+
{"name": "ice hockey rink", "id": 1297, "trainId": 415},
|
476 |
+
{"name": "shelter", "id": 2334, "trainId": 416},
|
477 |
+
{"name": "weeds", "id": 3027, "trainId": 417},
|
478 |
+
{"name": "temple", "id": 2735, "trainId": 418},
|
479 |
+
{"name": "patty, cake", "id": 1791, "trainId": 419},
|
480 |
+
{"name": "ski slope", "id": 2405, "trainId": 420},
|
481 |
+
{"name": "panel", "id": 1748, "trainId": 421},
|
482 |
+
{"name": "wallet", "id": 2983, "trainId": 422},
|
483 |
+
{"name": "wheel", "id": 3035, "trainId": 423},
|
484 |
+
{"name": "towel rack, towel horse", "id": 2824, "trainId": 424},
|
485 |
+
{"name": "roundabout", "id": 2168, "trainId": 425},
|
486 |
+
{"name": "canister, cannister, tin", "id": 385, "trainId": 426},
|
487 |
+
{"name": "rod", "id": 2148, "trainId": 427},
|
488 |
+
{"name": "soap dispenser", "id": 2465, "trainId": 428},
|
489 |
+
{"name": "bell", "id": 175, "trainId": 429},
|
490 |
+
{"name": "canvas", "id": 390, "trainId": 430},
|
491 |
+
{"name": "box office, ticket office, ticket booth", "id": 268, "trainId": 431},
|
492 |
+
{"name": "teacup", "id": 2722, "trainId": 432},
|
493 |
+
{"name": "trellis", "id": 2857, "trainId": 433},
|
494 |
+
{"name": "workbench", "id": 3088, "trainId": 434},
|
495 |
+
{"name": "valley, vale", "id": 2926, "trainId": 435},
|
496 |
+
{"name": "toaster", "id": 2782, "trainId": 436},
|
497 |
+
{"name": "knife", "id": 1378, "trainId": 437},
|
498 |
+
{"name": "podium", "id": 1934, "trainId": 438},
|
499 |
+
{"name": "ramp", "id": 2072, "trainId": 439},
|
500 |
+
{"name": "tumble dryer", "id": 2889, "trainId": 440},
|
501 |
+
{"name": "fireplug, fire hydrant, plug", "id": 944, "trainId": 441},
|
502 |
+
{"name": "gym shoe, sneaker, tennis shoe", "id": 1158, "trainId": 442},
|
503 |
+
{"name": "lab bench", "id": 1383, "trainId": 443},
|
504 |
+
{"name": "equipment", "id": 867, "trainId": 444},
|
505 |
+
{"name": "rocky formation", "id": 2145, "trainId": 445},
|
506 |
+
{"name": "plastic", "id": 1915, "trainId": 446},
|
507 |
+
{"name": "calendar", "id": 361, "trainId": 447},
|
508 |
+
{"name": "caravan", "id": 402, "trainId": 448},
|
509 |
+
{"name": "check-in-desk", "id": 482, "trainId": 449},
|
510 |
+
{"name": "ticket counter", "id": 2761, "trainId": 450},
|
511 |
+
{"name": "brush", "id": 300, "trainId": 451},
|
512 |
+
{"name": "mill", "id": 1554, "trainId": 452},
|
513 |
+
{"name": "covered bridge", "id": 636, "trainId": 453},
|
514 |
+
{"name": "bowling alley", "id": 260, "trainId": 454},
|
515 |
+
{"name": "hanger", "id": 1186, "trainId": 455},
|
516 |
+
{"name": "excavator", "id": 871, "trainId": 456},
|
517 |
+
{"name": "trestle", "id": 2859, "trainId": 457},
|
518 |
+
{"name": "revolving door", "id": 2103, "trainId": 458},
|
519 |
+
{"name": "blast furnace", "id": 208, "trainId": 459},
|
520 |
+
{"name": "scale, weighing machine", "id": 2236, "trainId": 460},
|
521 |
+
{"name": "projector", "id": 2012, "trainId": 461},
|
522 |
+
{"name": "soap", "id": 2462, "trainId": 462},
|
523 |
+
{"name": "locker", "id": 1462, "trainId": 463},
|
524 |
+
{"name": "tractor", "id": 2832, "trainId": 464},
|
525 |
+
{"name": "stretcher", "id": 2617, "trainId": 465},
|
526 |
+
{"name": "frame", "id": 1024, "trainId": 466},
|
527 |
+
{"name": "grating", "id": 1129, "trainId": 467},
|
528 |
+
{"name": "alembic", "id": 18, "trainId": 468},
|
529 |
+
{"name": "candle, taper, wax light", "id": 376, "trainId": 469},
|
530 |
+
{"name": "barrier", "id": 134, "trainId": 470},
|
531 |
+
{"name": "cardboard", "id": 407, "trainId": 471},
|
532 |
+
{"name": "cave", "id": 434, "trainId": 472},
|
533 |
+
{"name": "puddle", "id": 2017, "trainId": 473},
|
534 |
+
{"name": "tarp", "id": 2717, "trainId": 474},
|
535 |
+
{"name": "price tag", "id": 2005, "trainId": 475},
|
536 |
+
{"name": "watchtower", "id": 2993, "trainId": 476},
|
537 |
+
{"name": "meters", "id": 1545, "trainId": 477},
|
538 |
+
{
|
539 |
+
"name": "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb",
|
540 |
+
"id": 1445,
|
541 |
+
"trainId": 478,
|
542 |
+
},
|
543 |
+
{"name": "tracks", "id": 2831, "trainId": 479},
|
544 |
+
{"name": "hair dryer", "id": 1161, "trainId": 480},
|
545 |
+
{"name": "skirt", "id": 2411, "trainId": 481},
|
546 |
+
{"name": "viaduct", "id": 2949, "trainId": 482},
|
547 |
+
{"name": "paper towel", "id": 1769, "trainId": 483},
|
548 |
+
{"name": "coat", "id": 552, "trainId": 484},
|
549 |
+
{"name": "sheet", "id": 2327, "trainId": 485},
|
550 |
+
{"name": "fire extinguisher, extinguisher, asphyxiator", "id": 939, "trainId": 486},
|
551 |
+
{"name": "water wheel", "id": 3013, "trainId": 487},
|
552 |
+
{"name": "pottery, clayware", "id": 1986, "trainId": 488},
|
553 |
+
{"name": "magazine rack", "id": 1486, "trainId": 489},
|
554 |
+
{"name": "teapot", "id": 2723, "trainId": 490},
|
555 |
+
{"name": "microphone, mike", "id": 1549, "trainId": 491},
|
556 |
+
{"name": "support", "id": 2649, "trainId": 492},
|
557 |
+
{"name": "forklift", "id": 1020, "trainId": 493},
|
558 |
+
{"name": "canyon", "id": 392, "trainId": 494},
|
559 |
+
{"name": "cash register, register", "id": 422, "trainId": 495},
|
560 |
+
{"name": "leaf, leafage, foliage", "id": 1419, "trainId": 496},
|
561 |
+
{"name": "remote control, remote", "id": 2099, "trainId": 497},
|
562 |
+
{"name": "soap dish", "id": 2464, "trainId": 498},
|
563 |
+
{"name": "windshield, windscreen", "id": 3058, "trainId": 499},
|
564 |
+
{"name": "cat", "id": 430, "trainId": 500},
|
565 |
+
{"name": "cue, cue stick, pool cue, pool stick", "id": 675, "trainId": 501},
|
566 |
+
{"name": "vent, venthole, vent-hole, blowhole", "id": 2941, "trainId": 502},
|
567 |
+
{"name": "videos", "id": 2955, "trainId": 503},
|
568 |
+
{"name": "shovel", "id": 2355, "trainId": 504},
|
569 |
+
{"name": "eaves", "id": 840, "trainId": 505},
|
570 |
+
{"name": "antenna, aerial, transmitting aerial", "id": 32, "trainId": 506},
|
571 |
+
{"name": "shipyard", "id": 2338, "trainId": 507},
|
572 |
+
{"name": "hen, biddy", "id": 1232, "trainId": 508},
|
573 |
+
{"name": "traffic cone", "id": 2834, "trainId": 509},
|
574 |
+
{"name": "washing machines", "id": 2991, "trainId": 510},
|
575 |
+
{"name": "truck crane", "id": 2879, "trainId": 511},
|
576 |
+
{"name": "cds", "id": 444, "trainId": 512},
|
577 |
+
{"name": "niche", "id": 1657, "trainId": 513},
|
578 |
+
{"name": "scoreboard", "id": 2246, "trainId": 514},
|
579 |
+
{"name": "briefcase", "id": 296, "trainId": 515},
|
580 |
+
{"name": "boot", "id": 245, "trainId": 516},
|
581 |
+
{"name": "sweater, jumper", "id": 2661, "trainId": 517},
|
582 |
+
{"name": "hay", "id": 1202, "trainId": 518},
|
583 |
+
{"name": "pack", "id": 1714, "trainId": 519},
|
584 |
+
{"name": "bottle rack", "id": 251, "trainId": 520},
|
585 |
+
{"name": "glacier", "id": 1095, "trainId": 521},
|
586 |
+
{"name": "pergola", "id": 1828, "trainId": 522},
|
587 |
+
{"name": "building materials", "id": 311, "trainId": 523},
|
588 |
+
{"name": "television camera", "id": 2732, "trainId": 524},
|
589 |
+
{"name": "first floor", "id": 947, "trainId": 525},
|
590 |
+
{"name": "rifle", "id": 2115, "trainId": 526},
|
591 |
+
{"name": "tennis table", "id": 2738, "trainId": 527},
|
592 |
+
{"name": "stadium", "id": 2525, "trainId": 528},
|
593 |
+
{"name": "safety belt", "id": 2194, "trainId": 529},
|
594 |
+
{"name": "cover", "id": 634, "trainId": 530},
|
595 |
+
{"name": "dish rack", "id": 740, "trainId": 531},
|
596 |
+
{"name": "synthesizer", "id": 2682, "trainId": 532},
|
597 |
+
{"name": "pumpkin", "id": 2020, "trainId": 533},
|
598 |
+
{"name": "gutter", "id": 1156, "trainId": 534},
|
599 |
+
{"name": "fruit stand", "id": 1036, "trainId": 535},
|
600 |
+
{"name": "ice floe, floe", "id": 1295, "trainId": 536},
|
601 |
+
{"name": "handle, grip, handgrip, hold", "id": 1181, "trainId": 537},
|
602 |
+
{"name": "wheelchair", "id": 3037, "trainId": 538},
|
603 |
+
{"name": "mousepad, mouse mat", "id": 1614, "trainId": 539},
|
604 |
+
{"name": "diploma", "id": 736, "trainId": 540},
|
605 |
+
{"name": "fairground ride", "id": 893, "trainId": 541},
|
606 |
+
{"name": "radio", "id": 2047, "trainId": 542},
|
607 |
+
{"name": "hotplate", "id": 1274, "trainId": 543},
|
608 |
+
{"name": "junk", "id": 1361, "trainId": 544},
|
609 |
+
{"name": "wheelbarrow", "id": 3036, "trainId": 545},
|
610 |
+
{"name": "stream", "id": 2606, "trainId": 546},
|
611 |
+
{"name": "toll plaza", "id": 2797, "trainId": 547},
|
612 |
+
{"name": "punching bag", "id": 2022, "trainId": 548},
|
613 |
+
{"name": "trough", "id": 2876, "trainId": 549},
|
614 |
+
{"name": "throne", "id": 2758, "trainId": 550},
|
615 |
+
{"name": "chair desk", "id": 472, "trainId": 551},
|
616 |
+
{"name": "weighbridge", "id": 3028, "trainId": 552},
|
617 |
+
{"name": "extractor fan", "id": 882, "trainId": 553},
|
618 |
+
{"name": "hanging clothes", "id": 1189, "trainId": 554},
|
619 |
+
{"name": "dish, dish aerial, dish antenna, saucer", "id": 743, "trainId": 555},
|
620 |
+
{"name": "alarm clock, alarm", "id": 3122, "trainId": 556},
|
621 |
+
{"name": "ski lift", "id": 2401, "trainId": 557},
|
622 |
+
{"name": "chain", "id": 468, "trainId": 558},
|
623 |
+
{"name": "garage", "id": 1061, "trainId": 559},
|
624 |
+
{"name": "mechanical shovel", "id": 1523, "trainId": 560},
|
625 |
+
{"name": "wine rack", "id": 3059, "trainId": 561},
|
626 |
+
{"name": "tramway", "id": 2843, "trainId": 562},
|
627 |
+
{"name": "treadmill", "id": 2853, "trainId": 563},
|
628 |
+
{"name": "menu", "id": 1529, "trainId": 564},
|
629 |
+
{"name": "block", "id": 214, "trainId": 565},
|
630 |
+
{"name": "well", "id": 3032, "trainId": 566},
|
631 |
+
{"name": "witness stand", "id": 3071, "trainId": 567},
|
632 |
+
{"name": "branch", "id": 277, "trainId": 568},
|
633 |
+
{"name": "duck", "id": 813, "trainId": 569},
|
634 |
+
{"name": "casserole", "id": 426, "trainId": 570},
|
635 |
+
{"name": "frying pan", "id": 1039, "trainId": 571},
|
636 |
+
{"name": "desk organizer", "id": 727, "trainId": 572},
|
637 |
+
{"name": "mast", "id": 1508, "trainId": 573},
|
638 |
+
{"name": "spectacles, specs, eyeglasses, glasses", "id": 2490, "trainId": 574},
|
639 |
+
{"name": "service elevator", "id": 2299, "trainId": 575},
|
640 |
+
{"name": "dollhouse", "id": 768, "trainId": 576},
|
641 |
+
{"name": "hammock", "id": 1172, "trainId": 577},
|
642 |
+
{"name": "clothes hanging", "id": 537, "trainId": 578},
|
643 |
+
{"name": "photocopier", "id": 1847, "trainId": 579},
|
644 |
+
{"name": "notepad", "id": 1664, "trainId": 580},
|
645 |
+
{"name": "golf cart", "id": 1110, "trainId": 581},
|
646 |
+
{"name": "footpath", "id": 1014, "trainId": 582},
|
647 |
+
{"name": "cross", "id": 662, "trainId": 583},
|
648 |
+
{"name": "baptismal font", "id": 121, "trainId": 584},
|
649 |
+
{"name": "boiler", "id": 227, "trainId": 585},
|
650 |
+
{"name": "skip", "id": 2410, "trainId": 586},
|
651 |
+
{"name": "rotisserie", "id": 2165, "trainId": 587},
|
652 |
+
{"name": "tables", "id": 2696, "trainId": 588},
|
653 |
+
{"name": "water mill", "id": 3005, "trainId": 589},
|
654 |
+
{"name": "helmet", "id": 1231, "trainId": 590},
|
655 |
+
{"name": "cover curtain", "id": 635, "trainId": 591},
|
656 |
+
{"name": "brick", "id": 292, "trainId": 592},
|
657 |
+
{"name": "table runner", "id": 2690, "trainId": 593},
|
658 |
+
{"name": "ashtray", "id": 65, "trainId": 594},
|
659 |
+
{"name": "street box", "id": 2607, "trainId": 595},
|
660 |
+
{"name": "stick", "id": 2574, "trainId": 596},
|
661 |
+
{"name": "hangers", "id": 1188, "trainId": 597},
|
662 |
+
{"name": "cells", "id": 456, "trainId": 598},
|
663 |
+
{"name": "urinal", "id": 2913, "trainId": 599},
|
664 |
+
{"name": "centerpiece", "id": 459, "trainId": 600},
|
665 |
+
{"name": "portable fridge", "id": 1955, "trainId": 601},
|
666 |
+
{"name": "dvds", "id": 827, "trainId": 602},
|
667 |
+
{"name": "golf club", "id": 1111, "trainId": 603},
|
668 |
+
{"name": "skirting board", "id": 2412, "trainId": 604},
|
669 |
+
{"name": "water cooler", "id": 2997, "trainId": 605},
|
670 |
+
{"name": "clipboard", "id": 528, "trainId": 606},
|
671 |
+
{"name": "camera, photographic camera", "id": 366, "trainId": 607},
|
672 |
+
{"name": "pigeonhole", "id": 1863, "trainId": 608},
|
673 |
+
{"name": "chips", "id": 500, "trainId": 609},
|
674 |
+
{"name": "food processor", "id": 1001, "trainId": 610},
|
675 |
+
{"name": "post box", "id": 1958, "trainId": 611},
|
676 |
+
{"name": "lid", "id": 1441, "trainId": 612},
|
677 |
+
{"name": "drum", "id": 809, "trainId": 613},
|
678 |
+
{"name": "blender", "id": 210, "trainId": 614},
|
679 |
+
{"name": "cave entrance", "id": 435, "trainId": 615},
|
680 |
+
{"name": "dental chair", "id": 718, "trainId": 616},
|
681 |
+
{"name": "obelisk", "id": 1674, "trainId": 617},
|
682 |
+
{"name": "canoe", "id": 388, "trainId": 618},
|
683 |
+
{"name": "mobile", "id": 1572, "trainId": 619},
|
684 |
+
{"name": "monitors", "id": 1584, "trainId": 620},
|
685 |
+
{"name": "pool ball", "id": 1944, "trainId": 621},
|
686 |
+
{"name": "cue rack", "id": 674, "trainId": 622},
|
687 |
+
{"name": "baggage carts", "id": 99, "trainId": 623},
|
688 |
+
{"name": "shore", "id": 2352, "trainId": 624},
|
689 |
+
{"name": "fork", "id": 1019, "trainId": 625},
|
690 |
+
{"name": "paper filer", "id": 1763, "trainId": 626},
|
691 |
+
{"name": "bicycle rack", "id": 185, "trainId": 627},
|
692 |
+
{"name": "coat rack", "id": 554, "trainId": 628},
|
693 |
+
{"name": "garland", "id": 1066, "trainId": 629},
|
694 |
+
{"name": "sports bag", "id": 2508, "trainId": 630},
|
695 |
+
{"name": "fish tank", "id": 951, "trainId": 631},
|
696 |
+
{"name": "towel dispenser", "id": 2822, "trainId": 632},
|
697 |
+
{"name": "carriage", "id": 415, "trainId": 633},
|
698 |
+
{"name": "brochure", "id": 297, "trainId": 634},
|
699 |
+
{"name": "plaque", "id": 1914, "trainId": 635},
|
700 |
+
{"name": "stringer", "id": 2619, "trainId": 636},
|
701 |
+
{"name": "iron", "id": 1338, "trainId": 637},
|
702 |
+
{"name": "spoon", "id": 2505, "trainId": 638},
|
703 |
+
{"name": "flag pole", "id": 955, "trainId": 639},
|
704 |
+
{"name": "toilet brush", "id": 2786, "trainId": 640},
|
705 |
+
{"name": "book stand", "id": 238, "trainId": 641},
|
706 |
+
{"name": "water faucet, water tap, tap, hydrant", "id": 3000, "trainId": 642},
|
707 |
+
{"name": "ticket office", "id": 2763, "trainId": 643},
|
708 |
+
{"name": "broom", "id": 299, "trainId": 644},
|
709 |
+
{"name": "dvd", "id": 822, "trainId": 645},
|
710 |
+
{"name": "ice bucket", "id": 1288, "trainId": 646},
|
711 |
+
{"name": "carapace, shell, cuticle, shield", "id": 3101, "trainId": 647},
|
712 |
+
{"name": "tureen", "id": 2894, "trainId": 648},
|
713 |
+
{"name": "folders", "id": 992, "trainId": 649},
|
714 |
+
{"name": "chess", "id": 489, "trainId": 650},
|
715 |
+
{"name": "root", "id": 2157, "trainId": 651},
|
716 |
+
{"name": "sewing machine", "id": 2309, "trainId": 652},
|
717 |
+
{"name": "model", "id": 1576, "trainId": 653},
|
718 |
+
{"name": "pen", "id": 1810, "trainId": 654},
|
719 |
+
{"name": "violin", "id": 2964, "trainId": 655},
|
720 |
+
{"name": "sweatshirt", "id": 2662, "trainId": 656},
|
721 |
+
{"name": "recycling materials", "id": 2087, "trainId": 657},
|
722 |
+
{"name": "mitten", "id": 1569, "trainId": 658},
|
723 |
+
{"name": "chopping board, cutting board", "id": 503, "trainId": 659},
|
724 |
+
{"name": "mask", "id": 1505, "trainId": 660},
|
725 |
+
{"name": "log", "id": 1468, "trainId": 661},
|
726 |
+
{"name": "mouse, computer mouse", "id": 1613, "trainId": 662},
|
727 |
+
{"name": "grill", "id": 1138, "trainId": 663},
|
728 |
+
{"name": "hole", "id": 1256, "trainId": 664},
|
729 |
+
{"name": "target", "id": 2715, "trainId": 665},
|
730 |
+
{"name": "trash bag", "id": 2846, "trainId": 666},
|
731 |
+
{"name": "chalk", "id": 477, "trainId": 667},
|
732 |
+
{"name": "sticks", "id": 2576, "trainId": 668},
|
733 |
+
{"name": "balloon", "id": 108, "trainId": 669},
|
734 |
+
{"name": "score", "id": 2245, "trainId": 670},
|
735 |
+
{"name": "hair spray", "id": 1162, "trainId": 671},
|
736 |
+
{"name": "roll", "id": 2149, "trainId": 672},
|
737 |
+
{"name": "runner", "id": 2183, "trainId": 673},
|
738 |
+
{"name": "engine", "id": 858, "trainId": 674},
|
739 |
+
{"name": "inflatable glove", "id": 1324, "trainId": 675},
|
740 |
+
{"name": "games", "id": 1055, "trainId": 676},
|
741 |
+
{"name": "pallets", "id": 1741, "trainId": 677},
|
742 |
+
{"name": "baskets", "id": 149, "trainId": 678},
|
743 |
+
{"name": "coop", "id": 615, "trainId": 679},
|
744 |
+
{"name": "dvd player", "id": 825, "trainId": 680},
|
745 |
+
{"name": "rocking horse", "id": 2143, "trainId": 681},
|
746 |
+
{"name": "buckets", "id": 304, "trainId": 682},
|
747 |
+
{"name": "bread rolls", "id": 283, "trainId": 683},
|
748 |
+
{"name": "shawl", "id": 2322, "trainId": 684},
|
749 |
+
{"name": "watering can", "id": 3017, "trainId": 685},
|
750 |
+
{"name": "spotlights", "id": 2510, "trainId": 686},
|
751 |
+
{"name": "post-it", "id": 1960, "trainId": 687},
|
752 |
+
{"name": "bowls", "id": 265, "trainId": 688},
|
753 |
+
{"name": "security camera", "id": 2282, "trainId": 689},
|
754 |
+
{"name": "runner cloth", "id": 2184, "trainId": 690},
|
755 |
+
{"name": "lock", "id": 1461, "trainId": 691},
|
756 |
+
{"name": "alarm, warning device, alarm system", "id": 3113, "trainId": 692},
|
757 |
+
{"name": "side", "id": 2372, "trainId": 693},
|
758 |
+
{"name": "roulette", "id": 2166, "trainId": 694},
|
759 |
+
{"name": "bone", "id": 232, "trainId": 695},
|
760 |
+
{"name": "cutlery", "id": 693, "trainId": 696},
|
761 |
+
{"name": "pool balls", "id": 1945, "trainId": 697},
|
762 |
+
{"name": "wheels", "id": 3039, "trainId": 698},
|
763 |
+
{"name": "spice rack", "id": 2494, "trainId": 699},
|
764 |
+
{"name": "plant pots", "id": 1908, "trainId": 700},
|
765 |
+
{"name": "towel ring", "id": 2827, "trainId": 701},
|
766 |
+
{"name": "bread box", "id": 280, "trainId": 702},
|
767 |
+
{"name": "video", "id": 2950, "trainId": 703},
|
768 |
+
{"name": "funfair", "id": 1044, "trainId": 704},
|
769 |
+
{"name": "breads", "id": 288, "trainId": 705},
|
770 |
+
{"name": "tripod", "id": 2863, "trainId": 706},
|
771 |
+
{"name": "ironing board", "id": 1342, "trainId": 707},
|
772 |
+
{"name": "skimmer", "id": 2409, "trainId": 708},
|
773 |
+
{"name": "hollow", "id": 1258, "trainId": 709},
|
774 |
+
{"name": "scratching post", "id": 2249, "trainId": 710},
|
775 |
+
{"name": "tricycle", "id": 2862, "trainId": 711},
|
776 |
+
{"name": "file box", "id": 920, "trainId": 712},
|
777 |
+
{"name": "mountain pass", "id": 1607, "trainId": 713},
|
778 |
+
{"name": "tombstones", "id": 2802, "trainId": 714},
|
779 |
+
{"name": "cooker", "id": 610, "trainId": 715},
|
780 |
+
{"name": "card game, cards", "id": 3129, "trainId": 716},
|
781 |
+
{"name": "golf bag", "id": 1108, "trainId": 717},
|
782 |
+
{"name": "towel paper", "id": 2823, "trainId": 718},
|
783 |
+
{"name": "chaise lounge", "id": 476, "trainId": 719},
|
784 |
+
{"name": "sun", "id": 2641, "trainId": 720},
|
785 |
+
{"name": "toilet paper holder", "id": 2788, "trainId": 721},
|
786 |
+
{"name": "rake", "id": 2070, "trainId": 722},
|
787 |
+
{"name": "key", "id": 1368, "trainId": 723},
|
788 |
+
{"name": "umbrella stand", "id": 2903, "trainId": 724},
|
789 |
+
{"name": "dartboard", "id": 699, "trainId": 725},
|
790 |
+
{"name": "transformer", "id": 2844, "trainId": 726},
|
791 |
+
{"name": "fireplace utensils", "id": 942, "trainId": 727},
|
792 |
+
{"name": "sweatshirts", "id": 2663, "trainId": 728},
|
793 |
+
{
|
794 |
+
"name": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
|
795 |
+
"id": 457,
|
796 |
+
"trainId": 729,
|
797 |
+
},
|
798 |
+
{"name": "tallboy", "id": 2701, "trainId": 730},
|
799 |
+
{"name": "stapler", "id": 2540, "trainId": 731},
|
800 |
+
{"name": "sauna", "id": 2231, "trainId": 732},
|
801 |
+
{"name": "test tube", "id": 2746, "trainId": 733},
|
802 |
+
{"name": "palette", "id": 1738, "trainId": 734},
|
803 |
+
{"name": "shopping carts", "id": 2350, "trainId": 735},
|
804 |
+
{"name": "tools", "id": 2808, "trainId": 736},
|
805 |
+
{"name": "push button, push, button", "id": 2025, "trainId": 737},
|
806 |
+
{"name": "star", "id": 2541, "trainId": 738},
|
807 |
+
{"name": "roof rack", "id": 2156, "trainId": 739},
|
808 |
+
{"name": "barbed wire", "id": 126, "trainId": 740},
|
809 |
+
{"name": "spray", "id": 2512, "trainId": 741},
|
810 |
+
{"name": "ear", "id": 831, "trainId": 742},
|
811 |
+
{"name": "sponge", "id": 2503, "trainId": 743},
|
812 |
+
{"name": "racket", "id": 2039, "trainId": 744},
|
813 |
+
{"name": "tins", "id": 2774, "trainId": 745},
|
814 |
+
{"name": "eyeglasses", "id": 886, "trainId": 746},
|
815 |
+
{"name": "file", "id": 919, "trainId": 747},
|
816 |
+
{"name": "scarfs", "id": 2240, "trainId": 748},
|
817 |
+
{"name": "sugar bowl", "id": 2636, "trainId": 749},
|
818 |
+
{"name": "flip flop", "id": 963, "trainId": 750},
|
819 |
+
{"name": "headstones", "id": 1218, "trainId": 751},
|
820 |
+
{"name": "laptop bag", "id": 1406, "trainId": 752},
|
821 |
+
{"name": "leash", "id": 1420, "trainId": 753},
|
822 |
+
{"name": "climbing frame", "id": 526, "trainId": 754},
|
823 |
+
{"name": "suit hanger", "id": 2639, "trainId": 755},
|
824 |
+
{"name": "floor spotlight", "id": 975, "trainId": 756},
|
825 |
+
{"name": "plate rack", "id": 1921, "trainId": 757},
|
826 |
+
{"name": "sewer", "id": 2305, "trainId": 758},
|
827 |
+
{"name": "hard drive", "id": 1193, "trainId": 759},
|
828 |
+
{"name": "sprinkler", "id": 2517, "trainId": 760},
|
829 |
+
{"name": "tools box", "id": 2809, "trainId": 761},
|
830 |
+
{"name": "necklace", "id": 1647, "trainId": 762},
|
831 |
+
{"name": "bulbs", "id": 314, "trainId": 763},
|
832 |
+
{"name": "steel industry", "id": 2560, "trainId": 764},
|
833 |
+
{"name": "club", "id": 545, "trainId": 765},
|
834 |
+
{"name": "jack", "id": 1345, "trainId": 766},
|
835 |
+
{"name": "door bars", "id": 775, "trainId": 767},
|
836 |
+
{
|
837 |
+
"name": "control panel, instrument panel, control board, board, panel",
|
838 |
+
"id": 603,
|
839 |
+
"trainId": 768,
|
840 |
+
},
|
841 |
+
{"name": "hairbrush", "id": 1163, "trainId": 769},
|
842 |
+
{"name": "napkin holder", "id": 1641, "trainId": 770},
|
843 |
+
{"name": "office", "id": 1678, "trainId": 771},
|
844 |
+
{"name": "smoke detector", "id": 2450, "trainId": 772},
|
845 |
+
{"name": "utensils", "id": 2915, "trainId": 773},
|
846 |
+
{"name": "apron", "id": 42, "trainId": 774},
|
847 |
+
{"name": "scissors", "id": 2242, "trainId": 775},
|
848 |
+
{"name": "terminal", "id": 2741, "trainId": 776},
|
849 |
+
{"name": "grinder", "id": 1143, "trainId": 777},
|
850 |
+
{"name": "entry phone", "id": 862, "trainId": 778},
|
851 |
+
{"name": "newspaper stand", "id": 1654, "trainId": 779},
|
852 |
+
{"name": "pepper shaker", "id": 1826, "trainId": 780},
|
853 |
+
{"name": "onions", "id": 1689, "trainId": 781},
|
854 |
+
{
|
855 |
+
"name": "central processing unit, cpu, c p u , central processor, processor, mainframe",
|
856 |
+
"id": 3124,
|
857 |
+
"trainId": 782,
|
858 |
+
},
|
859 |
+
{"name": "tape", "id": 2710, "trainId": 783},
|
860 |
+
{"name": "bat", "id": 152, "trainId": 784},
|
861 |
+
{"name": "coaster", "id": 549, "trainId": 785},
|
862 |
+
{"name": "calculator", "id": 360, "trainId": 786},
|
863 |
+
{"name": "potatoes", "id": 1982, "trainId": 787},
|
864 |
+
{"name": "luggage rack", "id": 1478, "trainId": 788},
|
865 |
+
{"name": "salt", "id": 2203, "trainId": 789},
|
866 |
+
{"name": "street number", "id": 2612, "trainId": 790},
|
867 |
+
{"name": "viewpoint", "id": 2956, "trainId": 791},
|
868 |
+
{"name": "sword", "id": 2681, "trainId": 792},
|
869 |
+
{"name": "cd", "id": 437, "trainId": 793},
|
870 |
+
{"name": "rowing machine", "id": 2171, "trainId": 794},
|
871 |
+
{"name": "plug", "id": 1933, "trainId": 795},
|
872 |
+
{"name": "andiron, firedog, dog, dog-iron", "id": 3110, "trainId": 796},
|
873 |
+
{"name": "pepper", "id": 1824, "trainId": 797},
|
874 |
+
{"name": "tongs", "id": 2803, "trainId": 798},
|
875 |
+
{"name": "bonfire", "id": 234, "trainId": 799},
|
876 |
+
{"name": "dog dish", "id": 764, "trainId": 800},
|
877 |
+
{"name": "belt", "id": 177, "trainId": 801},
|
878 |
+
{"name": "dumbbells", "id": 817, "trainId": 802},
|
879 |
+
{"name": "videocassette recorder, vcr", "id": 3145, "trainId": 803},
|
880 |
+
{"name": "hook", "id": 1262, "trainId": 804},
|
881 |
+
{"name": "envelopes", "id": 864, "trainId": 805},
|
882 |
+
{"name": "shower faucet", "id": 2359, "trainId": 806},
|
883 |
+
{"name": "watch", "id": 2992, "trainId": 807},
|
884 |
+
{"name": "padlock", "id": 1725, "trainId": 808},
|
885 |
+
{"name": "swimming pool ladder", "id": 2667, "trainId": 809},
|
886 |
+
{"name": "spanners", "id": 2484, "trainId": 810},
|
887 |
+
{"name": "gravy boat", "id": 1133, "trainId": 811},
|
888 |
+
{"name": "notice board", "id": 1667, "trainId": 812},
|
889 |
+
{"name": "trash bags", "id": 2847, "trainId": 813},
|
890 |
+
{"name": "fire alarm", "id": 932, "trainId": 814},
|
891 |
+
{"name": "ladle", "id": 1392, "trainId": 815},
|
892 |
+
{"name": "stethoscope", "id": 2573, "trainId": 816},
|
893 |
+
{"name": "rocket", "id": 2140, "trainId": 817},
|
894 |
+
{"name": "funnel", "id": 1046, "trainId": 818},
|
895 |
+
{"name": "bowling pins", "id": 264, "trainId": 819},
|
896 |
+
{"name": "valve", "id": 2927, "trainId": 820},
|
897 |
+
{"name": "thermometer", "id": 2752, "trainId": 821},
|
898 |
+
{"name": "cups", "id": 679, "trainId": 822},
|
899 |
+
{"name": "spice jar", "id": 2493, "trainId": 823},
|
900 |
+
{"name": "night light", "id": 1658, "trainId": 824},
|
901 |
+
{"name": "soaps", "id": 2466, "trainId": 825},
|
902 |
+
{"name": "games table", "id": 1057, "trainId": 826},
|
903 |
+
{"name": "slotted spoon", "id": 2444, "trainId": 827},
|
904 |
+
{"name": "reel", "id": 2093, "trainId": 828},
|
905 |
+
{"name": "scourer", "id": 2248, "trainId": 829},
|
906 |
+
{"name": "sleeping robe", "id": 2432, "trainId": 830},
|
907 |
+
{"name": "desk mat", "id": 726, "trainId": 831},
|
908 |
+
{"name": "dumbbell", "id": 816, "trainId": 832},
|
909 |
+
{"name": "hammer", "id": 1171, "trainId": 833},
|
910 |
+
{"name": "tie", "id": 2766, "trainId": 834},
|
911 |
+
{"name": "typewriter", "id": 2900, "trainId": 835},
|
912 |
+
{"name": "shaker", "id": 2313, "trainId": 836},
|
913 |
+
{"name": "cheese dish", "id": 488, "trainId": 837},
|
914 |
+
{"name": "sea star", "id": 2265, "trainId": 838},
|
915 |
+
{"name": "racquet", "id": 2043, "trainId": 839},
|
916 |
+
{"name": "butane gas cylinder", "id": 332, "trainId": 840},
|
917 |
+
{"name": "paper weight", "id": 1771, "trainId": 841},
|
918 |
+
{"name": "shaving brush", "id": 2320, "trainId": 842},
|
919 |
+
{"name": "sunglasses", "id": 2646, "trainId": 843},
|
920 |
+
{"name": "gear shift", "id": 1089, "trainId": 844},
|
921 |
+
{"name": "towel rail", "id": 2826, "trainId": 845},
|
922 |
+
{"name": "adding machine, totalizer, totaliser", "id": 3148, "trainId": 846},
|
923 |
+
]
|
924 |
+
|
925 |
+
|
926 |
+
def _get_ade20k_full_meta():
|
927 |
+
# Id 0 is reserved for ignore_label, we change ignore_label for 0
|
928 |
+
# to 255 in our pre-processing, so all ids are shifted by 1.
|
929 |
+
stuff_ids = [k["id"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
|
930 |
+
assert len(stuff_ids) == 847, len(stuff_ids)
|
931 |
+
|
932 |
+
# For semantic segmentation, this mapping maps from contiguous stuff id
|
933 |
+
# (in [0, 91], used in models) to ids in the dataset (used for processing results)
|
934 |
+
stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
935 |
+
stuff_classes = [k["name"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
|
936 |
+
|
937 |
+
ret = {
|
938 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
939 |
+
"stuff_classes": stuff_classes,
|
940 |
+
}
|
941 |
+
return ret
|
942 |
+
|
943 |
+
|
944 |
+
def register_all_ade20k_full(root):
|
945 |
+
root = os.path.join(root, "ADE20K_2021_17_01")
|
946 |
+
meta = _get_ade20k_full_meta()
|
947 |
+
for name, dirname in [("train", "training"), ("val", "validation")]:
|
948 |
+
image_dir = os.path.join(root, "images_detectron2", dirname)
|
949 |
+
gt_dir = os.path.join(root, "annotations_detectron2", dirname)
|
950 |
+
name = f"ade20k_full_sem_seg_{name}"
|
951 |
+
DatasetCatalog.register(
|
952 |
+
name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="tif", image_ext="jpg")
|
953 |
+
)
|
954 |
+
MetadataCatalog.get(name).set(
|
955 |
+
stuff_classes=meta["stuff_classes"][:],
|
956 |
+
image_root=image_dir,
|
957 |
+
sem_seg_root=gt_dir,
|
958 |
+
evaluator_type="sem_seg",
|
959 |
+
ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images
|
960 |
+
)
|
961 |
+
|
962 |
+
|
963 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
964 |
+
register_all_ade20k_full(_root)
|
mask_former/data/datasets/register_ade20k_panoptic.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
6 |
+
from detectron2.utils.file_io import PathManager
|
7 |
+
|
8 |
+
ADE20K_150_CATEGORIES = [
|
9 |
+
{"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
|
10 |
+
{"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
|
11 |
+
{"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
|
12 |
+
{"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
|
13 |
+
{"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
|
14 |
+
{"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
|
15 |
+
{"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
|
16 |
+
{"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
|
17 |
+
{"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
|
18 |
+
{"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
|
19 |
+
{"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
|
20 |
+
{"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
|
21 |
+
{"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
|
22 |
+
{"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
|
23 |
+
{"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
|
24 |
+
{"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
|
25 |
+
{"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
|
26 |
+
{"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
|
27 |
+
{"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
|
28 |
+
{"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
|
29 |
+
{"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
|
30 |
+
{"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
|
31 |
+
{"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
|
32 |
+
{"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
|
33 |
+
{"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
|
34 |
+
{"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
|
35 |
+
{"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
|
36 |
+
{"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
|
37 |
+
{"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
|
38 |
+
{"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
|
39 |
+
{"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
|
40 |
+
{"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
|
41 |
+
{"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
|
42 |
+
{"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
|
43 |
+
{"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
|
44 |
+
{"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
|
45 |
+
{"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
|
46 |
+
{"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
|
47 |
+
{"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
|
48 |
+
{"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
|
49 |
+
{"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
|
50 |
+
{"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
|
51 |
+
{"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
|
52 |
+
{"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
|
53 |
+
{
|
54 |
+
"color": [6, 51, 255],
|
55 |
+
"id": 44,
|
56 |
+
"isthing": 1,
|
57 |
+
"name": "chest of drawers, chest, bureau, dresser",
|
58 |
+
},
|
59 |
+
{"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
|
60 |
+
{"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
|
61 |
+
{"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
|
62 |
+
{"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
|
63 |
+
{"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
|
64 |
+
{"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
|
65 |
+
{"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
|
66 |
+
{"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
|
67 |
+
{"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
|
68 |
+
{"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
|
69 |
+
{"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
|
70 |
+
{
|
71 |
+
"color": [255, 71, 0],
|
72 |
+
"id": 56,
|
73 |
+
"isthing": 1,
|
74 |
+
"name": "pool table, billiard table, snooker table",
|
75 |
+
},
|
76 |
+
{"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
|
77 |
+
{"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
|
78 |
+
{"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
|
79 |
+
{"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
|
80 |
+
{"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
|
81 |
+
{"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
|
82 |
+
{"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
|
83 |
+
{"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
|
84 |
+
{
|
85 |
+
"color": [0, 255, 133],
|
86 |
+
"id": 65,
|
87 |
+
"isthing": 1,
|
88 |
+
"name": "toilet, can, commode, crapper, pot, potty, stool, throne",
|
89 |
+
},
|
90 |
+
{"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
|
91 |
+
{"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
|
92 |
+
{"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
|
93 |
+
{"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
|
94 |
+
{"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
|
95 |
+
{"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
|
96 |
+
{"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
|
97 |
+
{"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
|
98 |
+
{"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
|
99 |
+
{"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
|
100 |
+
{"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
|
101 |
+
{"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
|
102 |
+
{"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
|
103 |
+
{"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
|
104 |
+
{"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
|
105 |
+
{"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
|
106 |
+
{"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
|
107 |
+
{"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
|
108 |
+
{"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
|
109 |
+
{"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
|
110 |
+
{"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
|
111 |
+
{"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
|
112 |
+
{"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
|
113 |
+
{"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
|
114 |
+
{"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
|
115 |
+
{"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
|
116 |
+
{"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
|
117 |
+
{"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
|
118 |
+
{"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
|
119 |
+
{
|
120 |
+
"color": [0, 122, 255],
|
121 |
+
"id": 95,
|
122 |
+
"isthing": 1,
|
123 |
+
"name": "bannister, banister, balustrade, balusters, handrail",
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"color": [0, 255, 163],
|
127 |
+
"id": 96,
|
128 |
+
"isthing": 0,
|
129 |
+
"name": "escalator, moving staircase, moving stairway",
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"color": [255, 153, 0],
|
133 |
+
"id": 97,
|
134 |
+
"isthing": 1,
|
135 |
+
"name": "ottoman, pouf, pouffe, puff, hassock",
|
136 |
+
},
|
137 |
+
{"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
|
138 |
+
{"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
|
139 |
+
{
|
140 |
+
"color": [143, 255, 0],
|
141 |
+
"id": 100,
|
142 |
+
"isthing": 0,
|
143 |
+
"name": "poster, posting, placard, notice, bill, card",
|
144 |
+
},
|
145 |
+
{"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
|
146 |
+
{"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
|
147 |
+
{"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
|
148 |
+
{"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
|
149 |
+
{
|
150 |
+
"color": [133, 0, 255],
|
151 |
+
"id": 105,
|
152 |
+
"isthing": 0,
|
153 |
+
"name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
|
154 |
+
},
|
155 |
+
{"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
|
156 |
+
{
|
157 |
+
"color": [184, 0, 255],
|
158 |
+
"id": 107,
|
159 |
+
"isthing": 1,
|
160 |
+
"name": "washer, automatic washer, washing machine",
|
161 |
+
},
|
162 |
+
{"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
|
163 |
+
{"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
|
164 |
+
{"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
|
165 |
+
{"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
|
166 |
+
{"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
|
167 |
+
{"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
|
168 |
+
{"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
|
169 |
+
{"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
|
170 |
+
{"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
|
171 |
+
{"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
|
172 |
+
{"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
|
173 |
+
{"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
|
174 |
+
{"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
|
175 |
+
{"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
|
176 |
+
{"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
|
177 |
+
{"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
|
178 |
+
{"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
|
179 |
+
{"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
|
180 |
+
{"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
|
181 |
+
{"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
|
182 |
+
{"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
|
183 |
+
{"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
|
184 |
+
{"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
|
185 |
+
{"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
|
186 |
+
{"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
|
187 |
+
{"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
|
188 |
+
{"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
|
189 |
+
{"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
|
190 |
+
{"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
|
191 |
+
{"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
|
192 |
+
{"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
|
193 |
+
{"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
|
194 |
+
{"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
|
195 |
+
{"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
|
196 |
+
{"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
|
197 |
+
{"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
|
198 |
+
{"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
|
199 |
+
{"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
|
200 |
+
{"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
|
201 |
+
{"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
|
202 |
+
{"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
|
203 |
+
{"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
|
204 |
+
]
|
205 |
+
|
206 |
+
ADE20k_COLORS = [k["color"] for k in ADE20K_150_CATEGORIES]
|
207 |
+
|
208 |
+
MetadataCatalog.get("ade20k_sem_seg_train").set(
|
209 |
+
stuff_colors=ADE20k_COLORS[:],
|
210 |
+
)
|
211 |
+
|
212 |
+
MetadataCatalog.get("ade20k_sem_seg_val").set(
|
213 |
+
stuff_colors=ADE20k_COLORS[:],
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
def load_ade20k_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta):
|
218 |
+
"""
|
219 |
+
Args:
|
220 |
+
image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
|
221 |
+
gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
|
222 |
+
json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
|
223 |
+
Returns:
|
224 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
225 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
226 |
+
"""
|
227 |
+
|
228 |
+
def _convert_category_id(segment_info, meta):
|
229 |
+
if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
|
230 |
+
segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
|
231 |
+
segment_info["category_id"]
|
232 |
+
]
|
233 |
+
segment_info["isthing"] = True
|
234 |
+
else:
|
235 |
+
segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
|
236 |
+
segment_info["category_id"]
|
237 |
+
]
|
238 |
+
segment_info["isthing"] = False
|
239 |
+
return segment_info
|
240 |
+
|
241 |
+
with PathManager.open(json_file) as f:
|
242 |
+
json_info = json.load(f)
|
243 |
+
|
244 |
+
ret = []
|
245 |
+
for ann in json_info["annotations"]:
|
246 |
+
image_id = ann["image_id"]
|
247 |
+
# TODO: currently we assume image and label has the same filename but
|
248 |
+
# different extension, and images have extension ".jpg" for COCO. Need
|
249 |
+
# to make image extension a user-provided argument if we extend this
|
250 |
+
# function to support other COCO-like datasets.
|
251 |
+
image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
|
252 |
+
label_file = os.path.join(gt_dir, ann["file_name"])
|
253 |
+
sem_label_file = os.path.join(semseg_dir, ann["file_name"])
|
254 |
+
segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
|
255 |
+
ret.append(
|
256 |
+
{
|
257 |
+
"file_name": image_file,
|
258 |
+
"image_id": image_id,
|
259 |
+
"pan_seg_file_name": label_file,
|
260 |
+
"sem_seg_file_name": sem_label_file,
|
261 |
+
"segments_info": segments_info,
|
262 |
+
}
|
263 |
+
)
|
264 |
+
assert len(ret), f"No images found in {image_dir}!"
|
265 |
+
assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
|
266 |
+
assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
|
267 |
+
assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
|
268 |
+
return ret
|
269 |
+
|
270 |
+
|
271 |
+
def register_ade20k_panoptic(
|
272 |
+
name, metadata, image_root, panoptic_root, semantic_root, panoptic_json, instances_json=None
|
273 |
+
):
|
274 |
+
"""
|
275 |
+
Register a "standard" version of ADE20k panoptic segmentation dataset named `name`.
|
276 |
+
The dictionaries in this registered dataset follows detectron2's standard format.
|
277 |
+
Hence it's called "standard".
|
278 |
+
Args:
|
279 |
+
name (str): the name that identifies a dataset,
|
280 |
+
e.g. "ade20k_panoptic_train"
|
281 |
+
metadata (dict): extra metadata associated with this dataset.
|
282 |
+
image_root (str): directory which contains all the images
|
283 |
+
panoptic_root (str): directory which contains panoptic annotation images in COCO format
|
284 |
+
panoptic_json (str): path to the json panoptic annotation file in COCO format
|
285 |
+
sem_seg_root (none): not used, to be consistent with
|
286 |
+
`register_coco_panoptic_separated`.
|
287 |
+
instances_json (str): path to the json instance annotation file
|
288 |
+
"""
|
289 |
+
panoptic_name = name
|
290 |
+
DatasetCatalog.register(
|
291 |
+
panoptic_name,
|
292 |
+
lambda: load_ade20k_panoptic_json(
|
293 |
+
panoptic_json, image_root, panoptic_root, semantic_root, metadata
|
294 |
+
),
|
295 |
+
)
|
296 |
+
MetadataCatalog.get(panoptic_name).set(
|
297 |
+
panoptic_root=panoptic_root,
|
298 |
+
image_root=image_root,
|
299 |
+
panoptic_json=panoptic_json,
|
300 |
+
json_file=instances_json,
|
301 |
+
evaluator_type="ade20k_panoptic_seg",
|
302 |
+
ignore_label=255,
|
303 |
+
label_divisor=1000,
|
304 |
+
**metadata,
|
305 |
+
)
|
306 |
+
|
307 |
+
|
308 |
+
_PREDEFINED_SPLITS_ADE20K_PANOPTIC = {
|
309 |
+
"ade20k_panoptic_train": (
|
310 |
+
"ADEChallengeData2016/images/training",
|
311 |
+
"ADEChallengeData2016/ade20k_panoptic_train",
|
312 |
+
"ADEChallengeData2016/ade20k_panoptic_train.json",
|
313 |
+
"ADEChallengeData2016/annotations_detectron2/training",
|
314 |
+
),
|
315 |
+
"ade20k_panoptic_val": (
|
316 |
+
"ADEChallengeData2016/images/validation",
|
317 |
+
"ADEChallengeData2016/ade20k_panoptic_val",
|
318 |
+
"ADEChallengeData2016/ade20k_panoptic_val.json",
|
319 |
+
"ADEChallengeData2016/annotations_detectron2/validation",
|
320 |
+
),
|
321 |
+
}
|
322 |
+
|
323 |
+
|
324 |
+
def get_metadata():
|
325 |
+
meta = {}
|
326 |
+
# The following metadata maps contiguous id from [0, #thing categories +
|
327 |
+
# #stuff categories) to their names and colors. We have to replica of the
|
328 |
+
# same name and color under "thing_*" and "stuff_*" because the current
|
329 |
+
# visualization function in D2 handles thing and class classes differently
|
330 |
+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
|
331 |
+
# enable reusing existing visualization functions.
|
332 |
+
thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
|
333 |
+
thing_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
|
334 |
+
stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
|
335 |
+
stuff_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
|
336 |
+
|
337 |
+
meta["thing_classes"] = thing_classes
|
338 |
+
meta["thing_colors"] = thing_colors
|
339 |
+
meta["stuff_classes"] = stuff_classes
|
340 |
+
meta["stuff_colors"] = stuff_colors
|
341 |
+
|
342 |
+
# Convert category id for training:
|
343 |
+
# category id: like semantic segmentation, it is the class id for each
|
344 |
+
# pixel. Since there are some classes not used in evaluation, the category
|
345 |
+
# id is not always contiguous and thus we have two set of category ids:
|
346 |
+
# - original category id: category id in the original dataset, mainly
|
347 |
+
# used for evaluation.
|
348 |
+
# - contiguous category id: [0, #classes), in order to train the linear
|
349 |
+
# softmax classifier.
|
350 |
+
thing_dataset_id_to_contiguous_id = {}
|
351 |
+
stuff_dataset_id_to_contiguous_id = {}
|
352 |
+
|
353 |
+
for i, cat in enumerate(ADE20K_150_CATEGORIES):
|
354 |
+
if cat["isthing"]:
|
355 |
+
thing_dataset_id_to_contiguous_id[cat["id"]] = i
|
356 |
+
# else:
|
357 |
+
# stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
358 |
+
|
359 |
+
# in order to use sem_seg evaluator
|
360 |
+
stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
361 |
+
|
362 |
+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
363 |
+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
364 |
+
|
365 |
+
return meta
|
366 |
+
|
367 |
+
|
368 |
+
def register_all_ade20k_panoptic(root):
|
369 |
+
metadata = get_metadata()
|
370 |
+
for (
|
371 |
+
prefix,
|
372 |
+
(image_root, panoptic_root, panoptic_json, semantic_root),
|
373 |
+
) in _PREDEFINED_SPLITS_ADE20K_PANOPTIC.items():
|
374 |
+
# The "standard" version of COCO panoptic segmentation dataset,
|
375 |
+
# e.g. used by Panoptic-DeepLab
|
376 |
+
register_ade20k_panoptic(
|
377 |
+
prefix,
|
378 |
+
metadata,
|
379 |
+
os.path.join(root, image_root),
|
380 |
+
os.path.join(root, panoptic_root),
|
381 |
+
os.path.join(root, semantic_root),
|
382 |
+
os.path.join(root, panoptic_json),
|
383 |
+
)
|
384 |
+
|
385 |
+
|
386 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
387 |
+
register_all_ade20k_panoptic(_root)
|
mask_former/data/datasets/register_coco_stuff_10k.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import os
|
3 |
+
|
4 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
5 |
+
from detectron2.data.datasets import load_sem_seg
|
6 |
+
|
7 |
+
COCO_CATEGORIES = [
|
8 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
9 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
10 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
11 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
12 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
13 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
14 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
15 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
16 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
17 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
18 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
19 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
20 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
21 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
22 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
23 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
24 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
25 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
26 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
27 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
28 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
29 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
30 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
31 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
32 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
33 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
34 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
35 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
36 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
37 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
38 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
39 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
40 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
41 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
42 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
43 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
44 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
45 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
46 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
47 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
48 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
49 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
50 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
51 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
52 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
53 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
54 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
55 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
56 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
57 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
58 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
59 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
60 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
61 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
62 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
63 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
64 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
65 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
66 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
67 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
68 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
69 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
70 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
71 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
72 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
73 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
74 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
75 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
76 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
77 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
78 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
79 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
80 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
81 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
82 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
83 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
84 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
85 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
86 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
87 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
88 |
+
{"id": 92, "name": "banner", "supercategory": "textile"},
|
89 |
+
{"id": 93, "name": "blanket", "supercategory": "textile"},
|
90 |
+
{"id": 94, "name": "branch", "supercategory": "plant"},
|
91 |
+
{"id": 95, "name": "bridge", "supercategory": "building"},
|
92 |
+
{"id": 96, "name": "building-other", "supercategory": "building"},
|
93 |
+
{"id": 97, "name": "bush", "supercategory": "plant"},
|
94 |
+
{"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
|
95 |
+
{"id": 99, "name": "cage", "supercategory": "structural"},
|
96 |
+
{"id": 100, "name": "cardboard", "supercategory": "raw-material"},
|
97 |
+
{"id": 101, "name": "carpet", "supercategory": "floor"},
|
98 |
+
{"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
|
99 |
+
{"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
|
100 |
+
{"id": 104, "name": "cloth", "supercategory": "textile"},
|
101 |
+
{"id": 105, "name": "clothes", "supercategory": "textile"},
|
102 |
+
{"id": 106, "name": "clouds", "supercategory": "sky"},
|
103 |
+
{"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
|
104 |
+
{"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
|
105 |
+
{"id": 109, "name": "curtain", "supercategory": "textile"},
|
106 |
+
{"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
|
107 |
+
{"id": 111, "name": "dirt", "supercategory": "ground"},
|
108 |
+
{"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
|
109 |
+
{"id": 113, "name": "fence", "supercategory": "structural"},
|
110 |
+
{"id": 114, "name": "floor-marble", "supercategory": "floor"},
|
111 |
+
{"id": 115, "name": "floor-other", "supercategory": "floor"},
|
112 |
+
{"id": 116, "name": "floor-stone", "supercategory": "floor"},
|
113 |
+
{"id": 117, "name": "floor-tile", "supercategory": "floor"},
|
114 |
+
{"id": 118, "name": "floor-wood", "supercategory": "floor"},
|
115 |
+
{"id": 119, "name": "flower", "supercategory": "plant"},
|
116 |
+
{"id": 120, "name": "fog", "supercategory": "water"},
|
117 |
+
{"id": 121, "name": "food-other", "supercategory": "food-stuff"},
|
118 |
+
{"id": 122, "name": "fruit", "supercategory": "food-stuff"},
|
119 |
+
{"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
|
120 |
+
{"id": 124, "name": "grass", "supercategory": "plant"},
|
121 |
+
{"id": 125, "name": "gravel", "supercategory": "ground"},
|
122 |
+
{"id": 126, "name": "ground-other", "supercategory": "ground"},
|
123 |
+
{"id": 127, "name": "hill", "supercategory": "solid"},
|
124 |
+
{"id": 128, "name": "house", "supercategory": "building"},
|
125 |
+
{"id": 129, "name": "leaves", "supercategory": "plant"},
|
126 |
+
{"id": 130, "name": "light", "supercategory": "furniture-stuff"},
|
127 |
+
{"id": 131, "name": "mat", "supercategory": "textile"},
|
128 |
+
{"id": 132, "name": "metal", "supercategory": "raw-material"},
|
129 |
+
{"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
|
130 |
+
{"id": 134, "name": "moss", "supercategory": "plant"},
|
131 |
+
{"id": 135, "name": "mountain", "supercategory": "solid"},
|
132 |
+
{"id": 136, "name": "mud", "supercategory": "ground"},
|
133 |
+
{"id": 137, "name": "napkin", "supercategory": "textile"},
|
134 |
+
{"id": 138, "name": "net", "supercategory": "structural"},
|
135 |
+
{"id": 139, "name": "paper", "supercategory": "raw-material"},
|
136 |
+
{"id": 140, "name": "pavement", "supercategory": "ground"},
|
137 |
+
{"id": 141, "name": "pillow", "supercategory": "textile"},
|
138 |
+
{"id": 142, "name": "plant-other", "supercategory": "plant"},
|
139 |
+
{"id": 143, "name": "plastic", "supercategory": "raw-material"},
|
140 |
+
{"id": 144, "name": "platform", "supercategory": "ground"},
|
141 |
+
{"id": 145, "name": "playingfield", "supercategory": "ground"},
|
142 |
+
{"id": 146, "name": "railing", "supercategory": "structural"},
|
143 |
+
{"id": 147, "name": "railroad", "supercategory": "ground"},
|
144 |
+
{"id": 148, "name": "river", "supercategory": "water"},
|
145 |
+
{"id": 149, "name": "road", "supercategory": "ground"},
|
146 |
+
{"id": 150, "name": "rock", "supercategory": "solid"},
|
147 |
+
{"id": 151, "name": "roof", "supercategory": "building"},
|
148 |
+
{"id": 152, "name": "rug", "supercategory": "textile"},
|
149 |
+
{"id": 153, "name": "salad", "supercategory": "food-stuff"},
|
150 |
+
{"id": 154, "name": "sand", "supercategory": "ground"},
|
151 |
+
{"id": 155, "name": "sea", "supercategory": "water"},
|
152 |
+
{"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
|
153 |
+
{"id": 157, "name": "sky-other", "supercategory": "sky"},
|
154 |
+
{"id": 158, "name": "skyscraper", "supercategory": "building"},
|
155 |
+
{"id": 159, "name": "snow", "supercategory": "ground"},
|
156 |
+
{"id": 160, "name": "solid-other", "supercategory": "solid"},
|
157 |
+
{"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
|
158 |
+
{"id": 162, "name": "stone", "supercategory": "solid"},
|
159 |
+
{"id": 163, "name": "straw", "supercategory": "plant"},
|
160 |
+
{"id": 164, "name": "structural-other", "supercategory": "structural"},
|
161 |
+
{"id": 165, "name": "table", "supercategory": "furniture-stuff"},
|
162 |
+
{"id": 166, "name": "tent", "supercategory": "building"},
|
163 |
+
{"id": 167, "name": "textile-other", "supercategory": "textile"},
|
164 |
+
{"id": 168, "name": "towel", "supercategory": "textile"},
|
165 |
+
{"id": 169, "name": "tree", "supercategory": "plant"},
|
166 |
+
{"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
|
167 |
+
{"id": 171, "name": "wall-brick", "supercategory": "wall"},
|
168 |
+
{"id": 172, "name": "wall-concrete", "supercategory": "wall"},
|
169 |
+
{"id": 173, "name": "wall-other", "supercategory": "wall"},
|
170 |
+
{"id": 174, "name": "wall-panel", "supercategory": "wall"},
|
171 |
+
{"id": 175, "name": "wall-stone", "supercategory": "wall"},
|
172 |
+
{"id": 176, "name": "wall-tile", "supercategory": "wall"},
|
173 |
+
{"id": 177, "name": "wall-wood", "supercategory": "wall"},
|
174 |
+
{"id": 178, "name": "water-other", "supercategory": "water"},
|
175 |
+
{"id": 179, "name": "waterdrops", "supercategory": "water"},
|
176 |
+
{"id": 180, "name": "window-blind", "supercategory": "window"},
|
177 |
+
{"id": 181, "name": "window-other", "supercategory": "window"},
|
178 |
+
{"id": 182, "name": "wood", "supercategory": "solid"},
|
179 |
+
]
|
180 |
+
|
181 |
+
|
182 |
+
def _get_coco_stuff_meta():
|
183 |
+
# Id 0 is reserved for ignore_label, we change ignore_label for 0
|
184 |
+
# to 255 in our pre-processing.
|
185 |
+
stuff_ids = [k["id"] for k in COCO_CATEGORIES]
|
186 |
+
assert len(stuff_ids) == 171, len(stuff_ids)
|
187 |
+
|
188 |
+
# For semantic segmentation, this mapping maps from contiguous stuff id
|
189 |
+
# (in [0, 91], used in models) to ids in the dataset (used for processing results)
|
190 |
+
stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
191 |
+
stuff_classes = [k["name"] for k in COCO_CATEGORIES]
|
192 |
+
|
193 |
+
ret = {
|
194 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
195 |
+
"stuff_classes": stuff_classes,
|
196 |
+
}
|
197 |
+
return ret
|
198 |
+
|
199 |
+
|
200 |
+
def register_all_coco_stuff_10k(root):
|
201 |
+
root = os.path.join(root, "coco", "coco_stuff_10k")
|
202 |
+
meta = _get_coco_stuff_meta()
|
203 |
+
for name, image_dirname, sem_seg_dirname in [
|
204 |
+
("train", "images_detectron2/train", "annotations_detectron2/train"),
|
205 |
+
("test", "images_detectron2/test", "annotations_detectron2/test"),
|
206 |
+
]:
|
207 |
+
image_dir = os.path.join(root, image_dirname)
|
208 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
209 |
+
name = f"coco_2017_{name}_stuff_10k_sem_seg"
|
210 |
+
DatasetCatalog.register(
|
211 |
+
name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
|
212 |
+
)
|
213 |
+
MetadataCatalog.get(name).set(
|
214 |
+
image_root=image_dir,
|
215 |
+
sem_seg_root=gt_dir,
|
216 |
+
evaluator_type="sem_seg",
|
217 |
+
ignore_label=255,
|
218 |
+
**meta,
|
219 |
+
)
|
220 |
+
|
221 |
+
|
222 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
223 |
+
register_all_coco_stuff_10k(_root)
|
mask_former/data/datasets/register_mapillary_vistas.py
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import os
|
3 |
+
|
4 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
5 |
+
from detectron2.data.datasets import load_sem_seg
|
6 |
+
|
7 |
+
MAPILLARY_VISTAS_SEM_SEG_CATEGORIES = [
|
8 |
+
{
|
9 |
+
"color": [165, 42, 42],
|
10 |
+
"instances": True,
|
11 |
+
"readable": "Bird",
|
12 |
+
"name": "animal--bird",
|
13 |
+
"evaluate": True,
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"color": [0, 192, 0],
|
17 |
+
"instances": True,
|
18 |
+
"readable": "Ground Animal",
|
19 |
+
"name": "animal--ground-animal",
|
20 |
+
"evaluate": True,
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"color": [196, 196, 196],
|
24 |
+
"instances": False,
|
25 |
+
"readable": "Curb",
|
26 |
+
"name": "construction--barrier--curb",
|
27 |
+
"evaluate": True,
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"color": [190, 153, 153],
|
31 |
+
"instances": False,
|
32 |
+
"readable": "Fence",
|
33 |
+
"name": "construction--barrier--fence",
|
34 |
+
"evaluate": True,
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"color": [180, 165, 180],
|
38 |
+
"instances": False,
|
39 |
+
"readable": "Guard Rail",
|
40 |
+
"name": "construction--barrier--guard-rail",
|
41 |
+
"evaluate": True,
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"color": [90, 120, 150],
|
45 |
+
"instances": False,
|
46 |
+
"readable": "Barrier",
|
47 |
+
"name": "construction--barrier--other-barrier",
|
48 |
+
"evaluate": True,
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"color": [102, 102, 156],
|
52 |
+
"instances": False,
|
53 |
+
"readable": "Wall",
|
54 |
+
"name": "construction--barrier--wall",
|
55 |
+
"evaluate": True,
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"color": [128, 64, 255],
|
59 |
+
"instances": False,
|
60 |
+
"readable": "Bike Lane",
|
61 |
+
"name": "construction--flat--bike-lane",
|
62 |
+
"evaluate": True,
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"color": [140, 140, 200],
|
66 |
+
"instances": True,
|
67 |
+
"readable": "Crosswalk - Plain",
|
68 |
+
"name": "construction--flat--crosswalk-plain",
|
69 |
+
"evaluate": True,
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"color": [170, 170, 170],
|
73 |
+
"instances": False,
|
74 |
+
"readable": "Curb Cut",
|
75 |
+
"name": "construction--flat--curb-cut",
|
76 |
+
"evaluate": True,
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"color": [250, 170, 160],
|
80 |
+
"instances": False,
|
81 |
+
"readable": "Parking",
|
82 |
+
"name": "construction--flat--parking",
|
83 |
+
"evaluate": True,
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"color": [96, 96, 96],
|
87 |
+
"instances": False,
|
88 |
+
"readable": "Pedestrian Area",
|
89 |
+
"name": "construction--flat--pedestrian-area",
|
90 |
+
"evaluate": True,
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"color": [230, 150, 140],
|
94 |
+
"instances": False,
|
95 |
+
"readable": "Rail Track",
|
96 |
+
"name": "construction--flat--rail-track",
|
97 |
+
"evaluate": True,
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"color": [128, 64, 128],
|
101 |
+
"instances": False,
|
102 |
+
"readable": "Road",
|
103 |
+
"name": "construction--flat--road",
|
104 |
+
"evaluate": True,
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"color": [110, 110, 110],
|
108 |
+
"instances": False,
|
109 |
+
"readable": "Service Lane",
|
110 |
+
"name": "construction--flat--service-lane",
|
111 |
+
"evaluate": True,
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"color": [244, 35, 232],
|
115 |
+
"instances": False,
|
116 |
+
"readable": "Sidewalk",
|
117 |
+
"name": "construction--flat--sidewalk",
|
118 |
+
"evaluate": True,
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"color": [150, 100, 100],
|
122 |
+
"instances": False,
|
123 |
+
"readable": "Bridge",
|
124 |
+
"name": "construction--structure--bridge",
|
125 |
+
"evaluate": True,
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"color": [70, 70, 70],
|
129 |
+
"instances": False,
|
130 |
+
"readable": "Building",
|
131 |
+
"name": "construction--structure--building",
|
132 |
+
"evaluate": True,
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"color": [150, 120, 90],
|
136 |
+
"instances": False,
|
137 |
+
"readable": "Tunnel",
|
138 |
+
"name": "construction--structure--tunnel",
|
139 |
+
"evaluate": True,
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"color": [220, 20, 60],
|
143 |
+
"instances": True,
|
144 |
+
"readable": "Person",
|
145 |
+
"name": "human--person",
|
146 |
+
"evaluate": True,
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"color": [255, 0, 0],
|
150 |
+
"instances": True,
|
151 |
+
"readable": "Bicyclist",
|
152 |
+
"name": "human--rider--bicyclist",
|
153 |
+
"evaluate": True,
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"color": [255, 0, 100],
|
157 |
+
"instances": True,
|
158 |
+
"readable": "Motorcyclist",
|
159 |
+
"name": "human--rider--motorcyclist",
|
160 |
+
"evaluate": True,
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"color": [255, 0, 200],
|
164 |
+
"instances": True,
|
165 |
+
"readable": "Other Rider",
|
166 |
+
"name": "human--rider--other-rider",
|
167 |
+
"evaluate": True,
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"color": [200, 128, 128],
|
171 |
+
"instances": True,
|
172 |
+
"readable": "Lane Marking - Crosswalk",
|
173 |
+
"name": "marking--crosswalk-zebra",
|
174 |
+
"evaluate": True,
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"color": [255, 255, 255],
|
178 |
+
"instances": False,
|
179 |
+
"readable": "Lane Marking - General",
|
180 |
+
"name": "marking--general",
|
181 |
+
"evaluate": True,
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"color": [64, 170, 64],
|
185 |
+
"instances": False,
|
186 |
+
"readable": "Mountain",
|
187 |
+
"name": "nature--mountain",
|
188 |
+
"evaluate": True,
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"color": [230, 160, 50],
|
192 |
+
"instances": False,
|
193 |
+
"readable": "Sand",
|
194 |
+
"name": "nature--sand",
|
195 |
+
"evaluate": True,
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"color": [70, 130, 180],
|
199 |
+
"instances": False,
|
200 |
+
"readable": "Sky",
|
201 |
+
"name": "nature--sky",
|
202 |
+
"evaluate": True,
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"color": [190, 255, 255],
|
206 |
+
"instances": False,
|
207 |
+
"readable": "Snow",
|
208 |
+
"name": "nature--snow",
|
209 |
+
"evaluate": True,
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"color": [152, 251, 152],
|
213 |
+
"instances": False,
|
214 |
+
"readable": "Terrain",
|
215 |
+
"name": "nature--terrain",
|
216 |
+
"evaluate": True,
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"color": [107, 142, 35],
|
220 |
+
"instances": False,
|
221 |
+
"readable": "Vegetation",
|
222 |
+
"name": "nature--vegetation",
|
223 |
+
"evaluate": True,
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"color": [0, 170, 30],
|
227 |
+
"instances": False,
|
228 |
+
"readable": "Water",
|
229 |
+
"name": "nature--water",
|
230 |
+
"evaluate": True,
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"color": [255, 255, 128],
|
234 |
+
"instances": True,
|
235 |
+
"readable": "Banner",
|
236 |
+
"name": "object--banner",
|
237 |
+
"evaluate": True,
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"color": [250, 0, 30],
|
241 |
+
"instances": True,
|
242 |
+
"readable": "Bench",
|
243 |
+
"name": "object--bench",
|
244 |
+
"evaluate": True,
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"color": [100, 140, 180],
|
248 |
+
"instances": True,
|
249 |
+
"readable": "Bike Rack",
|
250 |
+
"name": "object--bike-rack",
|
251 |
+
"evaluate": True,
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"color": [220, 220, 220],
|
255 |
+
"instances": True,
|
256 |
+
"readable": "Billboard",
|
257 |
+
"name": "object--billboard",
|
258 |
+
"evaluate": True,
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"color": [220, 128, 128],
|
262 |
+
"instances": True,
|
263 |
+
"readable": "Catch Basin",
|
264 |
+
"name": "object--catch-basin",
|
265 |
+
"evaluate": True,
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"color": [222, 40, 40],
|
269 |
+
"instances": True,
|
270 |
+
"readable": "CCTV Camera",
|
271 |
+
"name": "object--cctv-camera",
|
272 |
+
"evaluate": True,
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"color": [100, 170, 30],
|
276 |
+
"instances": True,
|
277 |
+
"readable": "Fire Hydrant",
|
278 |
+
"name": "object--fire-hydrant",
|
279 |
+
"evaluate": True,
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"color": [40, 40, 40],
|
283 |
+
"instances": True,
|
284 |
+
"readable": "Junction Box",
|
285 |
+
"name": "object--junction-box",
|
286 |
+
"evaluate": True,
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"color": [33, 33, 33],
|
290 |
+
"instances": True,
|
291 |
+
"readable": "Mailbox",
|
292 |
+
"name": "object--mailbox",
|
293 |
+
"evaluate": True,
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"color": [100, 128, 160],
|
297 |
+
"instances": True,
|
298 |
+
"readable": "Manhole",
|
299 |
+
"name": "object--manhole",
|
300 |
+
"evaluate": True,
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"color": [142, 0, 0],
|
304 |
+
"instances": True,
|
305 |
+
"readable": "Phone Booth",
|
306 |
+
"name": "object--phone-booth",
|
307 |
+
"evaluate": True,
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"color": [70, 100, 150],
|
311 |
+
"instances": False,
|
312 |
+
"readable": "Pothole",
|
313 |
+
"name": "object--pothole",
|
314 |
+
"evaluate": True,
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"color": [210, 170, 100],
|
318 |
+
"instances": True,
|
319 |
+
"readable": "Street Light",
|
320 |
+
"name": "object--street-light",
|
321 |
+
"evaluate": True,
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"color": [153, 153, 153],
|
325 |
+
"instances": True,
|
326 |
+
"readable": "Pole",
|
327 |
+
"name": "object--support--pole",
|
328 |
+
"evaluate": True,
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"color": [128, 128, 128],
|
332 |
+
"instances": True,
|
333 |
+
"readable": "Traffic Sign Frame",
|
334 |
+
"name": "object--support--traffic-sign-frame",
|
335 |
+
"evaluate": True,
|
336 |
+
},
|
337 |
+
{
|
338 |
+
"color": [0, 0, 80],
|
339 |
+
"instances": True,
|
340 |
+
"readable": "Utility Pole",
|
341 |
+
"name": "object--support--utility-pole",
|
342 |
+
"evaluate": True,
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"color": [250, 170, 30],
|
346 |
+
"instances": True,
|
347 |
+
"readable": "Traffic Light",
|
348 |
+
"name": "object--traffic-light",
|
349 |
+
"evaluate": True,
|
350 |
+
},
|
351 |
+
{
|
352 |
+
"color": [192, 192, 192],
|
353 |
+
"instances": True,
|
354 |
+
"readable": "Traffic Sign (Back)",
|
355 |
+
"name": "object--traffic-sign--back",
|
356 |
+
"evaluate": True,
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"color": [220, 220, 0],
|
360 |
+
"instances": True,
|
361 |
+
"readable": "Traffic Sign (Front)",
|
362 |
+
"name": "object--traffic-sign--front",
|
363 |
+
"evaluate": True,
|
364 |
+
},
|
365 |
+
{
|
366 |
+
"color": [140, 140, 20],
|
367 |
+
"instances": True,
|
368 |
+
"readable": "Trash Can",
|
369 |
+
"name": "object--trash-can",
|
370 |
+
"evaluate": True,
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"color": [119, 11, 32],
|
374 |
+
"instances": True,
|
375 |
+
"readable": "Bicycle",
|
376 |
+
"name": "object--vehicle--bicycle",
|
377 |
+
"evaluate": True,
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"color": [150, 0, 255],
|
381 |
+
"instances": True,
|
382 |
+
"readable": "Boat",
|
383 |
+
"name": "object--vehicle--boat",
|
384 |
+
"evaluate": True,
|
385 |
+
},
|
386 |
+
{
|
387 |
+
"color": [0, 60, 100],
|
388 |
+
"instances": True,
|
389 |
+
"readable": "Bus",
|
390 |
+
"name": "object--vehicle--bus",
|
391 |
+
"evaluate": True,
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"color": [0, 0, 142],
|
395 |
+
"instances": True,
|
396 |
+
"readable": "Car",
|
397 |
+
"name": "object--vehicle--car",
|
398 |
+
"evaluate": True,
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"color": [0, 0, 90],
|
402 |
+
"instances": True,
|
403 |
+
"readable": "Caravan",
|
404 |
+
"name": "object--vehicle--caravan",
|
405 |
+
"evaluate": True,
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"color": [0, 0, 230],
|
409 |
+
"instances": True,
|
410 |
+
"readable": "Motorcycle",
|
411 |
+
"name": "object--vehicle--motorcycle",
|
412 |
+
"evaluate": True,
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"color": [0, 80, 100],
|
416 |
+
"instances": False,
|
417 |
+
"readable": "On Rails",
|
418 |
+
"name": "object--vehicle--on-rails",
|
419 |
+
"evaluate": True,
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"color": [128, 64, 64],
|
423 |
+
"instances": True,
|
424 |
+
"readable": "Other Vehicle",
|
425 |
+
"name": "object--vehicle--other-vehicle",
|
426 |
+
"evaluate": True,
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"color": [0, 0, 110],
|
430 |
+
"instances": True,
|
431 |
+
"readable": "Trailer",
|
432 |
+
"name": "object--vehicle--trailer",
|
433 |
+
"evaluate": True,
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"color": [0, 0, 70],
|
437 |
+
"instances": True,
|
438 |
+
"readable": "Truck",
|
439 |
+
"name": "object--vehicle--truck",
|
440 |
+
"evaluate": True,
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"color": [0, 0, 192],
|
444 |
+
"instances": True,
|
445 |
+
"readable": "Wheeled Slow",
|
446 |
+
"name": "object--vehicle--wheeled-slow",
|
447 |
+
"evaluate": True,
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"color": [32, 32, 32],
|
451 |
+
"instances": False,
|
452 |
+
"readable": "Car Mount",
|
453 |
+
"name": "void--car-mount",
|
454 |
+
"evaluate": True,
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"color": [120, 10, 10],
|
458 |
+
"instances": False,
|
459 |
+
"readable": "Ego Vehicle",
|
460 |
+
"name": "void--ego-vehicle",
|
461 |
+
"evaluate": True,
|
462 |
+
},
|
463 |
+
{
|
464 |
+
"color": [0, 0, 0],
|
465 |
+
"instances": False,
|
466 |
+
"readable": "Unlabeled",
|
467 |
+
"name": "void--unlabeled",
|
468 |
+
"evaluate": False,
|
469 |
+
},
|
470 |
+
]
|
471 |
+
|
472 |
+
|
473 |
+
def _get_mapillary_vistas_meta():
|
474 |
+
stuff_classes = [k["readable"] for k in MAPILLARY_VISTAS_SEM_SEG_CATEGORIES if k["evaluate"]]
|
475 |
+
assert len(stuff_classes) == 65
|
476 |
+
|
477 |
+
stuff_colors = [k["color"] for k in MAPILLARY_VISTAS_SEM_SEG_CATEGORIES if k["evaluate"]]
|
478 |
+
assert len(stuff_colors) == 65
|
479 |
+
|
480 |
+
ret = {
|
481 |
+
"stuff_classes": stuff_classes,
|
482 |
+
"stuff_colors": stuff_colors,
|
483 |
+
}
|
484 |
+
return ret
|
485 |
+
|
486 |
+
|
487 |
+
def register_all_mapillary_vistas(root):
|
488 |
+
root = os.path.join(root, "mapillary_vistas")
|
489 |
+
meta = _get_mapillary_vistas_meta()
|
490 |
+
for name, dirname in [("train", "training"), ("val", "validation")]:
|
491 |
+
image_dir = os.path.join(root, dirname, "images")
|
492 |
+
gt_dir = os.path.join(root, dirname, "labels")
|
493 |
+
name = f"mapillary_vistas_sem_seg_{name}"
|
494 |
+
DatasetCatalog.register(
|
495 |
+
name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
|
496 |
+
)
|
497 |
+
MetadataCatalog.get(name).set(
|
498 |
+
image_root=image_dir,
|
499 |
+
sem_seg_root=gt_dir,
|
500 |
+
evaluator_type="sem_seg",
|
501 |
+
ignore_label=65, # different from other datasets, Mapillary Vistas sets ignore_label to 65
|
502 |
+
**meta,
|
503 |
+
)
|
504 |
+
|
505 |
+
|
506 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
507 |
+
register_all_mapillary_vistas(_root)
|
mask_former/mask_former_model.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from detectron2.config import configurable
|
6 |
+
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
|
7 |
+
from detectron2.modeling.backbone import Backbone
|
8 |
+
from detectron2.structures import ImageList
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from torchvision.transforms import functional as Ftv
|
12 |
+
|
13 |
+
from utils.log import getLogger
|
14 |
+
from .modeling.criterion import SetCriterion
|
15 |
+
from .modeling.matcher import HungarianMatcher
|
16 |
+
|
17 |
+
logger = getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def interpolate_or_crop(img,
|
21 |
+
size=(128, 128),
|
22 |
+
mode="bilinear",
|
23 |
+
align_corners=False,
|
24 |
+
tol=1.1):
|
25 |
+
h, w = img.shape[-2:]
|
26 |
+
H, W = size
|
27 |
+
if h == H and w == W:
|
28 |
+
return img
|
29 |
+
if H <= h < tol * H and W <= w < tol * W:
|
30 |
+
logger.info_once(f"Using center cropping instead of interpolation")
|
31 |
+
return Ftv.center_crop(img, output_size=size)
|
32 |
+
return F.interpolate(img, size=size, mode=mode, align_corners=align_corners)
|
33 |
+
|
34 |
+
|
35 |
+
@META_ARCH_REGISTRY.register()
|
36 |
+
class MaskFormer(nn.Module):
|
37 |
+
"""
|
38 |
+
Main class for mask classification semantic segmentation architectures.
|
39 |
+
"""
|
40 |
+
|
41 |
+
@configurable
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
*,
|
45 |
+
backbone: Backbone,
|
46 |
+
sem_seg_head: nn.Module,
|
47 |
+
criterion: nn.Module,
|
48 |
+
num_queries: int,
|
49 |
+
panoptic_on: bool,
|
50 |
+
object_mask_threshold: float,
|
51 |
+
overlap_threshold: float,
|
52 |
+
metadata,
|
53 |
+
size_divisibility: int,
|
54 |
+
sem_seg_postprocess_before_inference: bool,
|
55 |
+
pixel_mean: Tuple[float],
|
56 |
+
pixel_std: Tuple[float],
|
57 |
+
crop_not_upsample: bool=True
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
62 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
63 |
+
criterion: a module that defines the loss
|
64 |
+
num_queries: int, number of queries
|
65 |
+
panoptic_on: bool, whether to output panoptic segmentation prediction
|
66 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
67 |
+
for panoptic segmentation inference
|
68 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
69 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
70 |
+
segmentation inference
|
71 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
72 |
+
specific integer. We can use this to override such requirement.
|
73 |
+
sem_seg_postprocess_before_inference: whether to resize the prediction back
|
74 |
+
to original input size before semantic segmentation inference or after.
|
75 |
+
For high-resolution dataset like Mapillary, resizing predictions before
|
76 |
+
inference will cause OOM error.
|
77 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
78 |
+
the per-channel mean and std to be used to normalize the input image
|
79 |
+
"""
|
80 |
+
super().__init__()
|
81 |
+
self.crop_not_upsample = crop_not_upsample
|
82 |
+
self.backbone = backbone
|
83 |
+
self.sem_seg_head = sem_seg_head
|
84 |
+
self.criterion = criterion
|
85 |
+
self.num_queries = num_queries
|
86 |
+
self.overlap_threshold = overlap_threshold
|
87 |
+
self.panoptic_on = panoptic_on
|
88 |
+
self.object_mask_threshold = object_mask_threshold
|
89 |
+
self.metadata = metadata
|
90 |
+
if size_divisibility < 0:
|
91 |
+
# use backbone size_divisibility if not set
|
92 |
+
size_divisibility = self.backbone.size_divisibility
|
93 |
+
self.size_divisibility = size_divisibility
|
94 |
+
self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
|
95 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
96 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
97 |
+
|
98 |
+
@classmethod
|
99 |
+
def from_config(cls, cfg):
|
100 |
+
backbone = build_backbone(cfg)
|
101 |
+
out_shape = backbone.output_shape()
|
102 |
+
if len(cfg.GWM.SAMPLE_KEYS) > 1:
|
103 |
+
for k, v in out_shape.items():
|
104 |
+
out_shape[k] = v._replace(channels=v.channels * len(cfg.GWM.SAMPLE_KEYS))
|
105 |
+
sem_seg_head = build_sem_seg_head(cfg, out_shape)
|
106 |
+
|
107 |
+
# Loss parameters:
|
108 |
+
deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
|
109 |
+
no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
|
110 |
+
dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
|
111 |
+
mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
|
112 |
+
|
113 |
+
# building criterion
|
114 |
+
matcher = HungarianMatcher(
|
115 |
+
cost_class=1,
|
116 |
+
cost_mask=mask_weight,
|
117 |
+
cost_dice=dice_weight,
|
118 |
+
)
|
119 |
+
|
120 |
+
weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight}
|
121 |
+
if deep_supervision:
|
122 |
+
dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
|
123 |
+
aux_weight_dict = {}
|
124 |
+
for i in range(dec_layers - 1):
|
125 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
126 |
+
weight_dict.update(aux_weight_dict)
|
127 |
+
|
128 |
+
losses = ["labels", "masks"]
|
129 |
+
|
130 |
+
criterion = SetCriterion(
|
131 |
+
sem_seg_head.num_classes,
|
132 |
+
matcher=matcher,
|
133 |
+
weight_dict=weight_dict,
|
134 |
+
eos_coef=no_object_weight,
|
135 |
+
losses=losses,
|
136 |
+
)
|
137 |
+
|
138 |
+
return {
|
139 |
+
"backbone": backbone,
|
140 |
+
"sem_seg_head": sem_seg_head,
|
141 |
+
"criterion": criterion,
|
142 |
+
"num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
|
143 |
+
"panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
|
144 |
+
"object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
|
145 |
+
"overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
|
146 |
+
"metadata": None, # MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
|
147 |
+
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
|
148 |
+
"sem_seg_postprocess_before_inference": (
|
149 |
+
cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
|
150 |
+
or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
|
151 |
+
),
|
152 |
+
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
|
153 |
+
"pixel_std": cfg.MODEL.PIXEL_STD,
|
154 |
+
'crop_not_upsample': cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME != 'BasePixelDecoder'
|
155 |
+
}
|
156 |
+
|
157 |
+
@property
|
158 |
+
def device(self):
|
159 |
+
return self.pixel_mean.device
|
160 |
+
|
161 |
+
def forward(self, batched_inputs):
|
162 |
+
"""
|
163 |
+
Args:
|
164 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
165 |
+
Each item in the list contains the inputs for one image.
|
166 |
+
For now, each item in the list is a dict that contains:
|
167 |
+
* "image": Tensor, image in (C, H, W) format.
|
168 |
+
* "instances": per-region ground truth
|
169 |
+
* Other information that's included in the original dicts, such as:
|
170 |
+
"height", "width" (int): the output resolution of the model (may be different
|
171 |
+
from input resolution), used in inference.
|
172 |
+
Returns:
|
173 |
+
list[dict]:
|
174 |
+
each dict has the results for one image. The dict contains the following keys:
|
175 |
+
|
176 |
+
* "sem_seg":
|
177 |
+
A Tensor that represents the
|
178 |
+
per-pixel segmentation prediced by the head.
|
179 |
+
The prediction has shape KxHxW that represents the logits of
|
180 |
+
each class for each pixel.
|
181 |
+
* "panoptic_seg":
|
182 |
+
A tuple that represent panoptic output
|
183 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
184 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
185 |
+
Each dict contains keys "id", "category_id", "isthing".
|
186 |
+
"""
|
187 |
+
return self.forward_base(batched_inputs, keys=["image"], get_train=not self.training,
|
188 |
+
get_eval=not self.training)
|
189 |
+
|
190 |
+
def forward_base(self, batched_inputs, keys, get_train=False, get_eval=False, raw_sem_seg=False):
|
191 |
+
for i, key in enumerate(keys):
|
192 |
+
images = [x[key].to(self.device) for x in batched_inputs]
|
193 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
194 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
195 |
+
logger.debug_once(f"Maskformer input {key} shape: {images.tensor.shape}")
|
196 |
+
out = self.backbone(images.tensor)
|
197 |
+
if i == 0:
|
198 |
+
features = out
|
199 |
+
else:
|
200 |
+
features = {k: torch.cat([features[k], v], 1) for k, v in out.items()}
|
201 |
+
outputs = self.sem_seg_head(features)
|
202 |
+
|
203 |
+
if get_train:
|
204 |
+
# mask classification target
|
205 |
+
if "instances" in batched_inputs[0]:
|
206 |
+
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
207 |
+
targets = self.prepare_targets(gt_instances, images)
|
208 |
+
else:
|
209 |
+
targets = None
|
210 |
+
|
211 |
+
# bipartite matching-based loss
|
212 |
+
losses = self.criterion(outputs, targets)
|
213 |
+
|
214 |
+
for k in list(losses.keys()):
|
215 |
+
if k in self.criterion.weight_dict:
|
216 |
+
losses[k] *= self.criterion.weight_dict[k]
|
217 |
+
else:
|
218 |
+
# remove this loss if not specified in `weight_dict`
|
219 |
+
losses.pop(k)
|
220 |
+
if not get_eval:
|
221 |
+
return losses
|
222 |
+
|
223 |
+
if get_eval:
|
224 |
+
# mask_cls_results = outputs["pred_logits"]
|
225 |
+
mask_pred_results = outputs["pred_masks"]
|
226 |
+
mask_cls_results = mask_pred_results
|
227 |
+
logger.debug_once(f"Maskformer mask_pred_results shape: {mask_pred_results.shape}")
|
228 |
+
# upsample masks
|
229 |
+
# mask_pred_results = interpolate_or_crop(
|
230 |
+
# mask_pred_results,
|
231 |
+
# size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
232 |
+
# mode="bilinear",
|
233 |
+
# align_corners=False,
|
234 |
+
# )
|
235 |
+
|
236 |
+
processed_results = []
|
237 |
+
for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
238 |
+
mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
|
239 |
+
):
|
240 |
+
|
241 |
+
if raw_sem_seg:
|
242 |
+
processed_results.append({"sem_seg": mask_pred_result})
|
243 |
+
continue
|
244 |
+
|
245 |
+
height = input_per_image.get("height", image_size[0])
|
246 |
+
width = input_per_image.get("width", image_size[1])
|
247 |
+
logger.debug_once(f"Maskformer mask_pred_results target HW: {height, width}")
|
248 |
+
r = interpolate_or_crop(mask_pred_result[None], size=(height, width), mode="bilinear", align_corners=False)[0]
|
249 |
+
|
250 |
+
processed_results.append({"sem_seg": r})
|
251 |
+
|
252 |
+
# panoptic segmentation inference
|
253 |
+
# if self.panoptic_on:
|
254 |
+
# panoptic_r = self.panoptic_inference(mask_cls_result, mask_pred_result)
|
255 |
+
# processed_results[-1]["panoptic_seg"] = panoptic_r
|
256 |
+
|
257 |
+
# if 'features' in outputs:
|
258 |
+
# features = outputs['features']
|
259 |
+
# features = interpolate_or_crop(
|
260 |
+
# features,
|
261 |
+
# size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
262 |
+
# mode="bilinear",
|
263 |
+
# align_corners=False,
|
264 |
+
# )
|
265 |
+
# for res, f in zip(processed_results, features):
|
266 |
+
# res['features'] = f
|
267 |
+
del outputs
|
268 |
+
|
269 |
+
if not get_train:
|
270 |
+
return processed_results
|
271 |
+
|
272 |
+
return losses, processed_results
|
273 |
+
|
274 |
+
|
275 |
+
def prepare_targets(self, targets, images):
|
276 |
+
h, w = images.tensor.shape[-2:]
|
277 |
+
new_targets = []
|
278 |
+
for targets_per_image in targets:
|
279 |
+
# pad gt
|
280 |
+
gt_masks = targets_per_image.gt_masks
|
281 |
+
padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device)
|
282 |
+
padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
283 |
+
new_targets.append(
|
284 |
+
{
|
285 |
+
"labels": targets_per_image.gt_classes,
|
286 |
+
"masks": padded_masks,
|
287 |
+
}
|
288 |
+
)
|
289 |
+
return new_targets
|
290 |
+
|
291 |
+
|
292 |
+
def semantic_inference(self, mask_cls, mask_pred):
|
293 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
294 |
+
mask_pred = mask_pred.sigmoid()
|
295 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
296 |
+
return semseg
|
297 |
+
|
298 |
+
|
299 |
+
def panoptic_inference(self, mask_cls, mask_pred):
|
300 |
+
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
|
301 |
+
mask_pred = mask_pred.sigmoid()
|
302 |
+
|
303 |
+
keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
|
304 |
+
cur_scores = scores[keep]
|
305 |
+
cur_classes = labels[keep]
|
306 |
+
cur_masks = mask_pred[keep]
|
307 |
+
cur_mask_cls = mask_cls[keep]
|
308 |
+
cur_mask_cls = cur_mask_cls[:, :-1]
|
309 |
+
|
310 |
+
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
311 |
+
|
312 |
+
h, w = cur_masks.shape[-2:]
|
313 |
+
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
314 |
+
segments_info = []
|
315 |
+
|
316 |
+
current_segment_id = 0
|
317 |
+
|
318 |
+
if cur_masks.shape[0] == 0:
|
319 |
+
# We didn't detect any mask :(
|
320 |
+
return panoptic_seg, segments_info
|
321 |
+
else:
|
322 |
+
# take argmax
|
323 |
+
cur_mask_ids = cur_prob_masks.argmax(0)
|
324 |
+
stuff_memory_list = {}
|
325 |
+
for k in range(cur_classes.shape[0]):
|
326 |
+
pred_class = cur_classes[k].item()
|
327 |
+
isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
328 |
+
mask = cur_mask_ids == k
|
329 |
+
mask_area = mask.sum().item()
|
330 |
+
original_area = (cur_masks[k] >= 0.5).sum().item()
|
331 |
+
|
332 |
+
if mask_area > 0 and original_area > 0:
|
333 |
+
if mask_area / original_area < self.overlap_threshold:
|
334 |
+
continue
|
335 |
+
|
336 |
+
# merge stuff regions
|
337 |
+
if not isthing:
|
338 |
+
if int(pred_class) in stuff_memory_list.keys():
|
339 |
+
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
340 |
+
continue
|
341 |
+
else:
|
342 |
+
stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
343 |
+
|
344 |
+
current_segment_id += 1
|
345 |
+
panoptic_seg[mask] = current_segment_id
|
346 |
+
|
347 |
+
segments_info.append(
|
348 |
+
{
|
349 |
+
"id": current_segment_id,
|
350 |
+
"isthing": bool(isthing),
|
351 |
+
"category_id": int(pred_class),
|
352 |
+
}
|
353 |
+
)
|
354 |
+
|
355 |
+
return panoptic_seg, segments_info
|
mask_former/modeling/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .backbone.swin import D2SwinTransformer
|
3 |
+
from .backbone.vit import D2ViTTransformer
|
4 |
+
from .heads.mask_former_head import MaskFormerHead
|
5 |
+
from .heads.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead
|
6 |
+
from .heads.pixel_decoder import BasePixelDecoder
|
7 |
+
from .heads.big_pixel_decoder import BigPixelDecoder
|
8 |
+
from .heads.mega_big_pixel_decoder import MegaBigPixelDecoder
|
9 |
+
from .heads.mask_former_head_baseline import MaskFormerBaselineHead
|
mask_former/modeling/backbone/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
mask_former/modeling/backbone/swin.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
9 |
+
# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
|
10 |
+
import logging
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.utils.checkpoint as checkpoint
|
17 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
18 |
+
|
19 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
20 |
+
logger = logging.getLogger('gwm')
|
21 |
+
|
22 |
+
class Mlp(nn.Module):
|
23 |
+
"""Multilayer perceptron."""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
out_features = out_features or in_features
|
30 |
+
hidden_features = hidden_features or in_features
|
31 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
32 |
+
self.act = act_layer()
|
33 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
34 |
+
self.drop = nn.Dropout(drop)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.fc1(x)
|
38 |
+
x = self.act(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
x = self.fc2(x)
|
41 |
+
x = self.drop(x)
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
def window_partition(x, window_size):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
x: (B, H, W, C)
|
49 |
+
window_size (int): window size
|
50 |
+
Returns:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
"""
|
53 |
+
B, H, W, C = x.shape
|
54 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
55 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
56 |
+
return windows
|
57 |
+
|
58 |
+
|
59 |
+
def window_reverse(windows, window_size, H, W):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
windows: (num_windows*B, window_size, window_size, C)
|
63 |
+
window_size (int): Window size
|
64 |
+
H (int): Height of image
|
65 |
+
W (int): Width of image
|
66 |
+
Returns:
|
67 |
+
x: (B, H, W, C)
|
68 |
+
"""
|
69 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
70 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
71 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class WindowAttention(nn.Module):
|
76 |
+
"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
77 |
+
It supports both of shifted and non-shifted window.
|
78 |
+
Args:
|
79 |
+
dim (int): Number of input channels.
|
80 |
+
window_size (tuple[int]): The height and width of the window.
|
81 |
+
num_heads (int): Number of attention heads.
|
82 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
83 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
84 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
85 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
dim,
|
91 |
+
window_size,
|
92 |
+
num_heads,
|
93 |
+
qkv_bias=True,
|
94 |
+
qk_scale=None,
|
95 |
+
attn_drop=0.0,
|
96 |
+
proj_drop=0.0,
|
97 |
+
):
|
98 |
+
|
99 |
+
super().__init__()
|
100 |
+
self.dim = dim
|
101 |
+
self.window_size = window_size # Wh, Ww
|
102 |
+
self.num_heads = num_heads
|
103 |
+
head_dim = dim // num_heads
|
104 |
+
self.scale = qk_scale or head_dim ** -0.5
|
105 |
+
|
106 |
+
# define a parameter table of relative position bias
|
107 |
+
self.relative_position_bias_table = nn.Parameter(
|
108 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
109 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
110 |
+
|
111 |
+
# get pair-wise relative position index for each token inside the window
|
112 |
+
coords_h = torch.arange(self.window_size[0])
|
113 |
+
coords_w = torch.arange(self.window_size[1])
|
114 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
115 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
116 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
117 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
118 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
119 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
120 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
121 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
122 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
123 |
+
|
124 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
125 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
126 |
+
self.proj = nn.Linear(dim, dim)
|
127 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
128 |
+
|
129 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
130 |
+
self.softmax = nn.Softmax(dim=-1)
|
131 |
+
|
132 |
+
def forward(self, x, mask=None):
|
133 |
+
"""Forward function.
|
134 |
+
Args:
|
135 |
+
x: input features with shape of (num_windows*B, N, C)
|
136 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
137 |
+
"""
|
138 |
+
B_, N, C = x.shape
|
139 |
+
qkv = (
|
140 |
+
self.qkv(x)
|
141 |
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
142 |
+
.permute(2, 0, 3, 1, 4)
|
143 |
+
)
|
144 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
145 |
+
|
146 |
+
q = q * self.scale
|
147 |
+
attn = q @ k.transpose(-2, -1)
|
148 |
+
|
149 |
+
relative_position_bias = self.relative_position_bias_table[
|
150 |
+
self.relative_position_index.view(-1)
|
151 |
+
].view(
|
152 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
153 |
+
) # Wh*Ww,Wh*Ww,nH
|
154 |
+
relative_position_bias = relative_position_bias.permute(
|
155 |
+
2, 0, 1
|
156 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
157 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
158 |
+
|
159 |
+
if mask is not None:
|
160 |
+
nW = mask.shape[0]
|
161 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
162 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
163 |
+
attn = self.softmax(attn)
|
164 |
+
else:
|
165 |
+
attn = self.softmax(attn)
|
166 |
+
|
167 |
+
attn = self.attn_drop(attn)
|
168 |
+
|
169 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
170 |
+
x = self.proj(x)
|
171 |
+
x = self.proj_drop(x)
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class SwinTransformerBlock(nn.Module):
|
176 |
+
"""Swin Transformer Block.
|
177 |
+
Args:
|
178 |
+
dim (int): Number of input channels.
|
179 |
+
num_heads (int): Number of attention heads.
|
180 |
+
window_size (int): Window size.
|
181 |
+
shift_size (int): Shift size for SW-MSA.
|
182 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
183 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
184 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
185 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
186 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
187 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
188 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
189 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
dim,
|
195 |
+
num_heads,
|
196 |
+
window_size=7,
|
197 |
+
shift_size=0,
|
198 |
+
mlp_ratio=4.0,
|
199 |
+
qkv_bias=True,
|
200 |
+
qk_scale=None,
|
201 |
+
drop=0.0,
|
202 |
+
attn_drop=0.0,
|
203 |
+
drop_path=0.0,
|
204 |
+
act_layer=nn.GELU,
|
205 |
+
norm_layer=nn.LayerNorm,
|
206 |
+
):
|
207 |
+
super().__init__()
|
208 |
+
self.dim = dim
|
209 |
+
self.num_heads = num_heads
|
210 |
+
self.window_size = window_size
|
211 |
+
self.shift_size = shift_size
|
212 |
+
self.mlp_ratio = mlp_ratio
|
213 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
214 |
+
|
215 |
+
self.norm1 = norm_layer(dim)
|
216 |
+
self.attn = WindowAttention(
|
217 |
+
dim,
|
218 |
+
window_size=to_2tuple(self.window_size),
|
219 |
+
num_heads=num_heads,
|
220 |
+
qkv_bias=qkv_bias,
|
221 |
+
qk_scale=qk_scale,
|
222 |
+
attn_drop=attn_drop,
|
223 |
+
proj_drop=drop,
|
224 |
+
)
|
225 |
+
|
226 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
227 |
+
self.norm2 = norm_layer(dim)
|
228 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
229 |
+
self.mlp = Mlp(
|
230 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
231 |
+
)
|
232 |
+
|
233 |
+
self.H = None
|
234 |
+
self.W = None
|
235 |
+
|
236 |
+
def forward(self, x, mask_matrix):
|
237 |
+
"""Forward function.
|
238 |
+
Args:
|
239 |
+
x: Input feature, tensor size (B, H*W, C).
|
240 |
+
H, W: Spatial resolution of the input feature.
|
241 |
+
mask_matrix: Attention mask for cyclic shift.
|
242 |
+
"""
|
243 |
+
B, L, C = x.shape
|
244 |
+
H, W = self.H, self.W
|
245 |
+
assert L == H * W, "input feature has wrong size"
|
246 |
+
|
247 |
+
shortcut = x
|
248 |
+
x = self.norm1(x)
|
249 |
+
x = x.view(B, H, W, C)
|
250 |
+
|
251 |
+
# pad feature maps to multiples of window size
|
252 |
+
pad_l = pad_t = 0
|
253 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
254 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
255 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
256 |
+
_, Hp, Wp, _ = x.shape
|
257 |
+
|
258 |
+
# cyclic shift
|
259 |
+
if self.shift_size > 0:
|
260 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
261 |
+
attn_mask = mask_matrix
|
262 |
+
else:
|
263 |
+
shifted_x = x
|
264 |
+
attn_mask = None
|
265 |
+
|
266 |
+
# partition windows
|
267 |
+
x_windows = window_partition(
|
268 |
+
shifted_x, self.window_size
|
269 |
+
) # nW*B, window_size, window_size, C
|
270 |
+
x_windows = x_windows.view(
|
271 |
+
-1, self.window_size * self.window_size, C
|
272 |
+
) # nW*B, window_size*window_size, C
|
273 |
+
|
274 |
+
# W-MSA/SW-MSA
|
275 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
276 |
+
|
277 |
+
# merge windows
|
278 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
279 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
280 |
+
|
281 |
+
# reverse cyclic shift
|
282 |
+
if self.shift_size > 0:
|
283 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
284 |
+
else:
|
285 |
+
x = shifted_x
|
286 |
+
|
287 |
+
if pad_r > 0 or pad_b > 0:
|
288 |
+
x = x[:, :H, :W, :].contiguous()
|
289 |
+
|
290 |
+
x = x.view(B, H * W, C)
|
291 |
+
|
292 |
+
# FFN
|
293 |
+
x = shortcut + self.drop_path(x)
|
294 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
295 |
+
|
296 |
+
return x
|
297 |
+
|
298 |
+
|
299 |
+
class PatchMerging(nn.Module):
|
300 |
+
"""Patch Merging Layer
|
301 |
+
Args:
|
302 |
+
dim (int): Number of input channels.
|
303 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
307 |
+
super().__init__()
|
308 |
+
self.dim = dim
|
309 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
310 |
+
self.norm = norm_layer(4 * dim)
|
311 |
+
|
312 |
+
def forward(self, x, H, W):
|
313 |
+
"""Forward function.
|
314 |
+
Args:
|
315 |
+
x: Input feature, tensor size (B, H*W, C).
|
316 |
+
H, W: Spatial resolution of the input feature.
|
317 |
+
"""
|
318 |
+
B, L, C = x.shape
|
319 |
+
assert L == H * W, "input feature has wrong size"
|
320 |
+
|
321 |
+
x = x.view(B, H, W, C)
|
322 |
+
|
323 |
+
# padding
|
324 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
325 |
+
if pad_input:
|
326 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
327 |
+
|
328 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
329 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
330 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
331 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
332 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
333 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
334 |
+
|
335 |
+
x = self.norm(x)
|
336 |
+
x = self.reduction(x)
|
337 |
+
|
338 |
+
return x
|
339 |
+
|
340 |
+
|
341 |
+
class BasicLayer(nn.Module):
|
342 |
+
"""A basic Swin Transformer layer for one stage.
|
343 |
+
Args:
|
344 |
+
dim (int): Number of feature channels
|
345 |
+
depth (int): Depths of this stage.
|
346 |
+
num_heads (int): Number of attention head.
|
347 |
+
window_size (int): Local window size. Default: 7.
|
348 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
349 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
350 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
351 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
352 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
353 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
354 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
355 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
356 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
357 |
+
"""
|
358 |
+
|
359 |
+
def __init__(
|
360 |
+
self,
|
361 |
+
dim,
|
362 |
+
depth,
|
363 |
+
num_heads,
|
364 |
+
window_size=7,
|
365 |
+
mlp_ratio=4.0,
|
366 |
+
qkv_bias=True,
|
367 |
+
qk_scale=None,
|
368 |
+
drop=0.0,
|
369 |
+
attn_drop=0.0,
|
370 |
+
drop_path=0.0,
|
371 |
+
norm_layer=nn.LayerNorm,
|
372 |
+
downsample=None,
|
373 |
+
use_checkpoint=False,
|
374 |
+
):
|
375 |
+
super().__init__()
|
376 |
+
self.window_size = window_size
|
377 |
+
self.shift_size = window_size // 2
|
378 |
+
self.depth = depth
|
379 |
+
self.use_checkpoint = use_checkpoint
|
380 |
+
|
381 |
+
# build blocks
|
382 |
+
self.blocks = nn.ModuleList(
|
383 |
+
[
|
384 |
+
SwinTransformerBlock(
|
385 |
+
dim=dim,
|
386 |
+
num_heads=num_heads,
|
387 |
+
window_size=window_size,
|
388 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
389 |
+
mlp_ratio=mlp_ratio,
|
390 |
+
qkv_bias=qkv_bias,
|
391 |
+
qk_scale=qk_scale,
|
392 |
+
drop=drop,
|
393 |
+
attn_drop=attn_drop,
|
394 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
395 |
+
norm_layer=norm_layer,
|
396 |
+
)
|
397 |
+
for i in range(depth)
|
398 |
+
]
|
399 |
+
)
|
400 |
+
|
401 |
+
# patch merging layer
|
402 |
+
if downsample is not None:
|
403 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
404 |
+
else:
|
405 |
+
self.downsample = None
|
406 |
+
|
407 |
+
def forward(self, x, H, W):
|
408 |
+
"""Forward function.
|
409 |
+
Args:
|
410 |
+
x: Input feature, tensor size (B, H*W, C).
|
411 |
+
H, W: Spatial resolution of the input feature.
|
412 |
+
"""
|
413 |
+
|
414 |
+
# calculate attention mask for SW-MSA
|
415 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
416 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
417 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
418 |
+
h_slices = (
|
419 |
+
slice(0, -self.window_size),
|
420 |
+
slice(-self.window_size, -self.shift_size),
|
421 |
+
slice(-self.shift_size, None),
|
422 |
+
)
|
423 |
+
w_slices = (
|
424 |
+
slice(0, -self.window_size),
|
425 |
+
slice(-self.window_size, -self.shift_size),
|
426 |
+
slice(-self.shift_size, None),
|
427 |
+
)
|
428 |
+
cnt = 0
|
429 |
+
for h in h_slices:
|
430 |
+
for w in w_slices:
|
431 |
+
img_mask[:, h, w, :] = cnt
|
432 |
+
cnt += 1
|
433 |
+
|
434 |
+
mask_windows = window_partition(
|
435 |
+
img_mask, self.window_size
|
436 |
+
) # nW, window_size, window_size, 1
|
437 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
438 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
439 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
440 |
+
attn_mask == 0, float(0.0)
|
441 |
+
)
|
442 |
+
|
443 |
+
for blk in self.blocks:
|
444 |
+
blk.H, blk.W = H, W
|
445 |
+
if self.use_checkpoint:
|
446 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
447 |
+
else:
|
448 |
+
x = blk(x, attn_mask)
|
449 |
+
if self.downsample is not None:
|
450 |
+
x_down = self.downsample(x, H, W)
|
451 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
452 |
+
return x, H, W, x_down, Wh, Ww
|
453 |
+
else:
|
454 |
+
return x, H, W, x, H, W
|
455 |
+
|
456 |
+
|
457 |
+
class PatchEmbed(nn.Module):
|
458 |
+
"""Image to Patch Embedding
|
459 |
+
Args:
|
460 |
+
patch_size (int): Patch token size. Default: 4.
|
461 |
+
in_chans (int): Number of input image channels. Default: 3.
|
462 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
463 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
464 |
+
"""
|
465 |
+
|
466 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
467 |
+
super().__init__()
|
468 |
+
patch_size = to_2tuple(patch_size)
|
469 |
+
self.patch_size = patch_size
|
470 |
+
|
471 |
+
self.in_chans = in_chans
|
472 |
+
self.embed_dim = embed_dim
|
473 |
+
|
474 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
475 |
+
if norm_layer is not None:
|
476 |
+
self.norm = norm_layer(embed_dim)
|
477 |
+
else:
|
478 |
+
self.norm = None
|
479 |
+
|
480 |
+
def forward(self, x):
|
481 |
+
"""Forward function."""
|
482 |
+
# padding
|
483 |
+
_, _, H, W = x.size()
|
484 |
+
if W % self.patch_size[1] != 0:
|
485 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
486 |
+
if H % self.patch_size[0] != 0:
|
487 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
488 |
+
|
489 |
+
x = self.proj(x) # B C Wh Ww
|
490 |
+
if self.norm is not None:
|
491 |
+
Wh, Ww = x.size(2), x.size(3)
|
492 |
+
x = x.flatten(2).transpose(1, 2)
|
493 |
+
x = self.norm(x)
|
494 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
495 |
+
|
496 |
+
return x
|
497 |
+
|
498 |
+
|
499 |
+
class SwinTransformer(nn.Module):
|
500 |
+
"""Swin Transformer backbone.
|
501 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
502 |
+
https://arxiv.org/pdf/2103.14030
|
503 |
+
Args:
|
504 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
505 |
+
used in absolute postion embedding. Default 224.
|
506 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
507 |
+
in_chans (int): Number of input image channels. Default: 3.
|
508 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
509 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
510 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
511 |
+
window_size (int): Window size. Default: 7.
|
512 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
513 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
514 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
515 |
+
drop_rate (float): Dropout rate.
|
516 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
517 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
518 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
519 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
520 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
521 |
+
out_indices (Sequence[int]): Output from which stages.
|
522 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
523 |
+
-1 means not freezing any parameters.
|
524 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
525 |
+
"""
|
526 |
+
|
527 |
+
def __init__(
|
528 |
+
self,
|
529 |
+
pretrain_img_size=224,
|
530 |
+
patch_size=4,
|
531 |
+
in_chans=3,
|
532 |
+
embed_dim=96,
|
533 |
+
depths=[2, 2, 6, 2],
|
534 |
+
num_heads=[3, 6, 12, 24],
|
535 |
+
window_size=7,
|
536 |
+
mlp_ratio=4.0,
|
537 |
+
qkv_bias=True,
|
538 |
+
qk_scale=None,
|
539 |
+
drop_rate=0.0,
|
540 |
+
attn_drop_rate=0.0,
|
541 |
+
drop_path_rate=0.2,
|
542 |
+
norm_layer=nn.LayerNorm,
|
543 |
+
ape=False,
|
544 |
+
patch_norm=True,
|
545 |
+
out_indices=(0, 1, 2, 3),
|
546 |
+
frozen_stages=-1,
|
547 |
+
use_checkpoint=False,
|
548 |
+
):
|
549 |
+
super().__init__()
|
550 |
+
|
551 |
+
self.pretrain_img_size = pretrain_img_size
|
552 |
+
self.num_layers = len(depths)
|
553 |
+
self.embed_dim = embed_dim
|
554 |
+
self.ape = ape
|
555 |
+
self.patch_norm = patch_norm
|
556 |
+
self.out_indices = out_indices
|
557 |
+
self.frozen_stages = frozen_stages
|
558 |
+
|
559 |
+
# split image into non-overlapping patches
|
560 |
+
self.patch_embed = PatchEmbed(
|
561 |
+
patch_size=patch_size,
|
562 |
+
in_chans=in_chans,
|
563 |
+
embed_dim=embed_dim,
|
564 |
+
norm_layer=norm_layer if self.patch_norm else None,
|
565 |
+
)
|
566 |
+
|
567 |
+
# absolute position embedding
|
568 |
+
if self.ape:
|
569 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
570 |
+
patch_size = to_2tuple(patch_size)
|
571 |
+
patches_resolution = [
|
572 |
+
pretrain_img_size[0] // patch_size[0],
|
573 |
+
pretrain_img_size[1] // patch_size[1],
|
574 |
+
]
|
575 |
+
|
576 |
+
self.absolute_pos_embed = nn.Parameter(
|
577 |
+
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
|
578 |
+
)
|
579 |
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
580 |
+
|
581 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
582 |
+
|
583 |
+
# stochastic depth
|
584 |
+
dpr = [
|
585 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
586 |
+
] # stochastic depth decay rule
|
587 |
+
|
588 |
+
# build layers
|
589 |
+
self.layers = nn.ModuleList()
|
590 |
+
for i_layer in range(self.num_layers):
|
591 |
+
layer = BasicLayer(
|
592 |
+
dim=int(embed_dim * 2 ** i_layer),
|
593 |
+
depth=depths[i_layer],
|
594 |
+
num_heads=num_heads[i_layer],
|
595 |
+
window_size=window_size,
|
596 |
+
mlp_ratio=mlp_ratio,
|
597 |
+
qkv_bias=qkv_bias,
|
598 |
+
qk_scale=qk_scale,
|
599 |
+
drop=drop_rate,
|
600 |
+
attn_drop=attn_drop_rate,
|
601 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
602 |
+
norm_layer=norm_layer,
|
603 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
604 |
+
use_checkpoint=use_checkpoint,
|
605 |
+
)
|
606 |
+
self.layers.append(layer)
|
607 |
+
|
608 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
609 |
+
self.num_features = num_features
|
610 |
+
|
611 |
+
# add a norm layer for each output
|
612 |
+
for i_layer in out_indices:
|
613 |
+
layer = norm_layer(num_features[i_layer])
|
614 |
+
layer_name = f"norm{i_layer}"
|
615 |
+
self.add_module(layer_name, layer)
|
616 |
+
|
617 |
+
self._freeze_stages()
|
618 |
+
logger.info(f"Freezing {self.frozen_stages} Layers")
|
619 |
+
|
620 |
+
def _freeze_stages(self):
|
621 |
+
if self.frozen_stages >= 0:
|
622 |
+
self.patch_embed.eval()
|
623 |
+
for param in self.patch_embed.parameters():
|
624 |
+
param.requires_grad = False
|
625 |
+
|
626 |
+
if self.frozen_stages >= 1 and self.ape:
|
627 |
+
self.absolute_pos_embed.requires_grad = False
|
628 |
+
|
629 |
+
if self.frozen_stages >= 2:
|
630 |
+
self.pos_drop.eval()
|
631 |
+
for i in range(0, self.frozen_stages - 1):
|
632 |
+
m = self.layers[i]
|
633 |
+
m.eval()
|
634 |
+
for param in m.parameters():
|
635 |
+
param.requires_grad = False
|
636 |
+
|
637 |
+
def init_weights(self, pretrained=None):
|
638 |
+
"""Initialize the weights in backbone.
|
639 |
+
Args:
|
640 |
+
pretrained (str, optional): Path to pre-trained weights.
|
641 |
+
Defaults to None.
|
642 |
+
"""
|
643 |
+
|
644 |
+
def _init_weights(m):
|
645 |
+
if isinstance(m, nn.Linear):
|
646 |
+
trunc_normal_(m.weight, std=0.02)
|
647 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
648 |
+
nn.init.constant_(m.bias, 0)
|
649 |
+
elif isinstance(m, nn.LayerNorm):
|
650 |
+
nn.init.constant_(m.bias, 0)
|
651 |
+
nn.init.constant_(m.weight, 1.0)
|
652 |
+
|
653 |
+
def forward(self, x):
|
654 |
+
"""Forward function."""
|
655 |
+
x = self.patch_embed(x)
|
656 |
+
|
657 |
+
Wh, Ww = x.size(2), x.size(3)
|
658 |
+
if self.ape:
|
659 |
+
# interpolate the position embedding to the corresponding size
|
660 |
+
absolute_pos_embed = F.interpolate(
|
661 |
+
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
662 |
+
)
|
663 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
664 |
+
else:
|
665 |
+
x = x.flatten(2).transpose(1, 2)
|
666 |
+
x = self.pos_drop(x)
|
667 |
+
|
668 |
+
outs = {}
|
669 |
+
for i in range(self.num_layers):
|
670 |
+
layer = self.layers[i]
|
671 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
672 |
+
|
673 |
+
if i in self.out_indices:
|
674 |
+
norm_layer = getattr(self, f"norm{i}")
|
675 |
+
x_out = norm_layer(x_out)
|
676 |
+
|
677 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
678 |
+
outs["res{}".format(i + 2)] = out
|
679 |
+
|
680 |
+
return outs
|
681 |
+
|
682 |
+
def train(self, mode=True):
|
683 |
+
"""Convert the model into training mode while keep layers freezed."""
|
684 |
+
super(SwinTransformer, self).train(mode)
|
685 |
+
self._freeze_stages()
|
686 |
+
|
687 |
+
|
688 |
+
@BACKBONE_REGISTRY.register()
|
689 |
+
class D2SwinTransformer(SwinTransformer, Backbone):
|
690 |
+
def __init__(self, cfg, input_shape):
|
691 |
+
|
692 |
+
pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
|
693 |
+
patch_size = cfg.MODEL.SWIN.PATCH_SIZE
|
694 |
+
in_chans = 3
|
695 |
+
embed_dim = cfg.MODEL.SWIN.EMBED_DIM
|
696 |
+
depths = cfg.MODEL.SWIN.DEPTHS
|
697 |
+
num_heads = cfg.MODEL.SWIN.NUM_HEADS
|
698 |
+
window_size = cfg.MODEL.SWIN.WINDOW_SIZE
|
699 |
+
mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
|
700 |
+
qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
|
701 |
+
qk_scale = cfg.MODEL.SWIN.QK_SCALE
|
702 |
+
drop_rate = cfg.MODEL.SWIN.DROP_RATE
|
703 |
+
attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
|
704 |
+
drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
|
705 |
+
norm_layer = nn.LayerNorm
|
706 |
+
ape = cfg.MODEL.SWIN.APE
|
707 |
+
patch_norm = cfg.MODEL.SWIN.PATCH_NORM
|
708 |
+
frozen_stages = cfg.MODEL.BACKBONE.FREEZE_AT
|
709 |
+
|
710 |
+
super().__init__(
|
711 |
+
pretrain_img_size,
|
712 |
+
patch_size,
|
713 |
+
in_chans,
|
714 |
+
embed_dim,
|
715 |
+
depths,
|
716 |
+
num_heads,
|
717 |
+
window_size,
|
718 |
+
mlp_ratio,
|
719 |
+
qkv_bias,
|
720 |
+
qk_scale,
|
721 |
+
drop_rate,
|
722 |
+
attn_drop_rate,
|
723 |
+
drop_path_rate,
|
724 |
+
norm_layer,
|
725 |
+
ape,
|
726 |
+
patch_norm,
|
727 |
+
frozen_stages=frozen_stages,
|
728 |
+
)
|
729 |
+
|
730 |
+
self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
|
731 |
+
|
732 |
+
self._out_feature_strides = {
|
733 |
+
"res2": 4,
|
734 |
+
"res3": 8,
|
735 |
+
"res4": 16,
|
736 |
+
"res5": 32,
|
737 |
+
}
|
738 |
+
self._out_feature_channels = {
|
739 |
+
"res2": self.num_features[0],
|
740 |
+
"res3": self.num_features[1],
|
741 |
+
"res4": self.num_features[2],
|
742 |
+
"res5": self.num_features[3],
|
743 |
+
}
|
744 |
+
|
745 |
+
def forward(self, x):
|
746 |
+
"""
|
747 |
+
Args:
|
748 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
749 |
+
Returns:
|
750 |
+
dict[str->Tensor]: names and the corresponding features
|
751 |
+
"""
|
752 |
+
assert (
|
753 |
+
x.dim() == 4
|
754 |
+
), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
755 |
+
outputs = {}
|
756 |
+
y = super().forward(x)
|
757 |
+
for k in y.keys():
|
758 |
+
if k in self._out_features:
|
759 |
+
outputs[k] = y[k]
|
760 |
+
return outputs
|
761 |
+
|
762 |
+
def output_shape(self):
|
763 |
+
return {
|
764 |
+
name: ShapeSpec(
|
765 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
766 |
+
)
|
767 |
+
for name in self._out_features
|
768 |
+
}
|
769 |
+
|
770 |
+
@property
|
771 |
+
def size_divisibility(self):
|
772 |
+
return 32
|
mask_former/modeling/backbone/vit.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
9 |
+
# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
|
10 |
+
import logging
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.utils.checkpoint as checkpoint
|
17 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
18 |
+
|
19 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
20 |
+
logger = logging.getLogger('gwm')
|
21 |
+
import argparse
|
22 |
+
import torch
|
23 |
+
import torchvision.transforms
|
24 |
+
from torch import nn
|
25 |
+
from torchvision import transforms
|
26 |
+
import torch.nn.modules.utils as nn_utils
|
27 |
+
import math
|
28 |
+
import timm
|
29 |
+
import types
|
30 |
+
from pathlib import Path
|
31 |
+
from typing import Union, List, Tuple
|
32 |
+
from PIL import Image
|
33 |
+
import einops
|
34 |
+
|
35 |
+
class ViTExtractor:
|
36 |
+
""" This class facilitates extraction of features, descriptors, and saliency maps from a ViT.
|
37 |
+
|
38 |
+
We use the following notation in the documentation of the module's methods:
|
39 |
+
B - batch size
|
40 |
+
h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW
|
41 |
+
p - patch size of the ViT. either 8 or 16.
|
42 |
+
t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width
|
43 |
+
of the input image.
|
44 |
+
d - the embedding dimension in the ViT.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'):
|
48 |
+
"""
|
49 |
+
:param model_type: A string specifying the type of model to extract from.
|
50 |
+
[dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 |
|
51 |
+
vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]
|
52 |
+
:param stride: stride of first convolution layer. small stride -> higher resolution.
|
53 |
+
:param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor.
|
54 |
+
should be compatible with model_type.
|
55 |
+
"""
|
56 |
+
self.model_type = model_type
|
57 |
+
self.device = device
|
58 |
+
if model is not None:
|
59 |
+
self.model = model
|
60 |
+
else:
|
61 |
+
self.model = ViTExtractor.create_model(model_type)
|
62 |
+
|
63 |
+
self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride)
|
64 |
+
# self.model.eval()
|
65 |
+
self.model.to(self.device)
|
66 |
+
self.p = self.model.patch_embed.patch_size
|
67 |
+
self.stride = self.model.patch_embed.proj.stride
|
68 |
+
|
69 |
+
self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5)
|
70 |
+
self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5)
|
71 |
+
|
72 |
+
self._feats = []
|
73 |
+
self.hook_handlers = []
|
74 |
+
self.load_size = None
|
75 |
+
self.num_patches = None
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def create_model(model_type: str) -> nn.Module:
|
79 |
+
"""
|
80 |
+
:param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 |
|
81 |
+
dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 |
|
82 |
+
vit_base_patch16_224]
|
83 |
+
:return: the model
|
84 |
+
"""
|
85 |
+
if 'dino' in model_type:
|
86 |
+
model = torch.hub.load('facebookresearch/dino:main', model_type)
|
87 |
+
else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
|
88 |
+
temp_model = timm.create_model(model_type, pretrained=True)
|
89 |
+
model_type_dict = {
|
90 |
+
'vit_small_patch16_224': 'dino_vits16',
|
91 |
+
'vit_small_patch8_224': 'dino_vits8',
|
92 |
+
'vit_base_patch16_224': 'dino_vitb16',
|
93 |
+
'vit_base_patch8_224': 'dino_vitb8'
|
94 |
+
}
|
95 |
+
model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type])
|
96 |
+
temp_state_dict = temp_model.state_dict()
|
97 |
+
del temp_state_dict['head.weight']
|
98 |
+
del temp_state_dict['head.bias']
|
99 |
+
model.load_state_dict(temp_state_dict)
|
100 |
+
return model
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]):
|
104 |
+
"""
|
105 |
+
Creates a method for position encoding interpolation.
|
106 |
+
:param patch_size: patch size of the model.
|
107 |
+
:param stride_hw: A tuple containing the new height and width stride respectively.
|
108 |
+
:return: the interpolation method
|
109 |
+
"""
|
110 |
+
def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
|
111 |
+
npatch = x.shape[1] - 1
|
112 |
+
N = self.pos_embed.shape[1] - 1
|
113 |
+
if npatch == N and w == h:
|
114 |
+
return self.pos_embed
|
115 |
+
class_pos_embed = self.pos_embed[:, 0]
|
116 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
117 |
+
dim = x.shape[-1]
|
118 |
+
# compute number of tokens taking stride into account
|
119 |
+
w0 = 1 + (w - patch_size) // stride_hw[1]
|
120 |
+
h0 = 1 + (h - patch_size) // stride_hw[0]
|
121 |
+
assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and
|
122 |
+
stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}"""
|
123 |
+
# we add a small number to avoid floating point error in the interpolation
|
124 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
125 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
126 |
+
patch_pos_embed = nn.functional.interpolate(
|
127 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
128 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
129 |
+
mode='bicubic',
|
130 |
+
align_corners=False, recompute_scale_factor=False
|
131 |
+
)
|
132 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
133 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
134 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
135 |
+
|
136 |
+
return interpolate_pos_encoding
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module:
|
140 |
+
"""
|
141 |
+
change resolution of model output by changing the stride of the patch extraction.
|
142 |
+
:param model: the model to change resolution for.
|
143 |
+
:param stride: the new stride parameter.
|
144 |
+
:return: the adjusted model
|
145 |
+
"""
|
146 |
+
patch_size = model.patch_embed.patch_size
|
147 |
+
if stride == patch_size: # nothing to do
|
148 |
+
return model
|
149 |
+
|
150 |
+
stride = nn_utils._pair(stride)
|
151 |
+
assert all([(patch_size // s_) * s_ == patch_size for s_ in
|
152 |
+
stride]), f'stride {stride} should divide patch_size {patch_size}'
|
153 |
+
|
154 |
+
# fix the stride
|
155 |
+
model.patch_embed.proj.stride = stride
|
156 |
+
# fix the positional encoding code
|
157 |
+
model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model)
|
158 |
+
return model
|
159 |
+
|
160 |
+
def preprocess(self, image_path: Union[str, Path],
|
161 |
+
load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]:
|
162 |
+
"""
|
163 |
+
Preprocesses an image before extraction.
|
164 |
+
:param image_path: path to image to be extracted.
|
165 |
+
:param load_size: optional. Size to resize image before the rest of preprocessing.
|
166 |
+
:return: a tuple containing:
|
167 |
+
(1) the preprocessed image as a tensor to insert the model of shape BxCxHxW.
|
168 |
+
(2) the pil image in relevant dimensions
|
169 |
+
"""
|
170 |
+
pil_image = Image.open(image_path).convert('RGB')
|
171 |
+
if load_size is not None:
|
172 |
+
pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image)
|
173 |
+
prep = transforms.Compose([
|
174 |
+
transforms.ToTensor(),
|
175 |
+
transforms.Normalize(mean=self.mean, std=self.std)
|
176 |
+
])
|
177 |
+
prep_img = prep(pil_image)[None, ...]
|
178 |
+
return prep_img, pil_image
|
179 |
+
|
180 |
+
def _get_hook(self, facet: str):
|
181 |
+
"""
|
182 |
+
generate a hook method for a specific block and facet.
|
183 |
+
"""
|
184 |
+
if facet in ['attn', 'token']:
|
185 |
+
def _hook(model, input, output):
|
186 |
+
self._feats.append(output)
|
187 |
+
return _hook
|
188 |
+
|
189 |
+
if facet == 'query':
|
190 |
+
facet_idx = 0
|
191 |
+
elif facet == 'key':
|
192 |
+
facet_idx = 1
|
193 |
+
elif facet == 'value':
|
194 |
+
facet_idx = 2
|
195 |
+
else:
|
196 |
+
raise TypeError(f"{facet} is not a supported facet.")
|
197 |
+
|
198 |
+
def _inner_hook(module, input, output):
|
199 |
+
input = input[0]
|
200 |
+
B, N, C = input.shape
|
201 |
+
qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
|
202 |
+
self._feats.append(qkv[facet_idx]) #Bxhxtxd
|
203 |
+
return _inner_hook
|
204 |
+
|
205 |
+
def _register_hooks(self, layers: List[int], facet: str) -> None:
|
206 |
+
"""
|
207 |
+
register hook to extract features.
|
208 |
+
:param layers: layers from which to extract features.
|
209 |
+
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
|
210 |
+
"""
|
211 |
+
for block_idx, block in enumerate(self.model.blocks):
|
212 |
+
if block_idx in layers:
|
213 |
+
if facet == 'token':
|
214 |
+
self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet)))
|
215 |
+
elif facet == 'attn':
|
216 |
+
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet)))
|
217 |
+
elif facet in ['key', 'query', 'value']:
|
218 |
+
self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet)))
|
219 |
+
else:
|
220 |
+
raise TypeError(f"{facet} is not a supported facet.")
|
221 |
+
|
222 |
+
def _unregister_hooks(self) -> None:
|
223 |
+
"""
|
224 |
+
unregisters the hooks. should be called after feature extraction.
|
225 |
+
"""
|
226 |
+
for handle in self.hook_handlers:
|
227 |
+
handle.remove()
|
228 |
+
self.hook_handlers = []
|
229 |
+
|
230 |
+
def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]:
|
231 |
+
"""
|
232 |
+
extract features from the model
|
233 |
+
:param batch: batch to extract features for. Has shape BxCxHxW.
|
234 |
+
:param layers: layer to extract. A number between 0 to 11.
|
235 |
+
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
|
236 |
+
:return : tensor of features.
|
237 |
+
if facet is 'key' | 'query' | 'value' has shape Bxhxtxd
|
238 |
+
if facet is 'attn' has shape Bxhxtxt
|
239 |
+
if facet is 'token' has shape Bxtxd
|
240 |
+
"""
|
241 |
+
B, C, H, W = batch.shape
|
242 |
+
self._feats = []
|
243 |
+
self._register_hooks(layers, facet)
|
244 |
+
with torch.no_grad():
|
245 |
+
_ = self.model(batch)
|
246 |
+
self._unregister_hooks()
|
247 |
+
self.load_size = (H, W)
|
248 |
+
self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1])
|
249 |
+
return self._feats
|
250 |
+
|
251 |
+
def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
|
252 |
+
"""
|
253 |
+
create a log-binned descriptor.
|
254 |
+
:param x: tensor of features. Has shape Bxhxtxd.
|
255 |
+
:param hierarchy: how many bin hierarchies to use.
|
256 |
+
"""
|
257 |
+
B = x.shape[0]
|
258 |
+
num_bins = 1 + 8 * hierarchy
|
259 |
+
|
260 |
+
bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh)
|
261 |
+
bin_x = bin_x.permute(0, 2, 1)
|
262 |
+
bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1])
|
263 |
+
# Bx(dxh)xnum_patches[0]xnum_patches[1]
|
264 |
+
sub_desc_dim = bin_x.shape[1]
|
265 |
+
|
266 |
+
avg_pools = []
|
267 |
+
# compute bins of all sizes for all spatial locations.
|
268 |
+
for k in range(0, hierarchy):
|
269 |
+
# avg pooling with kernel 3**kx3**k
|
270 |
+
win_size = 3 ** k
|
271 |
+
avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False)
|
272 |
+
avg_pools.append(avg_pool(bin_x))
|
273 |
+
|
274 |
+
bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device)
|
275 |
+
for y in range(self.num_patches[0]):
|
276 |
+
for x in range(self.num_patches[1]):
|
277 |
+
part_idx = 0
|
278 |
+
# fill all bins for a spatial location (y, x)
|
279 |
+
for k in range(0, hierarchy):
|
280 |
+
kernel_size = 3 ** k
|
281 |
+
for i in range(y - kernel_size, y + kernel_size + 1, kernel_size):
|
282 |
+
for j in range(x - kernel_size, x + kernel_size + 1, kernel_size):
|
283 |
+
if i == y and j == x and k != 0:
|
284 |
+
continue
|
285 |
+
if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]:
|
286 |
+
bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
|
287 |
+
:, :, i, j]
|
288 |
+
else: # handle padding in a more delicate way than zero padding
|
289 |
+
temp_i = max(0, min(i, self.num_patches[0] - 1))
|
290 |
+
temp_j = max(0, min(j, self.num_patches[1] - 1))
|
291 |
+
bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][
|
292 |
+
:, :, temp_i,
|
293 |
+
temp_j]
|
294 |
+
part_idx += 1
|
295 |
+
bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1)
|
296 |
+
# Bx1x(t-1)x(dxh)
|
297 |
+
return bin_x
|
298 |
+
|
299 |
+
def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key',
|
300 |
+
bin: bool = False, include_cls: bool = False) -> torch.Tensor:
|
301 |
+
"""
|
302 |
+
extract descriptors from the model
|
303 |
+
:param batch: batch to extract descriptors for. Has shape BxCxHxW.
|
304 |
+
:param layers: layer to extract. A number between 0 to 11.
|
305 |
+
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token']
|
306 |
+
:param bin: apply log binning to the descriptor. default is False.
|
307 |
+
:return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors.
|
308 |
+
"""
|
309 |
+
assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors.
|
310 |
+
choose from ['key' | 'query' | 'value' | 'token'] """
|
311 |
+
self._extract_features(batch, [layer], facet)
|
312 |
+
x = self._feats[0]
|
313 |
+
if facet == 'token':
|
314 |
+
x.unsqueeze_(dim=1) #Bx1xtxd
|
315 |
+
if not include_cls:
|
316 |
+
x = x[:, :, 1:, :] # remove cls token
|
317 |
+
else:
|
318 |
+
assert not bin, "bin = True and include_cls = True are not supported together, set one of them False."
|
319 |
+
if not bin:
|
320 |
+
desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
|
321 |
+
else:
|
322 |
+
desc = self._log_bin(x)
|
323 |
+
return desc
|
324 |
+
|
325 |
+
def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor:
|
326 |
+
"""
|
327 |
+
extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer
|
328 |
+
in of the CLS token. All values are then normalized to range between 0 and 1.
|
329 |
+
:param batch: batch to extract saliency maps for. Has shape BxCxHxW.
|
330 |
+
:return: a tensor of saliency maps. has shape Bxt-1
|
331 |
+
"""
|
332 |
+
assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type."
|
333 |
+
self._extract_features(batch, [11], 'attn')
|
334 |
+
head_idxs = [0, 2, 4, 5]
|
335 |
+
curr_feats = self._feats[0] #Bxhxtxt
|
336 |
+
cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1)
|
337 |
+
temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0]
|
338 |
+
cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1]
|
339 |
+
return cls_attn_maps
|
340 |
+
|
341 |
+
@BACKBONE_REGISTRY.register()
|
342 |
+
class D2ViTTransformer(Backbone):
|
343 |
+
def __init__(self, cfg, input_shape):
|
344 |
+
|
345 |
+
pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
|
346 |
+
patch_size = cfg.MODEL.SWIN.PATCH_SIZE
|
347 |
+
in_chans = 3
|
348 |
+
embed_dim = cfg.MODEL.SWIN.EMBED_DIM
|
349 |
+
depths = cfg.MODEL.SWIN.DEPTHS
|
350 |
+
num_heads = cfg.MODEL.SWIN.NUM_HEADS
|
351 |
+
window_size = cfg.MODEL.SWIN.WINDOW_SIZE
|
352 |
+
mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
|
353 |
+
qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
|
354 |
+
qk_scale = cfg.MODEL.SWIN.QK_SCALE
|
355 |
+
drop_rate = cfg.MODEL.SWIN.DROP_RATE
|
356 |
+
attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
|
357 |
+
drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
|
358 |
+
norm_layer = nn.LayerNorm
|
359 |
+
ape = cfg.MODEL.SWIN.APE
|
360 |
+
patch_norm = cfg.MODEL.SWIN.PATCH_NORM
|
361 |
+
frozen_stages = cfg.MODEL.BACKBONE.FREEZE_AT
|
362 |
+
|
363 |
+
super().__init__()
|
364 |
+
self.num_layers = 12
|
365 |
+
num_features = [int(embed_dim) for i in range(self.num_layers)]
|
366 |
+
self.num_features = num_features
|
367 |
+
self.frozen_stages = frozen_stages
|
368 |
+
self.extractor = ViTExtractor( model_type='dino_vitb8', stride = 4, model = None, device = cfg.MODEL.DEVICE)
|
369 |
+
if self.frozen_stages >= 0:
|
370 |
+
for block_idx, block in enumerate(self.extractor.model.blocks):
|
371 |
+
if block_idx <= self.frozen_stages:
|
372 |
+
block.eval()
|
373 |
+
for p in block.parameters():
|
374 |
+
p.requires_grad = False
|
375 |
+
|
376 |
+
for block_idx, block in enumerate(self.extractor.model.blocks):
|
377 |
+
if all(p.requires_grad == False for p in block.parameters()):
|
378 |
+
print(f"DINO {block_idx} frozen")
|
379 |
+
|
380 |
+
|
381 |
+
self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
|
382 |
+
|
383 |
+
self._out_feature_strides = {
|
384 |
+
"res2": 4,
|
385 |
+
"res3": 8,
|
386 |
+
"res4": 16,
|
387 |
+
"res5": 32,
|
388 |
+
}
|
389 |
+
self._out_feature_channels = {
|
390 |
+
"res2": self.num_features[0],
|
391 |
+
"res3": self.num_features[1],
|
392 |
+
"res4": self.num_features[2],
|
393 |
+
"res5": self.num_features[3],
|
394 |
+
}
|
395 |
+
|
396 |
+
def forward(self, x):
|
397 |
+
facet = 'key'
|
398 |
+
self.extractor._extract_features(x, [5, 7, 9, 11], facet=facet)
|
399 |
+
res2 = self.extractor._feats[0].unsqueeze_(dim=1) # Bx1xtxd
|
400 |
+
res3 = self.extractor._feats[1].unsqueeze_(dim=1) # Bx1xtxd
|
401 |
+
res4 = self.extractor._feats[2].unsqueeze_(dim=1) # Bx1xtxd
|
402 |
+
res5 = self.extractor._feats[3].unsqueeze_(dim=1) # Bx1xtxd
|
403 |
+
if facet == 'key':
|
404 |
+
res2 = einops.rearrange(res2, 'b c h t d -> b c t (d h)') # Bx1xtxd
|
405 |
+
res3 = einops.rearrange(res3, 'b c h t d -> b c t (d h)') # Bx1xtxd
|
406 |
+
res4 = einops.rearrange(res4, 'b c h t d -> b c t (d h)') # Bx1xtxd
|
407 |
+
res5 = einops.rearrange(res5, 'b c h t d -> b c t (d h)') # Bx1xtxd
|
408 |
+
|
409 |
+
res2 = res2.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
|
410 |
+
res3 = res3.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
|
411 |
+
res4 = res4.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
|
412 |
+
res5 = res5.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh)
|
413 |
+
|
414 |
+
res2 = res2[:, :, 1:, :] # remove cls token
|
415 |
+
res3 = res3[:, :, 1:, :] # remove cls token
|
416 |
+
res4 = res4[:, :, 1:, :] # remove cls token
|
417 |
+
res5 = res5[:, :, 1:, :] # remove cls token
|
418 |
+
|
419 |
+
res2 = res2.reshape(res2.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
|
420 |
+
res3 = res3.reshape(res3.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
|
421 |
+
res4 = res4.reshape(res4.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
|
422 |
+
res5 = res5.reshape(res5.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
|
423 |
+
|
424 |
+
return {
|
425 |
+
"res2": res2,
|
426 |
+
"res3": res3,
|
427 |
+
"res4": res4,
|
428 |
+
"res5": res5,
|
429 |
+
}
|
430 |
+
|
431 |
+
def output_shape(self):
|
432 |
+
return {
|
433 |
+
name: ShapeSpec(
|
434 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
435 |
+
)
|
436 |
+
for name in self._out_features
|
437 |
+
}
|
438 |
+
|
439 |
+
@property
|
440 |
+
def size_divisibility(self):
|
441 |
+
return 32
|
mask_former/modeling/criterion.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
3 |
+
"""
|
4 |
+
MaskFormer criterion.
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from detectron2.utils.comm import get_world_size
|
11 |
+
|
12 |
+
from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
|
13 |
+
|
14 |
+
|
15 |
+
def dice_loss(inputs, targets, num_masks):
|
16 |
+
"""
|
17 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
18 |
+
Args:
|
19 |
+
inputs: A float tensor of arbitrary shape.
|
20 |
+
The predictions for each example.
|
21 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
22 |
+
classification label for each element in inputs
|
23 |
+
(0 for the negative class and 1 for the positive class).
|
24 |
+
"""
|
25 |
+
inputs = inputs.sigmoid()
|
26 |
+
inputs = inputs.flatten(1)
|
27 |
+
numerator = 2 * (inputs * targets).sum(-1)
|
28 |
+
denominator = inputs.sum(-1) + targets.sum(-1)
|
29 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
30 |
+
return loss.sum() / num_masks
|
31 |
+
|
32 |
+
|
33 |
+
def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2):
|
34 |
+
"""
|
35 |
+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
36 |
+
Args:
|
37 |
+
inputs: A float tensor of arbitrary shape.
|
38 |
+
The predictions for each example.
|
39 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
40 |
+
classification label for each element in inputs
|
41 |
+
(0 for the negative class and 1 for the positive class).
|
42 |
+
alpha: (optional) Weighting factor in range (0,1) to balance
|
43 |
+
positive vs negative examples. Default = -1 (no weighting).
|
44 |
+
gamma: Exponent of the modulating factor (1 - p_t) to
|
45 |
+
balance easy vs hard examples.
|
46 |
+
Returns:
|
47 |
+
Loss tensor
|
48 |
+
"""
|
49 |
+
prob = inputs.sigmoid()
|
50 |
+
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
51 |
+
p_t = prob * targets + (1 - prob) * (1 - targets)
|
52 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
53 |
+
|
54 |
+
if alpha >= 0:
|
55 |
+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
56 |
+
loss = alpha_t * loss
|
57 |
+
|
58 |
+
return loss.mean(1).sum() / num_masks
|
59 |
+
|
60 |
+
|
61 |
+
class SetCriterion(nn.Module):
|
62 |
+
"""This class computes the loss for DETR.
|
63 |
+
The process happens in two steps:
|
64 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
65 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
|
69 |
+
"""Create the criterion.
|
70 |
+
Parameters:
|
71 |
+
num_classes: number of object categories, omitting the special no-object category
|
72 |
+
matcher: module able to compute a matching between targets and proposals
|
73 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
74 |
+
eos_coef: relative classification weight applied to the no-object category
|
75 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
76 |
+
"""
|
77 |
+
super().__init__()
|
78 |
+
self.num_classes = num_classes
|
79 |
+
self.matcher = matcher
|
80 |
+
self.weight_dict = weight_dict
|
81 |
+
self.eos_coef = eos_coef
|
82 |
+
self.losses = losses
|
83 |
+
empty_weight = torch.ones(self.num_classes + 1)
|
84 |
+
empty_weight[-1] = self.eos_coef
|
85 |
+
self.register_buffer("empty_weight", empty_weight)
|
86 |
+
|
87 |
+
def loss_labels(self, outputs, targets, indices, num_masks):
|
88 |
+
"""Classification loss (NLL)
|
89 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
90 |
+
"""
|
91 |
+
assert "pred_logits" in outputs
|
92 |
+
src_logits = outputs["pred_logits"]
|
93 |
+
|
94 |
+
idx = self._get_src_permutation_idx(indices)
|
95 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
96 |
+
target_classes = torch.full(
|
97 |
+
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
98 |
+
)
|
99 |
+
target_classes[idx] = target_classes_o
|
100 |
+
|
101 |
+
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
102 |
+
losses = {"loss_ce": loss_ce}
|
103 |
+
return losses
|
104 |
+
|
105 |
+
def loss_masks(self, outputs, targets, indices, num_masks):
|
106 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
107 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
108 |
+
"""
|
109 |
+
assert "pred_masks" in outputs
|
110 |
+
|
111 |
+
src_idx = self._get_src_permutation_idx(indices)
|
112 |
+
tgt_idx = self._get_tgt_permutation_idx(indices)
|
113 |
+
src_masks = outputs["pred_masks"]
|
114 |
+
src_masks = src_masks[src_idx]
|
115 |
+
masks = [t["masks"] for t in targets]
|
116 |
+
# TODO use valid to mask invalid areas due to padding in loss
|
117 |
+
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
118 |
+
target_masks = target_masks.to(src_masks)
|
119 |
+
target_masks = target_masks[tgt_idx]
|
120 |
+
|
121 |
+
# upsample predictions to the target size
|
122 |
+
src_masks = F.interpolate(
|
123 |
+
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
124 |
+
)
|
125 |
+
src_masks = src_masks[:, 0].flatten(1)
|
126 |
+
|
127 |
+
target_masks = target_masks.flatten(1)
|
128 |
+
target_masks = target_masks.view(src_masks.shape)
|
129 |
+
losses = {
|
130 |
+
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
|
131 |
+
"loss_dice": dice_loss(src_masks, target_masks, num_masks),
|
132 |
+
}
|
133 |
+
return losses
|
134 |
+
|
135 |
+
def _get_src_permutation_idx(self, indices):
|
136 |
+
# permute predictions following indices
|
137 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
138 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
139 |
+
return batch_idx, src_idx
|
140 |
+
|
141 |
+
def _get_tgt_permutation_idx(self, indices):
|
142 |
+
# permute targets following indices
|
143 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
144 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
145 |
+
return batch_idx, tgt_idx
|
146 |
+
|
147 |
+
def get_loss(self, loss, outputs, targets, indices, num_masks):
|
148 |
+
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
|
149 |
+
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
150 |
+
return loss_map[loss](outputs, targets, indices, num_masks)
|
151 |
+
|
152 |
+
def forward(self, outputs, targets):
|
153 |
+
"""This performs the loss computation.
|
154 |
+
Parameters:
|
155 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
156 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
157 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
158 |
+
"""
|
159 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
160 |
+
|
161 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
162 |
+
indices = self.matcher(outputs_without_aux, targets)
|
163 |
+
|
164 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
165 |
+
num_masks = sum(len(t["labels"]) for t in targets)
|
166 |
+
num_masks = torch.as_tensor(
|
167 |
+
[num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
|
168 |
+
)
|
169 |
+
if is_dist_avail_and_initialized():
|
170 |
+
torch.distributed.all_reduce(num_masks)
|
171 |
+
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
172 |
+
|
173 |
+
# Compute all the requested losses
|
174 |
+
losses = {}
|
175 |
+
for loss in self.losses:
|
176 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
|
177 |
+
|
178 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
179 |
+
if "aux_outputs" in outputs:
|
180 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
181 |
+
indices = self.matcher(aux_outputs, targets)
|
182 |
+
for loss in self.losses:
|
183 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
|
184 |
+
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
185 |
+
losses.update(l_dict)
|
186 |
+
|
187 |
+
return losses
|
mask_former/modeling/heads/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
mask_former/modeling/heads/big_pixel_decoder.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
from typing import Callable, Dict, Optional, Union
|
4 |
+
|
5 |
+
import fvcore.nn.weight_init as weight_init
|
6 |
+
from detectron2.config import configurable
|
7 |
+
from detectron2.layers import Conv2d, ShapeSpec, get_norm
|
8 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from ..transformer.position_encoding import PositionEmbeddingSine
|
13 |
+
from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer
|
14 |
+
|
15 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
16 |
+
class BigPixelDecoder(nn.Module):
|
17 |
+
@configurable
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
input_shape: Dict[str, ShapeSpec],
|
21 |
+
*,
|
22 |
+
conv_dim: int,
|
23 |
+
mask_dim: int,
|
24 |
+
norm: Optional[Union[str, Callable]] = None,
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
NOTE: this interface is experimental.
|
28 |
+
Args:
|
29 |
+
input_shape: shapes (channels and stride) of the input features
|
30 |
+
conv_dims: number of output channels for the intermediate conv layers.
|
31 |
+
mask_dim: number of output channels for the final conv layer.
|
32 |
+
norm (str or callable): normalization for all conv layers
|
33 |
+
"""
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
37 |
+
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
|
38 |
+
feature_channels = [v.channels for k, v in input_shape]
|
39 |
+
|
40 |
+
lateral_convs = []
|
41 |
+
output_convs = []
|
42 |
+
|
43 |
+
use_bias = norm == ""
|
44 |
+
for idx, in_channels in enumerate(feature_channels):
|
45 |
+
if idx == len(self.in_features) - 1:
|
46 |
+
output_norm = get_norm(norm, conv_dim)
|
47 |
+
output_conv = Conv2d(
|
48 |
+
in_channels,
|
49 |
+
conv_dim,
|
50 |
+
kernel_size=3,
|
51 |
+
stride=1,
|
52 |
+
padding=1,
|
53 |
+
bias=use_bias,
|
54 |
+
norm=output_norm,
|
55 |
+
activation=F.relu,
|
56 |
+
)
|
57 |
+
weight_init.c2_xavier_fill(output_conv)
|
58 |
+
self.add_module("layer_{}".format(idx + 1), output_conv)
|
59 |
+
|
60 |
+
lateral_convs.append(None)
|
61 |
+
output_convs.append(output_conv)
|
62 |
+
else:
|
63 |
+
lateral_norm = get_norm(norm, conv_dim)
|
64 |
+
output_norm = get_norm(norm, conv_dim)
|
65 |
+
|
66 |
+
lateral_conv = Conv2d(
|
67 |
+
in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
|
68 |
+
)
|
69 |
+
output_conv = Conv2d(
|
70 |
+
conv_dim,
|
71 |
+
conv_dim,
|
72 |
+
kernel_size=3,
|
73 |
+
stride=1,
|
74 |
+
padding=1,
|
75 |
+
bias=use_bias,
|
76 |
+
norm=output_norm,
|
77 |
+
activation=F.relu,
|
78 |
+
)
|
79 |
+
weight_init.c2_xavier_fill(lateral_conv)
|
80 |
+
weight_init.c2_xavier_fill(output_conv)
|
81 |
+
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
|
82 |
+
self.add_module("layer_{}".format(idx + 1), output_conv)
|
83 |
+
|
84 |
+
lateral_convs.append(lateral_conv)
|
85 |
+
output_convs.append(output_conv)
|
86 |
+
# Place convs into top-down order (from low to high resolution)
|
87 |
+
# to make the top-down computation in forward clearer.
|
88 |
+
self.lateral_convs = lateral_convs[::-1]
|
89 |
+
self.output_convs = output_convs[::-1]
|
90 |
+
|
91 |
+
self.mask_dim = mask_dim
|
92 |
+
# self.mask_features = Conv2d(
|
93 |
+
# conv_dim,
|
94 |
+
# mask_dim,
|
95 |
+
# kernel_size=3,
|
96 |
+
# stride=1,
|
97 |
+
# padding=1,
|
98 |
+
# )
|
99 |
+
|
100 |
+
# weight_init.c2_xavier_fill(self.mask_features)
|
101 |
+
|
102 |
+
self.mask_features = nn.Sequential(
|
103 |
+
Conv2d(
|
104 |
+
conv_dim,
|
105 |
+
conv_dim,
|
106 |
+
kernel_size=3,
|
107 |
+
stride=1,
|
108 |
+
padding=1,
|
109 |
+
bias=use_bias,
|
110 |
+
norm=output_norm,
|
111 |
+
activation=F.relu,
|
112 |
+
),
|
113 |
+
nn.UpsamplingNearest2d(scale_factor=2),
|
114 |
+
Conv2d(
|
115 |
+
conv_dim,
|
116 |
+
conv_dim,
|
117 |
+
kernel_size=1,
|
118 |
+
stride=1,
|
119 |
+
padding=1,
|
120 |
+
bias=use_bias,
|
121 |
+
norm=output_norm,
|
122 |
+
activation=F.relu,
|
123 |
+
),
|
124 |
+
Conv2d(
|
125 |
+
conv_dim,
|
126 |
+
conv_dim,
|
127 |
+
kernel_size=3,
|
128 |
+
stride=1,
|
129 |
+
padding=1,
|
130 |
+
bias=use_bias,
|
131 |
+
norm=output_norm,
|
132 |
+
activation=F.relu,
|
133 |
+
),
|
134 |
+
nn.UpsamplingNearest2d(scale_factor=2),
|
135 |
+
Conv2d(
|
136 |
+
conv_dim,
|
137 |
+
conv_dim,
|
138 |
+
kernel_size=1,
|
139 |
+
stride=1,
|
140 |
+
padding=1,
|
141 |
+
bias=use_bias,
|
142 |
+
norm=output_norm,
|
143 |
+
activation=F.relu,
|
144 |
+
),
|
145 |
+
Conv2d(
|
146 |
+
conv_dim,
|
147 |
+
mask_dim,
|
148 |
+
kernel_size=3,
|
149 |
+
stride=1,
|
150 |
+
padding=1,
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
for name, module in self.mask_features.named_modules():
|
155 |
+
if 'Conv2d' in name:
|
156 |
+
weight_init.c2_xavier_fill(module)
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
160 |
+
ret = {}
|
161 |
+
ret["input_shape"] = {
|
162 |
+
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
163 |
+
}
|
164 |
+
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
165 |
+
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
166 |
+
ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
|
167 |
+
return ret
|
168 |
+
|
169 |
+
def forward_features(self, features):
|
170 |
+
# Reverse feature maps into top-down order (from low to high resolution)
|
171 |
+
for idx, f in enumerate(self.in_features[::-1]):
|
172 |
+
x = features[f]
|
173 |
+
lateral_conv = self.lateral_convs[idx]
|
174 |
+
output_conv = self.output_convs[idx]
|
175 |
+
if lateral_conv is None:
|
176 |
+
y = output_conv(x)
|
177 |
+
else:
|
178 |
+
cur_fpn = lateral_conv(x)
|
179 |
+
# Following FPN implementation, we use nearest upsampling here
|
180 |
+
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
|
181 |
+
y = output_conv(y)
|
182 |
+
return self.mask_features(y), None
|
183 |
+
|
184 |
+
def forward(self, features, targets=None):
|
185 |
+
logger = logging.getLogger(__name__)
|
186 |
+
logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
|
187 |
+
return self.forward_features(features)
|
188 |
+
|
189 |
+
|
190 |
+
class TransformerEncoderOnly(nn.Module):
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
d_model=512,
|
194 |
+
nhead=8,
|
195 |
+
num_encoder_layers=6,
|
196 |
+
dim_feedforward=2048,
|
197 |
+
dropout=0.1,
|
198 |
+
activation="relu",
|
199 |
+
normalize_before=False,
|
200 |
+
):
|
201 |
+
super().__init__()
|
202 |
+
|
203 |
+
encoder_layer = TransformerEncoderLayer(
|
204 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
205 |
+
)
|
206 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
207 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
208 |
+
|
209 |
+
self._reset_parameters()
|
210 |
+
|
211 |
+
self.d_model = d_model
|
212 |
+
self.nhead = nhead
|
213 |
+
|
214 |
+
def _reset_parameters(self):
|
215 |
+
for p in self.parameters():
|
216 |
+
if p.dim() > 1:
|
217 |
+
nn.init.xavier_uniform_(p)
|
218 |
+
|
219 |
+
def forward(self, src, mask, pos_embed):
|
220 |
+
# flatten NxCxHxW to HWxNxC
|
221 |
+
bs, c, h, w = src.shape
|
222 |
+
src = src.flatten(2).permute(2, 0, 1)
|
223 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
224 |
+
if mask is not None:
|
225 |
+
mask = mask.flatten(1)
|
226 |
+
|
227 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
228 |
+
return memory.permute(1, 2, 0).view(bs, c, h, w)
|
mask_former/modeling/heads/mask_former_head.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
from copy import deepcopy
|
4 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import fvcore.nn.weight_init as weight_init
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from detectron2.config import configurable
|
11 |
+
from detectron2.layers import Conv2d, ShapeSpec, get_norm
|
12 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
13 |
+
|
14 |
+
from ..transformer.transformer_predictor import TransformerPredictor
|
15 |
+
from .pixel_decoder import build_pixel_decoder
|
16 |
+
|
17 |
+
|
18 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
19 |
+
class MaskFormerHead(nn.Module):
|
20 |
+
|
21 |
+
_version = 2
|
22 |
+
|
23 |
+
def _load_from_state_dict(
|
24 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
25 |
+
):
|
26 |
+
version = local_metadata.get("version", None)
|
27 |
+
if version is None or version < 2:
|
28 |
+
# Do not warn if train from scratch
|
29 |
+
scratch = True
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
for k in list(state_dict.keys()):
|
32 |
+
newk = k
|
33 |
+
if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
|
34 |
+
newk = k.replace(prefix, prefix + "pixel_decoder.")
|
35 |
+
# logger.debug(f"{k} ==> {newk}")
|
36 |
+
if newk != k:
|
37 |
+
state_dict[newk] = state_dict[k]
|
38 |
+
del state_dict[k]
|
39 |
+
scratch = False
|
40 |
+
|
41 |
+
if not scratch:
|
42 |
+
logger.warning(
|
43 |
+
f"Weight format of {self.__class__.__name__} have changed! "
|
44 |
+
"Please upgrade your models. Applying automatic conversion now ..."
|
45 |
+
)
|
46 |
+
|
47 |
+
@configurable
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
input_shape: Dict[str, ShapeSpec],
|
51 |
+
*,
|
52 |
+
num_classes: int,
|
53 |
+
pixel_decoder: nn.Module,
|
54 |
+
loss_weight: float = 1.0,
|
55 |
+
ignore_value: int = -1,
|
56 |
+
# extra parameters
|
57 |
+
transformer_predictor: nn.Module,
|
58 |
+
transformer_in_feature: str,
|
59 |
+
):
|
60 |
+
"""
|
61 |
+
NOTE: this interface is experimental.
|
62 |
+
Args:
|
63 |
+
input_shape: shapes (channels and stride) of the input features
|
64 |
+
num_classes: number of classes to predict
|
65 |
+
pixel_decoder: the pixel decoder module
|
66 |
+
loss_weight: loss weight
|
67 |
+
ignore_value: category id to be ignored during training.
|
68 |
+
transformer_predictor: the transformer decoder that makes prediction
|
69 |
+
transformer_in_feature: input feature name to the transformer_predictor
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
73 |
+
self.in_features = [k for k, v in input_shape]
|
74 |
+
feature_strides = [v.stride for k, v in input_shape]
|
75 |
+
feature_channels = [v.channels for k, v in input_shape]
|
76 |
+
|
77 |
+
self.ignore_value = ignore_value
|
78 |
+
self.common_stride = 4
|
79 |
+
self.loss_weight = loss_weight
|
80 |
+
|
81 |
+
self.pixel_decoder = pixel_decoder
|
82 |
+
self.predictor = transformer_predictor
|
83 |
+
self.transformer_in_feature = transformer_in_feature
|
84 |
+
|
85 |
+
self.num_classes = num_classes
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
89 |
+
return {
|
90 |
+
"input_shape": {
|
91 |
+
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
92 |
+
},
|
93 |
+
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
94 |
+
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
95 |
+
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
|
96 |
+
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
|
97 |
+
"transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
|
98 |
+
"transformer_predictor": TransformerPredictor(
|
99 |
+
cfg,
|
100 |
+
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
101 |
+
if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
|
102 |
+
else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
|
103 |
+
mask_classification=True,
|
104 |
+
),
|
105 |
+
}
|
106 |
+
|
107 |
+
def forward(self, features):
|
108 |
+
return self.layers(features)
|
109 |
+
|
110 |
+
def layers(self, features):
|
111 |
+
mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
|
112 |
+
if self.transformer_in_feature == "transformer_encoder":
|
113 |
+
assert (
|
114 |
+
transformer_encoder_features is not None
|
115 |
+
), "Please use the TransformerEncoderPixelDecoder."
|
116 |
+
predictions = self.predictor(transformer_encoder_features, mask_features)
|
117 |
+
else:
|
118 |
+
predictions = self.predictor(features[self.transformer_in_feature], mask_features)
|
119 |
+
# predictions['features'] = mask_features
|
120 |
+
return predictions
|
mask_former/modeling/heads/mask_former_head_baseline.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
from copy import deepcopy
|
4 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import fvcore.nn.weight_init as weight_init
|
7 |
+
from torch import nn
|
8 |
+
import torch
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from detectron2.config import configurable
|
12 |
+
from detectron2.layers import Conv2d, ShapeSpec, get_norm
|
13 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
14 |
+
|
15 |
+
from ..transformer.transformer_predictor import TransformerPredictor
|
16 |
+
from .pixel_decoder import build_pixel_decoder
|
17 |
+
|
18 |
+
|
19 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
20 |
+
class MaskFormerBaselineHead(nn.Module):
|
21 |
+
|
22 |
+
_version = 2
|
23 |
+
|
24 |
+
def _load_from_state_dict(
|
25 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
26 |
+
):
|
27 |
+
version = local_metadata.get("version", None)
|
28 |
+
if version is None or version < 2:
|
29 |
+
# Do not warn if train from scratch
|
30 |
+
scratch = True
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
for k in list(state_dict.keys()):
|
33 |
+
newk = k
|
34 |
+
if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
|
35 |
+
newk = k.replace(prefix, prefix + "pixel_decoder.")
|
36 |
+
# logger.debug(f"{k} ==> {newk}")
|
37 |
+
if newk != k:
|
38 |
+
state_dict[newk] = state_dict[k]
|
39 |
+
del state_dict[k]
|
40 |
+
scratch = False
|
41 |
+
|
42 |
+
if not scratch:
|
43 |
+
logger.warning(
|
44 |
+
f"Weight format of {self.__class__.__name__} have changed! "
|
45 |
+
"Please upgrade your models. Applying automatic conversion now ..."
|
46 |
+
)
|
47 |
+
|
48 |
+
@configurable
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
input_shape: Dict[str, ShapeSpec],
|
52 |
+
*,
|
53 |
+
num_classes: int,
|
54 |
+
pixel_decoder: nn.Module,
|
55 |
+
loss_weight: float = 1.0,
|
56 |
+
ignore_value: int = -1,
|
57 |
+
# extra parameters
|
58 |
+
transformer_predictor: nn.Module,
|
59 |
+
transformer_in_feature: str,
|
60 |
+
):
|
61 |
+
"""
|
62 |
+
NOTE: this interface is experimental.
|
63 |
+
Args:
|
64 |
+
input_shape: shapes (channels and stride) of the input features
|
65 |
+
num_classes: number of classes to predict
|
66 |
+
pixel_decoder: the pixel decoder module
|
67 |
+
loss_weight: loss weight
|
68 |
+
ignore_value: category id to be ignored during training.
|
69 |
+
transformer_predictor: the transformer decoder that makes prediction
|
70 |
+
transformer_in_feature: input feature name to the transformer_predictor
|
71 |
+
"""
|
72 |
+
super().__init__()
|
73 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
74 |
+
self.in_features = [k for k, v in input_shape]
|
75 |
+
feature_strides = [v.stride for k, v in input_shape]
|
76 |
+
feature_channels = [v.channels for k, v in input_shape]
|
77 |
+
|
78 |
+
self.ignore_value = ignore_value
|
79 |
+
self.common_stride = 4
|
80 |
+
self.loss_weight = loss_weight
|
81 |
+
|
82 |
+
self.pixel_decoder = pixel_decoder
|
83 |
+
self.predictor = transformer_predictor
|
84 |
+
self.transformer_in_feature = transformer_in_feature
|
85 |
+
inc = 256
|
86 |
+
self.out_layers = nn.Sequential(nn.Conv2d(inc, inc, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(), nn.Linear(inc, 1))
|
87 |
+
self.num_classes = num_classes
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
91 |
+
return {
|
92 |
+
"input_shape": {
|
93 |
+
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
94 |
+
},
|
95 |
+
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
96 |
+
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
97 |
+
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
|
98 |
+
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
|
99 |
+
"transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
|
100 |
+
"transformer_predictor": TransformerPredictor(
|
101 |
+
cfg,
|
102 |
+
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
103 |
+
if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
|
104 |
+
else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
|
105 |
+
mask_classification=True,
|
106 |
+
),
|
107 |
+
}
|
108 |
+
|
109 |
+
def forward(self, features):
|
110 |
+
f = self.layers(features)
|
111 |
+
|
112 |
+
return self.out_layers(f).squeeze(-1)
|
113 |
+
|
114 |
+
def layers(self, features):
|
115 |
+
mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
|
116 |
+
# if self.transformer_in_feature == "transformer_encoder":
|
117 |
+
# assert (
|
118 |
+
# transformer_encoder_features is not None
|
119 |
+
# ), "Please use the TransformerEncoderPixelDecoder."
|
120 |
+
# predictions = self.predictor(transformer_encoder_features, mask_features)
|
121 |
+
# else:
|
122 |
+
# predictions = self.predictor(features[self.transformer_in_feature], mask_features)
|
123 |
+
return mask_features
|