zhangbo2008 JeffLiang commited on
Commit
7e8c559
0 Parent(s):

Duplicate from facebook/ov-seg

Browse files

Co-authored-by: Jeff Liang <JeffLiang@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .idea/vcs.xml +6 -0
  3. README.md +14 -0
  4. app.py +96 -0
  5. open_vocab_seg/.DS_Store +0 -0
  6. open_vocab_seg/__init__.py +9 -0
  7. open_vocab_seg/config.py +133 -0
  8. open_vocab_seg/data/.DS_Store +0 -0
  9. open_vocab_seg/data/__init__.py +9 -0
  10. open_vocab_seg/data/augmentations.py +202 -0
  11. open_vocab_seg/data/build.py +344 -0
  12. open_vocab_seg/data/dataset_mappers/__init__.py +4 -0
  13. open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py +208 -0
  14. open_vocab_seg/data/datasets/__init__.py +5 -0
  15. open_vocab_seg/data/datasets/csv_data.py +459 -0
  16. open_vocab_seg/data/datasets/register_ade20k_full.py +995 -0
  17. open_vocab_seg/data/datasets/register_cc3m.py +457 -0
  18. open_vocab_seg/data/datasets/register_coco_stuff.py +250 -0
  19. open_vocab_seg/data/datasets/register_pascal_context.py +588 -0
  20. open_vocab_seg/data/datasets/register_voc_seg.py +62 -0
  21. open_vocab_seg/evaluation/__init__.py +4 -0
  22. open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py +159 -0
  23. open_vocab_seg/mask_former_model.py +254 -0
  24. open_vocab_seg/modeling/.DS_Store +0 -0
  25. open_vocab_seg/modeling/__init__.py +8 -0
  26. open_vocab_seg/modeling/backbone/__init__.py +2 -0
  27. open_vocab_seg/modeling/backbone/clip_resnet.py +206 -0
  28. open_vocab_seg/modeling/backbone/swin.py +832 -0
  29. open_vocab_seg/modeling/clip_adapter/__init__.py +25 -0
  30. open_vocab_seg/modeling/clip_adapter/adapter.py +206 -0
  31. open_vocab_seg/modeling/clip_adapter/clip/__init__.py +1 -0
  32. open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  33. open_vocab_seg/modeling/clip_adapter/clip/clip.py +285 -0
  34. open_vocab_seg/modeling/clip_adapter/clip/model.py +613 -0
  35. open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py +150 -0
  36. open_vocab_seg/modeling/clip_adapter/text_template.py +156 -0
  37. open_vocab_seg/modeling/clip_adapter/utils.py +81 -0
  38. open_vocab_seg/modeling/criterion.py +229 -0
  39. open_vocab_seg/modeling/heads/__init__.py +2 -0
  40. open_vocab_seg/modeling/heads/mask_former_head.py +135 -0
  41. open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py +145 -0
  42. open_vocab_seg/modeling/heads/pixel_decoder.py +308 -0
  43. open_vocab_seg/modeling/matcher.py +187 -0
  44. open_vocab_seg/modeling/transformer/__init__.py +2 -0
  45. open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py +84 -0
  46. open_vocab_seg/modeling/transformer/position_encoding.py +58 -0
  47. open_vocab_seg/modeling/transformer/transformer.py +380 -0
  48. open_vocab_seg/modeling/transformer/transformer_predictor.py +179 -0
  49. open_vocab_seg/ovseg_model.py +460 -0
  50. open_vocab_seg/test_time_augmentation.py +217 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
35
+ *.gif filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ov Seg
3
+ emoji: 📊
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.8.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ duplicated_from: facebook/ov-seg
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import multiprocessing as mp
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+
10
+ try:
11
+ import detectron2
12
+ except:
13
+ import os
14
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
15
+
16
+ from detectron2.config import get_cfg
17
+
18
+ from detectron2.projects.deeplab import add_deeplab_config
19
+ from detectron2.data.detection_utils import read_image
20
+ from open_vocab_seg import add_ovseg_config
21
+ from open_vocab_seg.utils import VisualizationDemo, SAMVisualizationDemo
22
+
23
+ import gradio as gr
24
+
25
+ import gdown
26
+
27
+ # ckpt_url = 'https://drive.google.com/uc?id=1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy'
28
+ # output = './ovseg_swinbase_vitL14_ft_mpt.pth'
29
+ # gdown.download(ckpt_url, output, quiet=False)
30
+
31
+ def setup_cfg(config_file):
32
+ # load config from file and command-line arguments
33
+ cfg = get_cfg()
34
+ add_deeplab_config(cfg)
35
+ add_ovseg_config(cfg)
36
+ cfg.merge_from_file(config_file)
37
+ cfg.freeze()
38
+ return cfg
39
+
40
+
41
+ def inference(class_names, proposal_gen, granularity, input_img):
42
+ mp.set_start_method("spawn", force=True)
43
+ config_file = './ovseg_swinB_vitL_demo.yaml'
44
+ cfg = setup_cfg(config_file)
45
+ if proposal_gen == 'MaskFormer':
46
+ demo = VisualizationDemo(cfg)
47
+ elif proposal_gen == 'Segment_Anything':
48
+ demo = SAMVisualizationDemo(cfg, granularity, './sam_vit_l_0b3195.pth', './ovseg_clip_l_9a1909.pth')
49
+ class_names = class_names.split(',')
50
+ img = read_image(input_img, format="BGR")
51
+ _, visualized_output = demo.run_on_image(img, class_names)
52
+
53
+ return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB')
54
+
55
+
56
+ examples = [['Saturn V, toys, desk, wall, sunflowers, white roses, chrysanthemums, carnations, green dianthus', 'Segment_Anything', 0.8, './resources/demo_samples/sample_01.jpeg'],
57
+ ['red bench, yellow bench, blue bench, brown bench, green bench, blue chair, yellow chair, green chair, brown chair, yellow square painting, barrel, buddha statue', 'Segment_Anything', 0.8, './resources/demo_samples/sample_04.png'],
58
+ ['pillow, pipe, sweater, shirt, jeans jacket, shoes, cabinet, handbag, photo frame', 'Segment_Anything', 0.8, './resources/demo_samples/sample_05.png'],
59
+ ['Saturn V, toys, blossom', 'MaskFormer', 1.0, './resources/demo_samples/sample_01.jpeg'],
60
+ ['Oculus, Ukulele', 'MaskFormer', 1.0, './resources/demo_samples/sample_03.jpeg'],
61
+ ['Golden gate, yacht', 'MaskFormer', 1.0, './resources/demo_samples/sample_02.jpeg'],]
62
+ output_labels = ['segmentation map']
63
+
64
+ title = 'OVSeg (+ Segment_Anything)'
65
+
66
+ description = """
67
+ [NEW!] We incorperate OVSeg CLIP w/ Segment_Anything, enabling SAM's text prompts.
68
+ Gradio Demo for Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP. \n
69
+ OVSeg could perform open vocabulary segmentation, you may input more classes (seperate by comma). You may click on of the examples or upload your own image. \n
70
+ It might take some time to process. Cheers!
71
+ <p>(Colab only supports MaskFormer proposal generator) Don't want to wait in queue? <a href="https://colab.research.google.com/drive/1O4Ain5uFZNcQYUmDTG92DpEGCatga8K5?usp=sharing"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://camo.githubusercontent.com/84f0493939e0c4de4e6dbe113251b4bfb5353e57134ffd9fcab6b8714514d4d1/68747470733a2f2f636f6c61622e72657365617263682e676f6f676c652e636f6d2f6173736574732f636f6c61622d62616467652e737667"></a></p>
72
+ """
73
+
74
+ article = """
75
+ <p style='text-align: center'>
76
+ <a href='https://arxiv.org/abs/2210.04150' target='_blank'>
77
+ Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP
78
+ </a>
79
+ |
80
+ <a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p>
81
+ """
82
+
83
+ gr.Interface(
84
+ inference,
85
+ inputs=[
86
+ gr.Textbox(
87
+ lines=1, placeholder=None, default='', label='class names'),
88
+ gr.Radio(["Segment_Anything", "MaskFormer"], label="Proposal generator", default="Segment_Anything"),
89
+ gr.Slider(0, 1.0, 0.8, label="For Segment_Anything only, granularity of masks from 0 (most coarse) to 1 (most precise)"),
90
+ gr.Image(type='filepath'),
91
+ ],
92
+ outputs=gr.outputs.Image(label='segmentation map'),
93
+ title=title,
94
+ description=description,
95
+ article=article,
96
+ examples=examples).launch(enable_queue=True)
open_vocab_seg/.DS_Store ADDED
Binary file (6.15 kB). View file
 
open_vocab_seg/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from . import data
5
+ from . import modeling
6
+ from .config import add_ovseg_config
7
+
8
+ from .test_time_augmentation import SemanticSegmentorWithTTA
9
+ from .ovseg_model import OVSeg, OVSegDEMO
open_vocab_seg/config.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from detectron2.config import CfgNode as CN
5
+
6
+
7
+ def add_mask_former_default_config(cfg):
8
+ # data config
9
+ # select the dataset mapper
10
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
11
+ # Color augmentation
12
+ cfg.INPUT.COLOR_AUG_SSD = False
13
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
14
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
15
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
16
+ # Pad image and segmentation GT in dataset mapper.
17
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
18
+
19
+ # solver config
20
+ # test batch size
21
+ cfg.SOLVER.TEST_IMS_PER_BATCH = 1
22
+ # weight decay on embedding
23
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
24
+ # optimizer
25
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
26
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
27
+
28
+ # mask_former model config
29
+ cfg.MODEL.MASK_FORMER = CN()
30
+
31
+ # loss
32
+ cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
33
+ cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
34
+ cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
35
+ cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
36
+
37
+ # transformer config
38
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
39
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
40
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
41
+ cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
42
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
43
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
44
+
45
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
46
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
47
+
48
+ cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
49
+ cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
50
+
51
+ # mask_former inference config
52
+ cfg.MODEL.MASK_FORMER.TEST = CN()
53
+ cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
54
+ cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
55
+ cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
56
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
57
+
58
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
59
+ # you can use this config to override
60
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
61
+
62
+ # pixel decoder config
63
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
64
+ # adding transformer in pixel decoder
65
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
66
+ # pixel decoder
67
+ cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
68
+
69
+ # swin transformer backbone
70
+ cfg.MODEL.SWIN = CN()
71
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
72
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
73
+ cfg.MODEL.SWIN.EMBED_DIM = 96
74
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
75
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
76
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
77
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
78
+ cfg.MODEL.SWIN.QKV_BIAS = True
79
+ cfg.MODEL.SWIN.QK_SCALE = None
80
+ cfg.MODEL.SWIN.NORM_INDICES = None
81
+ cfg.MODEL.SWIN.PROJECTION = False
82
+ cfg.MODEL.SWIN.PROJECT_DIM = 256
83
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
84
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
85
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
86
+ cfg.MODEL.SWIN.APE = False
87
+ cfg.MODEL.SWIN.PATCH_NORM = True
88
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
89
+
90
+
91
+ def add_our_config(cfg):
92
+ cfg.TEST.SLIDING_WINDOW = False
93
+ cfg.TEST.SLIDING_TILE_SIZE = 224
94
+ cfg.TEST.SLIDING_OVERLAP = 2 / 3.0
95
+ # whether to use dense crf
96
+ cfg.TEST.DENSE_CRF = False
97
+ cfg.DATASETS.SAMPLE_PER_CLASS = -1
98
+ cfg.DATASETS.SAMPLE_SEED = 0
99
+ # embedding head
100
+ cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM = 512
101
+ cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM = 1024
102
+ cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS = 2
103
+ # clip_adapter
104
+ cfg.MODEL.CLIP_ADAPTER = CN()
105
+ cfg.MODEL.CLIP_ADAPTER.TEXT_TEMPLATES = "vild"
106
+ # for predefined
107
+ cfg.MODEL.CLIP_ADAPTER.PREDEFINED_PROMPT_TEMPLATES = ["a photo of a {}."]
108
+ # for learnable prompt
109
+ cfg.MODEL.CLIP_ADAPTER.PROMPT_CHECKPOINT = ""
110
+ cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME = "ViT-B/16"
111
+ cfg.MODEL.CLIP_ADAPTER.MASK_FILL = "mean"
112
+ cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO = 1.0
113
+ cfg.MODEL.CLIP_ADAPTER.MASK_THR = 0.4
114
+ cfg.MODEL.CLIP_ADAPTER.MASK_MATTING = False
115
+ cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED = True
116
+ cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE = True
117
+ cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT = 0.7
118
+ # for mask prompt
119
+ cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH = 3
120
+ cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD = False
121
+
122
+ # wandb
123
+ cfg.WANDB = CN()
124
+ cfg.WANDB.PROJECT = "open_vocab_seg"
125
+ cfg.WANDB.NAME = None
126
+
127
+
128
+ def add_ovseg_config(cfg):
129
+ """
130
+ Add config for open_vocab_seg.
131
+ """
132
+ add_mask_former_default_config(cfg)
133
+ add_our_config(cfg)
open_vocab_seg/data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
open_vocab_seg/data/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from .dataset_mappers import *
5
+ from . import datasets
6
+ from .build import (
7
+ build_detection_train_loader,
8
+ build_detection_test_loader,
9
+ )
open_vocab_seg/data/augmentations.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import math
5
+ import numbers
6
+ import numpy as np
7
+ from detectron2.data.transforms.augmentation import Augmentation
8
+ from detectron2.data.transforms.transform import (
9
+ CropTransform,
10
+ ResizeTransform,
11
+ TransformList,
12
+ )
13
+ from PIL import Image
14
+ from fvcore.transforms.transform import PadTransform
15
+
16
+
17
+ def mask2box(mask: np.ndarray):
18
+ # use naive way
19
+ row = np.nonzero(mask.sum(axis=0))[0]
20
+ if len(row) == 0:
21
+ return None
22
+ x1 = row.min()
23
+ x2 = row.max()
24
+ col = np.nonzero(mask.sum(axis=1))[0]
25
+ y1 = col.min()
26
+ y2 = col.max()
27
+ return x1, y1, x2 + 1 - x1, y2 + 1 - y1
28
+
29
+
30
+ def expand_box(x, y, w, h, expand_ratio=1.0, max_h=None, max_w=None):
31
+ cx = x + 0.5 * w
32
+ cy = y + 0.5 * h
33
+ w = w * expand_ratio
34
+ h = h * expand_ratio
35
+ box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h]
36
+ if max_h is not None:
37
+ box[1] = max(0, box[1])
38
+ box[3] = min(max_h - 1, box[3])
39
+ if max_w is not None:
40
+ box[0] = max(0, box[0])
41
+ box[2] = min(max_w - 1, box[2])
42
+ box[2] = box[2] - box[0]
43
+ box[3] = box[3] - box[1]
44
+
45
+ return [int(b) for b in box]
46
+
47
+
48
+ class CropImageWithMask(Augmentation):
49
+ def __init__(self, expand_ratio=1.0, mode="choice"):
50
+ if isinstance(expand_ratio, numbers.Number):
51
+ expand_ratio = (expand_ratio, expand_ratio)
52
+ self.mode = mode
53
+ self.expand_ratio = expand_ratio
54
+ if self.mode == "range":
55
+ assert len(expand_ratio) == 2 and expand_ratio[0] < expand_ratio[1]
56
+
57
+ def get_transform(self, image, sem_seg, category_id):
58
+ input_size = image.shape[:2]
59
+ bin_mask = sem_seg == category_id
60
+ x, y, w, h = mask2box(bin_mask)
61
+ if self.mode == "choice":
62
+ expand_ratio = np.random.choice(self.expand_ratio)
63
+ else:
64
+ expand_ratio = np.random.uniform(self.expand_ratio[0], self.expand_ratio[1])
65
+ x, y, w, h = expand_box(x, y, w, h, expand_ratio, *input_size)
66
+ w = max(w, 1)
67
+ h = max(h, 1)
68
+ return CropTransform(x, y, w, h, input_size[1], input_size[0])
69
+
70
+
71
+ class CropImageWithBox(Augmentation):
72
+ def __init__(self, expand_ratio=1.0, mode="choice"):
73
+ if isinstance(expand_ratio, numbers.Number):
74
+ expand_ratio = (expand_ratio, expand_ratio)
75
+ self.mode = mode
76
+ self.expand_ratio = expand_ratio
77
+ if self.mode == "range":
78
+ assert len(expand_ratio) == 2 and expand_ratio[0] < expand_ratio[1]
79
+
80
+ def get_transform(self, image, boxes):
81
+ input_size = image.shape[:2]
82
+ x, y, x2, y2 = boxes[0]
83
+ w = x2 - x + 1
84
+ h = y2 - y + 1
85
+ if self.mode == "choice":
86
+ expand_ratio = np.random.choice(self.expand_ratio)
87
+ else:
88
+ expand_ratio = np.random.uniform(self.expand_ratio[0], self.expand_ratio[1])
89
+ x, y, w, h = expand_box(x, y, w, h, expand_ratio, *input_size)
90
+ w = max(w, 1)
91
+ h = max(h, 1)
92
+ return CropTransform(x, y, w, h, input_size[1], input_size[0])
93
+
94
+
95
+ class RandomResizedCrop(Augmentation):
96
+ def __init__(
97
+ self,
98
+ size,
99
+ scale=(0.08, 1.0),
100
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
101
+ interpolation=Image.BILINEAR,
102
+ ):
103
+ if isinstance(size, int):
104
+ size = (size, size)
105
+ else:
106
+ assert isinstance(size, (tuple, list)) and len(size) == 2
107
+
108
+ self.size = size
109
+
110
+ self.scale = scale
111
+ self.ratio = ratio
112
+ self.interpolation = interpolation
113
+
114
+ def get_transform(self, image):
115
+ height, width = image.shape[:2]
116
+ area = height * width
117
+
118
+ log_ratio = np.log(np.array(self.ratio))
119
+ is_success = False
120
+ for _ in range(10):
121
+ target_area = area * np.random.uniform(self.scale[0], self.scale[1])
122
+ aspect_ratio = np.exp(np.random.uniform(log_ratio[0], log_ratio[1]))
123
+
124
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
125
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
126
+
127
+ if 0 < w <= width and 0 < h <= height:
128
+ i = np.random.randint(0, width - w + 1)
129
+ j = np.random.randint(0, height - h + 1)
130
+
131
+ is_success = True
132
+ break
133
+
134
+ if not is_success:
135
+ # Fallback to central crop
136
+ in_ratio = float(width) / float(height)
137
+ if in_ratio < min(self.ratio):
138
+ w = width
139
+ h = int(round(w / min(self.ratio)))
140
+ elif in_ratio > max(self.ratio):
141
+ h = height
142
+ w = int(round(h * max(self.ratio)))
143
+ else: # whole image
144
+ w = width
145
+ h = height
146
+ i = (width - w) // 2
147
+ j = (height - h) // 2
148
+ return TransformList(
149
+ [
150
+ CropTransform(i, j, w, h, width, height),
151
+ ResizeTransform(
152
+ h, w, self.size[1], self.size[0], interp=self.interpolation
153
+ ),
154
+ ]
155
+ )
156
+
157
+
158
+ class CenterCrop(Augmentation):
159
+ def __init__(self, size, seg_ignore_label):
160
+ if isinstance(size, numbers.Number):
161
+ size = (int(size), int(size))
162
+ elif isinstance(size, (tuple, list)) and len(size) == 1:
163
+ size = (size[0], size[0])
164
+ self.size = size
165
+ self.seg_ignore_label = seg_ignore_label
166
+
167
+ def get_transform(self, image):
168
+
169
+ image_height, image_width = image.shape[:2]
170
+ crop_height, crop_width = self.size
171
+
172
+ transforms = []
173
+ if crop_width > image_width or crop_height > image_height:
174
+ padding_ltrb = [
175
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
176
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
177
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
178
+ (crop_height - image_height + 1) // 2
179
+ if crop_height > image_height
180
+ else 0,
181
+ ]
182
+ transforms.append(
183
+ PadTransform(
184
+ *padding_ltrb,
185
+ orig_w=image_width,
186
+ orig_h=image_height,
187
+ seg_pad_value=self.seg_ignore_label
188
+ )
189
+ )
190
+ image_width, image_height = (
191
+ image_width + padding_ltrb[0] + padding_ltrb[2],
192
+ image_height + padding_ltrb[1] + padding_ltrb[3],
193
+ )
194
+
195
+ crop_top = int(round((image_height - crop_height) / 2.0))
196
+ crop_left = int(round((image_width - crop_width) / 2.0))
197
+ transforms.append(
198
+ CropTransform(
199
+ crop_left, crop_top, crop_width, crop_height, image_width, image_height
200
+ )
201
+ )
202
+ return TransformList(transforms)
open_vocab_seg/data/build.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import itertools
5
+ import logging
6
+ import numpy as np
7
+ from collections import Counter
8
+ import torch.utils.data
9
+ from tabulate import tabulate
10
+ from termcolor import colored
11
+
12
+ from detectron2.utils.logger import _log_api_usage, log_first_n
13
+ from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
14
+ import torch.utils.data
15
+ from detectron2.config import configurable
16
+ from detectron2.data.build import (
17
+ build_batch_data_loader,
18
+ trivial_batch_collator,
19
+ load_proposals_into_dataset,
20
+ filter_images_with_only_crowd_annotations,
21
+ filter_images_with_few_keypoints,
22
+ print_instances_class_histogram,
23
+ )
24
+
25
+ from detectron2.data.common import DatasetFromList, MapDataset
26
+ from detectron2.data.dataset_mapper import DatasetMapper
27
+ from detectron2.data.detection_utils import check_metadata_consistency
28
+ from detectron2.data.samplers import (
29
+ InferenceSampler,
30
+ RandomSubsetTrainingSampler,
31
+ RepeatFactorTrainingSampler,
32
+ TrainingSampler,
33
+ )
34
+
35
+ """
36
+ This file contains the default logic to build a dataloader for training or testing.
37
+ """
38
+
39
+ __all__ = [
40
+ "build_detection_train_loader",
41
+ "build_detection_test_loader",
42
+ ]
43
+
44
+
45
+ def print_classification_instances_class_histogram(dataset_dicts, class_names):
46
+ """
47
+ Args:
48
+ dataset_dicts (list[dict]): list of dataset dicts.
49
+ class_names (list[str]): list of class names (zero-indexed).
50
+ """
51
+ num_classes = len(class_names)
52
+ hist_bins = np.arange(num_classes + 1)
53
+ histogram = np.zeros((num_classes,), dtype=np.int)
54
+ for entry in dataset_dicts:
55
+ classes = np.asarray([entry["category_id"]], dtype=np.int)
56
+ if len(classes):
57
+ assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
58
+ assert (
59
+ classes.max() < num_classes
60
+ ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
61
+ histogram += np.histogram(classes, bins=hist_bins)[0]
62
+
63
+ N_COLS = min(6, len(class_names) * 2)
64
+
65
+ def short_name(x):
66
+ # make long class names shorter. useful for lvis
67
+ if len(x) > 13:
68
+ return x[:11] + ".."
69
+ return x
70
+
71
+ data = list(
72
+ itertools.chain(
73
+ *[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]
74
+ )
75
+ )
76
+ total_num_instances = sum(data[1::2])
77
+ data.extend([None] * (N_COLS - (len(data) % N_COLS)))
78
+ if num_classes > 1:
79
+ data.extend(["total", total_num_instances])
80
+ data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
81
+ table = tabulate(
82
+ data,
83
+ headers=["category", "#instances"] * (N_COLS // 2),
84
+ tablefmt="pipe",
85
+ numalign="left",
86
+ stralign="center",
87
+ )
88
+ log_first_n(
89
+ logging.INFO,
90
+ "Distribution of instances among all {} categories:\n".format(num_classes)
91
+ + colored(table, "cyan"),
92
+ key="message",
93
+ )
94
+
95
+
96
+ def wrap_metas(dataset_dict, **kwargs):
97
+ def _assign_attr(data_dict: dict, **kwargs):
98
+ assert not any(
99
+ [key in data_dict for key in kwargs]
100
+ ), "Assigned attributes should not exist in the original sample."
101
+ data_dict.update(kwargs)
102
+ return data_dict
103
+
104
+ return [_assign_attr(sample, meta=kwargs) for sample in dataset_dict]
105
+
106
+
107
+ def get_detection_dataset_dicts(
108
+ names, filter_empty=True, min_keypoints=0, proposal_files=None
109
+ ):
110
+ """
111
+ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
112
+
113
+ Args:
114
+ names (str or list[str]): a dataset name or a list of dataset names
115
+ filter_empty (bool): whether to filter out images without instance annotations
116
+ min_keypoints (int): filter out images with fewer keypoints than
117
+ `min_keypoints`. Set to 0 to do nothing.
118
+ proposal_files (list[str]): if given, a list of object proposal files
119
+ that match each dataset in `names`.
120
+
121
+ Returns:
122
+ list[dict]: a list of dicts following the standard dataset dict format.
123
+ """
124
+ if isinstance(names, str):
125
+ names = [names]
126
+ assert len(names), names
127
+ dataset_dicts = [
128
+ wrap_metas(DatasetCatalog.get(dataset_name), dataset_name=dataset_name)
129
+ for dataset_name in names
130
+ ]
131
+ for dataset_name, dicts in zip(names, dataset_dicts):
132
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
133
+
134
+ if proposal_files is not None:
135
+ assert len(names) == len(proposal_files)
136
+ # load precomputed proposals from proposal files
137
+ dataset_dicts = [
138
+ load_proposals_into_dataset(dataset_i_dicts, proposal_file)
139
+ for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
140
+ ]
141
+
142
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
143
+
144
+ has_instances = "annotations" in dataset_dicts[0]
145
+ if filter_empty and has_instances:
146
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
147
+ if min_keypoints > 0 and has_instances:
148
+ dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
149
+
150
+ if has_instances:
151
+ try:
152
+ class_names = MetadataCatalog.get(names[0]).thing_classes
153
+ check_metadata_consistency("thing_classes", names)
154
+ print_instances_class_histogram(dataset_dicts, class_names)
155
+ except AttributeError: # class names are not available for this dataset
156
+ pass
157
+
158
+ assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
159
+ return dataset_dicts
160
+
161
+
162
+ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
163
+ if dataset is None:
164
+ dataset = get_detection_dataset_dicts(
165
+ cfg.DATASETS.TRAIN,
166
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
167
+ min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
168
+ if cfg.MODEL.KEYPOINT_ON
169
+ else 0,
170
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN
171
+ if cfg.MODEL.LOAD_PROPOSALS
172
+ else None,
173
+ )
174
+ _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
175
+
176
+ if mapper is None:
177
+ mapper = DatasetMapper(cfg, True)
178
+
179
+ if sampler is None:
180
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
181
+ logger = logging.getLogger(__name__)
182
+ logger.info("Using training sampler {}".format(sampler_name))
183
+ if sampler_name == "TrainingSampler":
184
+ sampler = TrainingSampler(len(dataset))
185
+ elif sampler_name == "RepeatFactorTrainingSampler":
186
+ repeat_factors = (
187
+ RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
188
+ dataset, cfg.DATALOADER.REPEAT_THRESHOLD
189
+ )
190
+ )
191
+ sampler = RepeatFactorTrainingSampler(repeat_factors)
192
+ elif sampler_name == "RandomSubsetTrainingSampler":
193
+ sampler = RandomSubsetTrainingSampler(
194
+ len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
195
+ )
196
+ else:
197
+ raise ValueError("Unknown training sampler: {}".format(sampler_name))
198
+
199
+ return {
200
+ "dataset": dataset,
201
+ "sampler": sampler,
202
+ "mapper": mapper,
203
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
204
+ "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
205
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
206
+ }
207
+
208
+
209
+ # TODO can allow dataset as an iterable or IterableDataset to make this function more general
210
+ @configurable(from_config=_train_loader_from_config)
211
+ def build_detection_train_loader(
212
+ dataset,
213
+ *,
214
+ mapper,
215
+ sampler=None,
216
+ total_batch_size,
217
+ aspect_ratio_grouping=True,
218
+ num_workers=0,
219
+ ):
220
+ """
221
+ Build a dataloader for object detection with some default features.
222
+ This interface is experimental.
223
+
224
+ Args:
225
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
226
+ or a map-style pytorch dataset. They can be obtained by using
227
+ :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
228
+ mapper (callable): a callable which takes a sample (dict) from dataset and
229
+ returns the format to be consumed by the model.
230
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
231
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
232
+ indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
233
+ which coordinates an infinite random shuffle sequence across all workers.
234
+ total_batch_size (int): total batch size across all workers. Batching
235
+ simply puts data into a list.
236
+ aspect_ratio_grouping (bool): whether to group images with similar
237
+ aspect ratio for efficiency. When enabled, it requires each
238
+ element in dataset be a dict with keys "width" and "height".
239
+ num_workers (int): number of parallel data loading workers
240
+
241
+ Returns:
242
+ torch.utils.data.DataLoader:
243
+ a dataloader. Each output from it is a ``list[mapped_element]`` of length
244
+ ``total_batch_size / num_workers``, where ``mapped_element`` is produced
245
+ by the ``mapper``.
246
+ """
247
+ if isinstance(dataset, list):
248
+ dataset = DatasetFromList(dataset, copy=False)
249
+ if mapper is not None:
250
+ dataset = MapDataset(dataset, mapper)
251
+ if sampler is None:
252
+ sampler = TrainingSampler(len(dataset))
253
+ assert isinstance(sampler, torch.utils.data.sampler.Sampler)
254
+ return build_batch_data_loader(
255
+ dataset,
256
+ sampler,
257
+ total_batch_size,
258
+ aspect_ratio_grouping=aspect_ratio_grouping,
259
+ num_workers=num_workers,
260
+ )
261
+
262
+
263
+ def _test_loader_from_config(cfg, dataset_name, mapper=None):
264
+ """
265
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
266
+ standard practice is to evaluate each test set individually (not combining them).
267
+ """
268
+ if isinstance(dataset_name, str):
269
+ dataset_name = [dataset_name]
270
+
271
+ dataset = get_detection_dataset_dicts(
272
+ dataset_name,
273
+ filter_empty=False,
274
+ proposal_files=[
275
+ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)]
276
+ for x in dataset_name
277
+ ]
278
+ if cfg.MODEL.LOAD_PROPOSALS
279
+ else None,
280
+ )
281
+ if mapper is None:
282
+ mapper = DatasetMapper(cfg, False)
283
+ return {
284
+ "dataset": dataset,
285
+ "mapper": mapper,
286
+ "num_workers": 0,
287
+ "samples_per_gpu": cfg.SOLVER.TEST_IMS_PER_BATCH,
288
+ }
289
+
290
+
291
+ @configurable(from_config=_test_loader_from_config)
292
+ def build_detection_test_loader(
293
+ dataset, *, mapper, sampler=None, num_workers=0, samples_per_gpu=1
294
+ ):
295
+ """
296
+ Similar to `build_detection_train_loader`, but uses a batch size of 1,
297
+ and :class:`InferenceSampler`. This sampler coordinates all workers to
298
+ produce the exact set of all samples.
299
+ This interface is experimental.
300
+
301
+ Args:
302
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
303
+ or a map-style pytorch dataset. They can be obtained by using
304
+ :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
305
+ mapper (callable): a callable which takes a sample (dict) from dataset
306
+ and returns the format to be consumed by the model.
307
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
308
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
309
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
310
+ which splits the dataset across all workers.
311
+ num_workers (int): number of parallel data loading workers
312
+
313
+ Returns:
314
+ DataLoader: a torch DataLoader, that loads the given detection
315
+ dataset, with test-time transformation and batching.
316
+
317
+ Examples:
318
+ ::
319
+ data_loader = build_detection_test_loader(
320
+ DatasetRegistry.get("my_test"),
321
+ mapper=DatasetMapper(...))
322
+
323
+ # or, instantiate with a CfgNode:
324
+ data_loader = build_detection_test_loader(cfg, "my_test")
325
+ """
326
+ if isinstance(dataset, list):
327
+ dataset = DatasetFromList(dataset, copy=False)
328
+ if mapper is not None:
329
+ dataset = MapDataset(dataset, mapper)
330
+ if sampler is None:
331
+ sampler = InferenceSampler(len(dataset))
332
+ # Always use 1 image per worker during inference since this is the
333
+ # standard when reporting inference time in papers.
334
+ batch_sampler = torch.utils.data.sampler.BatchSampler(
335
+ sampler, samples_per_gpu, drop_last=False
336
+ )
337
+ data_loader = torch.utils.data.DataLoader(
338
+ dataset,
339
+ num_workers=num_workers,
340
+ batch_sampler=batch_sampler,
341
+ collate_fn=trivial_batch_collator,
342
+ )
343
+ return data_loader
344
+
open_vocab_seg/data/dataset_mappers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import copy
5
+ import logging
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+
11
+ from detectron2.config import configurable
12
+ from detectron2.data import MetadataCatalog
13
+ from detectron2.data import detection_utils as utils
14
+ from detectron2.data import transforms as T
15
+ from detectron2.projects.point_rend import ColorAugSSDTransform
16
+ from detectron2.structures import BitMasks, Instances
17
+
18
+ __all__ = ["MaskFormerSemanticDatasetMapper"]
19
+
20
+
21
+ class MaskFormerSemanticDatasetMapper:
22
+ """
23
+ A callable which takes a dataset dict in Detectron2 Dataset format,
24
+ and map it into a format used by MaskFormer for semantic segmentation.
25
+
26
+ The callable currently does the following:
27
+
28
+ 1. Read the image from "file_name"
29
+ 2. Applies geometric transforms to the image and annotation
30
+ 3. Find and applies suitable cropping to the image and annotation
31
+ 4. Prepare image and annotation to Tensors
32
+ """
33
+
34
+ @configurable
35
+ def __init__(
36
+ self,
37
+ is_train=True,
38
+ *,
39
+ augmentations,
40
+ image_format,
41
+ ignore_label,
42
+ size_divisibility,
43
+ ):
44
+ """
45
+ NOTE: this interface is experimental.
46
+ Args:
47
+ is_train: for training or inference
48
+ augmentations: a list of augmentations or deterministic transforms to apply
49
+ image_format: an image format supported by :func:`detection_utils.read_image`.
50
+ ignore_label: the label that is ignored to evaluation
51
+ size_divisibility: pad image size to be divisible by this value
52
+ """
53
+ self.is_train = is_train
54
+ self.tfm_gens = augmentations
55
+ self.img_format = image_format
56
+ self.ignore_label = ignore_label
57
+ self.size_divisibility = size_divisibility
58
+
59
+ logger = logging.getLogger(__name__)
60
+ mode = "training" if is_train else "inference"
61
+ logger.info(
62
+ f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}"
63
+ )
64
+
65
+ @classmethod
66
+ def from_config(cls, cfg, is_train=True):
67
+ # Build augmentation
68
+ if is_train:
69
+ augs = [
70
+ T.ResizeShortestEdge(
71
+ cfg.INPUT.MIN_SIZE_TRAIN,
72
+ cfg.INPUT.MAX_SIZE_TRAIN,
73
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
74
+ )
75
+ ]
76
+ if cfg.INPUT.CROP.ENABLED:
77
+ augs.append(
78
+ T.RandomCrop_CategoryAreaConstraint(
79
+ cfg.INPUT.CROP.TYPE,
80
+ cfg.INPUT.CROP.SIZE,
81
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
82
+ cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
83
+ )
84
+ )
85
+ if cfg.INPUT.COLOR_AUG_SSD:
86
+ augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
87
+ augs.append(T.RandomFlip())
88
+
89
+ # Assume always applies to the training set.
90
+ dataset_names = cfg.DATASETS.TRAIN
91
+ else:
92
+ min_size = cfg.INPUT.MIN_SIZE_TEST
93
+ max_size = cfg.INPUT.MAX_SIZE_TEST
94
+ sample_style = "choice"
95
+ augs = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
96
+ dataset_names = cfg.DATASETS.TEST
97
+ meta = MetadataCatalog.get(dataset_names[0])
98
+ ignore_label = meta.ignore_label
99
+
100
+ ret = {
101
+ "is_train": is_train,
102
+ "augmentations": augs,
103
+ "image_format": cfg.INPUT.FORMAT,
104
+ "ignore_label": ignore_label,
105
+ "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY if is_train else -1,
106
+ }
107
+ return ret
108
+
109
+ def __call__(self, dataset_dict):
110
+ """
111
+ Args:
112
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
113
+
114
+ Returns:
115
+ dict: a format that builtin models in detectron2 accept
116
+ """
117
+ # assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
118
+
119
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
120
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
121
+ utils.check_image_size(dataset_dict, image)
122
+
123
+ if "sem_seg_file_name" in dataset_dict:
124
+ # PyTorch transformation not implemented for uint16, so converting it to double first
125
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype(
126
+ "double"
127
+ )
128
+ else:
129
+ sem_seg_gt = None
130
+
131
+ if sem_seg_gt is None:
132
+ raise ValueError(
133
+ "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
134
+ dataset_dict["file_name"]
135
+ )
136
+ )
137
+
138
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
139
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
140
+ image = aug_input.image
141
+ sem_seg_gt = aug_input.sem_seg
142
+
143
+ # Pad image and segmentation label here!
144
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
145
+ if sem_seg_gt is not None:
146
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
147
+
148
+ if self.size_divisibility > 0:
149
+ image_size = (image.shape[-2], image.shape[-1])
150
+ padding_size = [
151
+ 0,
152
+ self.size_divisibility - image_size[1],
153
+ 0,
154
+ self.size_divisibility - image_size[0],
155
+ ]
156
+ image = F.pad(image, padding_size, value=128).contiguous()
157
+ if sem_seg_gt is not None:
158
+ sem_seg_gt = F.pad(
159
+ sem_seg_gt, padding_size, value=self.ignore_label
160
+ ).contiguous()
161
+
162
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
163
+
164
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
165
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
166
+ # Therefore it's important to use torch.Tensor.
167
+ dataset_dict["image"] = image
168
+
169
+ if sem_seg_gt is not None:
170
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
171
+
172
+ if "annotations" in dataset_dict:
173
+ raise ValueError(
174
+ "Semantic segmentation dataset should not have 'annotations'."
175
+ )
176
+
177
+ # Prepare per-category binary masks
178
+ if sem_seg_gt is not None:
179
+ sem_seg_gt = sem_seg_gt.numpy()
180
+ instances = Instances(image_shape)
181
+ classes = np.unique(sem_seg_gt)
182
+ # remove ignored region
183
+ classes = classes[classes != self.ignore_label]
184
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
185
+
186
+ masks = []
187
+ for class_id in classes:
188
+ masks.append(sem_seg_gt == class_id)
189
+
190
+ if len(masks) == 0:
191
+ # Some image does not have annotation (all ignored)
192
+ instances.gt_masks = torch.zeros(
193
+ (0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])
194
+ )
195
+ else:
196
+ masks = BitMasks(
197
+ torch.stack(
198
+ [
199
+ torch.from_numpy(np.ascontiguousarray(x.copy()))
200
+ for x in masks
201
+ ]
202
+ )
203
+ )
204
+ instances.gt_masks = masks.tensor
205
+
206
+ dataset_dict["instances"] = instances
207
+
208
+ return dataset_dict
open_vocab_seg/data/datasets/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import register_coco_stuff, register_voc_seg
3
+ from . import register_cc3m
4
+ from . import register_ade20k_full
5
+ from . import register_pascal_context
open_vocab_seg/data/datasets/csv_data.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import ast
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ import random
8
+ import sys
9
+ import time
10
+ from dataclasses import dataclass
11
+ from multiprocessing import Value
12
+
13
+ import braceexpand
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch
17
+ import torchvision.datasets as datasets
18
+ import webdataset as wds
19
+ from PIL import Image
20
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
21
+ from torch.utils.data.distributed import DistributedSampler
22
+ from webdataset.filters import _shuffle
23
+ from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
24
+
25
+ try:
26
+ import horovod.torch as hvd
27
+ except ImportError:
28
+ hvd = None
29
+
30
+ from clip import tokenize
31
+
32
+
33
+ class CsvDataset(Dataset):
34
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
35
+ logging.debug(f'Loading csv data from {input_filename}.')
36
+ df = pd.read_csv(input_filename, sep=sep)
37
+
38
+ self.images = df[img_key].tolist()
39
+ self.captions = df[caption_key].tolist()
40
+ self.transforms = transforms
41
+ logging.debug('Done loading data.')
42
+
43
+ def __len__(self):
44
+ return len(self.captions)
45
+
46
+ def __getitem__(self, idx):
47
+ images = self.transforms(Image.open(str(self.images[idx])))
48
+ texts = tokenize([str(self.captions[idx])])[0]
49
+ return images, texts
50
+
51
+
52
+ class SharedEpoch:
53
+ def __init__(self, epoch: int = 0):
54
+ self.shared_epoch = Value('i', epoch)
55
+
56
+ def set_value(self, epoch):
57
+ self.shared_epoch.value = epoch
58
+
59
+ def get_value(self):
60
+ return self.shared_epoch.value
61
+
62
+
63
+ @dataclass
64
+ class DataInfo:
65
+ dataloader: DataLoader
66
+ sampler: DistributedSampler = None
67
+ shared_epoch: SharedEpoch = None
68
+
69
+ def set_epoch(self, epoch):
70
+ if self.shared_epoch is not None:
71
+ self.shared_epoch.set_value(epoch)
72
+ if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
73
+ self.sampler.set_epoch(epoch)
74
+
75
+
76
+ def preprocess_txt(text):
77
+ return tokenize([str(text)])[0]
78
+
79
+
80
+ def get_dataset_size(shards):
81
+ shards_list = list(braceexpand.braceexpand(shards))
82
+ dir_path = os.path.dirname(shards)
83
+ sizes_filename = os.path.join(dir_path, 'sizes.json')
84
+ len_filename = os.path.join(dir_path, '__len__')
85
+ if os.path.exists(sizes_filename):
86
+ sizes = json.load(open(sizes_filename, 'r'))
87
+ total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
88
+ elif os.path.exists(len_filename):
89
+ # FIXME this used to be eval(open(...)) but that seemed rather unsafe
90
+ total_size = ast.literal_eval(open(len_filename, 'r').read())
91
+ else:
92
+ total_size = None # num samples undefined
93
+ # some common dataset sizes (at time of authors last download)
94
+ # CC3M (train): 2905954
95
+ # CC12M: 10968539
96
+ # LAION-400M: 407332084
97
+ # LAION-2B (english): 2170337258
98
+ num_shards = len(shards_list)
99
+ return total_size, num_shards
100
+
101
+
102
+ def get_imagenet(args, preprocess_fns, split):
103
+ assert split in ["train", "val", "v2"]
104
+ is_train = split == "train"
105
+ preprocess_train, preprocess_val = preprocess_fns
106
+
107
+ if split == "v2":
108
+ from imagenetv2_pytorch import ImageNetV2Dataset
109
+ dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
110
+ else:
111
+ if is_train:
112
+ data_path = args.imagenet_train
113
+ preprocess_fn = preprocess_train
114
+ else:
115
+ data_path = args.imagenet_val
116
+ preprocess_fn = preprocess_val
117
+ assert data_path
118
+
119
+ dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
120
+
121
+ if is_train:
122
+ idxs = np.zeros(len(dataset.targets))
123
+ target_array = np.array(dataset.targets)
124
+ k = 50
125
+ for c in range(1000):
126
+ m = target_array == c
127
+ n = len(idxs[m])
128
+ arr = np.zeros(n)
129
+ arr[:k] = 1
130
+ np.random.shuffle(arr)
131
+ idxs[m] = arr
132
+
133
+ idxs = idxs.astype('int')
134
+ sampler = SubsetRandomSampler(np.where(idxs)[0])
135
+ else:
136
+ sampler = None
137
+
138
+ dataloader = torch.utils.data.DataLoader(
139
+ dataset,
140
+ batch_size=args.batch_size,
141
+ num_workers=args.workers,
142
+ sampler=sampler,
143
+ )
144
+
145
+ return DataInfo(dataloader=dataloader, sampler=sampler)
146
+
147
+
148
+ def count_samples(dataloader):
149
+ os.environ["WDS_EPOCH"] = "0"
150
+ n_elements, n_batches = 0, 0
151
+ for images, texts in dataloader:
152
+ n_batches += 1
153
+ n_elements += len(images)
154
+ assert len(images) == len(texts)
155
+ return n_elements, n_batches
156
+
157
+
158
+ def filter_no_caption(sample):
159
+ return 'txt' in sample
160
+
161
+
162
+ def log_and_continue(exn):
163
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
164
+ logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
165
+ return True
166
+
167
+
168
+ def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
169
+ """Return function over iterator that groups key, value pairs into samples.
170
+
171
+ :param keys: function that splits the key into key and extension (base_plus_ext)
172
+ :param lcase: convert suffixes to lower case (Default value = True)
173
+ """
174
+ current_sample = None
175
+ for filesample in data:
176
+ assert isinstance(filesample, dict)
177
+ fname, value = filesample["fname"], filesample["data"]
178
+ prefix, suffix = keys(fname)
179
+ if prefix is None:
180
+ continue
181
+ if lcase:
182
+ suffix = suffix.lower()
183
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
184
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
185
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
186
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
187
+ if valid_sample(current_sample):
188
+ yield current_sample
189
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
190
+ if suffixes is None or suffix in suffixes:
191
+ current_sample[suffix] = value
192
+ if valid_sample(current_sample):
193
+ yield current_sample
194
+
195
+
196
+ def tarfile_to_samples_nothrow(src, handler=log_and_continue):
197
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
198
+ streams = url_opener(src, handler=handler)
199
+ files = tar_file_expander(streams, handler=handler)
200
+ samples = group_by_keys_nothrow(files, handler=handler)
201
+ return samples
202
+
203
+
204
+ def pytorch_worker_seed():
205
+ """get dataloader worker seed from pytorch"""
206
+ worker_info = get_worker_info()
207
+ if worker_info is not None:
208
+ # favour the seed already created for pytorch dataloader workers if it exists
209
+ return worker_info.seed
210
+ # fallback to wds rank based seed
211
+ return wds.utils.pytorch_worker_seed()
212
+
213
+
214
+ _SHARD_SHUFFLE_SIZE = 2000
215
+ _SHARD_SHUFFLE_INITIAL = 500
216
+ _SAMPLE_SHUFFLE_SIZE = 5000
217
+ _SAMPLE_SHUFFLE_INITIAL = 1000
218
+
219
+
220
+ class detshuffle2(wds.PipelineStage):
221
+ def __init__(
222
+ self,
223
+ bufsize=1000,
224
+ initial=100,
225
+ seed=0,
226
+ epoch=-1,
227
+ ):
228
+ self.bufsize = bufsize
229
+ self.initial = initial
230
+ self.seed = seed
231
+ self.epoch = epoch
232
+
233
+ def run(self, src):
234
+ if isinstance(self.epoch, SharedEpoch):
235
+ epoch = self.epoch.get_value()
236
+ else:
237
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
238
+ # situation as different workers may wrap at different times (or not at all).
239
+ self.epoch += 1
240
+ epoch = self.epoch
241
+ rng = random.Random()
242
+ if self.seed < 0:
243
+ seed = pytorch_worker_seed() + epoch
244
+ else:
245
+ seed = self.seed + epoch
246
+ rng.seed(seed)
247
+ return _shuffle(src, self.bufsize, self.initial, rng)
248
+
249
+
250
+ class ResampledShards2(IterableDataset):
251
+ """An iterable dataset yielding a list of urls."""
252
+
253
+ def __init__(
254
+ self,
255
+ urls,
256
+ nshards=sys.maxsize,
257
+ worker_seed=None,
258
+ deterministic=False,
259
+ epoch=-1,
260
+ ):
261
+ """Sample shards from the shard list with replacement.
262
+
263
+ :param urls: a list of URLs as a Python list or brace notation string
264
+ """
265
+ super().__init__()
266
+ urls = wds.shardlists.expand_urls(urls)
267
+ self.urls = urls
268
+ assert isinstance(self.urls[0], str)
269
+ self.nshards = nshards
270
+ self.rng = random.Random()
271
+ self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
272
+ self.deterministic = deterministic
273
+ self.epoch = epoch
274
+
275
+ def __iter__(self):
276
+ """Return an iterator over the shards."""
277
+ if isinstance(self.epoch, SharedEpoch):
278
+ epoch = self.epoch.get_value()
279
+ else:
280
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
281
+ # situation as different workers may wrap at different times (or not at all).
282
+ self.epoch += 1
283
+ epoch = self.epoch
284
+ if self.deterministic:
285
+ # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
286
+ self.rng.seed(self.worker_seed() + epoch)
287
+ for _ in range(self.nshards):
288
+ yield dict(url=self.rng.choice(self.urls))
289
+
290
+
291
+ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False):
292
+ input_shards = args.train_data if is_train else args.val_data
293
+ assert input_shards is not None
294
+ resampled = getattr(args, 'dataset_resampled', False) and is_train
295
+
296
+ num_samples, num_shards = get_dataset_size(input_shards)
297
+ if not num_samples:
298
+ if is_train:
299
+ num_samples = args.train_num_samples
300
+ if not num_samples:
301
+ raise RuntimeError(
302
+ 'Currently, number of dataset samples must be specified for training dataset. '
303
+ 'Please specify via `--train-num-samples` if no dataset length info present.')
304
+ else:
305
+ num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified
306
+
307
+ shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
308
+ if resampled:
309
+ pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)]
310
+ else:
311
+ pipeline = [wds.SimpleShardList(input_shards)]
312
+
313
+ # at this point we have an iterator over all the shards
314
+ if is_train:
315
+ if not resampled:
316
+ pipeline.extend([
317
+ detshuffle2(
318
+ bufsize=_SHARD_SHUFFLE_SIZE,
319
+ initial=_SHARD_SHUFFLE_INITIAL,
320
+ seed=args.seed,
321
+ epoch=shared_epoch,
322
+ ),
323
+ wds.split_by_node,
324
+ wds.split_by_worker,
325
+ ])
326
+ pipeline.extend([
327
+ # at this point, we have an iterator over the shards assigned to each worker at each node
328
+ tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
329
+ wds.shuffle(
330
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
331
+ initial=_SAMPLE_SHUFFLE_INITIAL,
332
+ ),
333
+ ])
334
+ else:
335
+ pipeline.extend([
336
+ wds.split_by_worker,
337
+ # at this point, we have an iterator over the shards assigned to each worker
338
+ wds.tarfile_to_samples(handler=log_and_continue),
339
+ ])
340
+ pipeline.extend([
341
+ wds.select(filter_no_caption),
342
+ wds.decode("pilrgb", handler=log_and_continue),
343
+ wds.rename(image="jpg;png", text="txt"),
344
+ wds.map_dict(image=preprocess_img, text=preprocess_txt),
345
+ wds.to_tuple("image", "text"),
346
+ wds.batched(args.batch_size, partial=not is_train),
347
+ ])
348
+
349
+ dataset = wds.DataPipeline(*pipeline)
350
+ if is_train:
351
+ if not resampled:
352
+ assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
353
+ # roll over and repeat a few samples to get same number of full batches on each node
354
+ round_fn = math.floor if floor else math.ceil
355
+ global_batch_size = args.batch_size * args.world_size
356
+ num_batches = round_fn(num_samples / global_batch_size)
357
+ num_workers = max(1, args.workers)
358
+ num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
359
+ num_batches = num_worker_batches * num_workers
360
+ num_samples = num_batches * global_batch_size
361
+ dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
362
+ else:
363
+ # last batches are partial, eval is done on single (master) node
364
+ num_batches = math.ceil(num_samples / args.batch_size)
365
+
366
+ dataloader = wds.WebLoader(
367
+ dataset,
368
+ batch_size=None,
369
+ shuffle=False,
370
+ num_workers=args.workers,
371
+ persistent_workers=True,
372
+ )
373
+
374
+ # FIXME not clear which approach is better, with_epoch before vs after dataloader?
375
+ # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
376
+ # if is_train:
377
+ # # roll over and repeat a few samples to get same number of full batches on each node
378
+ # global_batch_size = args.batch_size * args.world_size
379
+ # num_batches = math.ceil(num_samples / global_batch_size)
380
+ # num_workers = max(1, args.workers)
381
+ # num_batches = math.ceil(num_batches / num_workers) * num_workers
382
+ # num_samples = num_batches * global_batch_size
383
+ # dataloader = dataloader.with_epoch(num_batches)
384
+ # else:
385
+ # # last batches are partial, eval is done on single (master) node
386
+ # num_batches = math.ceil(num_samples / args.batch_size)
387
+
388
+ # add meta-data to dataloader instance for convenience
389
+ dataloader.num_batches = num_batches
390
+ dataloader.num_samples = num_samples
391
+
392
+ return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
393
+
394
+
395
+ def get_csv_dataset(args, preprocess_fn, is_train, epoch=0):
396
+ input_filename = args.train_data if is_train else args.val_data
397
+ assert input_filename
398
+ dataset = CsvDataset(
399
+ input_filename,
400
+ preprocess_fn,
401
+ img_key=args.csv_img_key,
402
+ caption_key=args.csv_caption_key,
403
+ sep=args.csv_separator)
404
+ num_samples = len(dataset)
405
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
406
+ shuffle = is_train and sampler is None
407
+
408
+ dataloader = DataLoader(
409
+ dataset,
410
+ batch_size=args.batch_size,
411
+ shuffle=shuffle,
412
+ num_workers=args.workers,
413
+ pin_memory=True,
414
+ sampler=sampler,
415
+ drop_last=is_train,
416
+ )
417
+ dataloader.num_samples = num_samples
418
+ dataloader.num_batches = len(dataloader)
419
+
420
+ return DataInfo(dataloader, sampler)
421
+
422
+
423
+ def get_dataset_fn(data_path, dataset_type):
424
+ if dataset_type == "webdataset":
425
+ return get_wds_dataset
426
+ elif dataset_type == "csv":
427
+ return get_csv_dataset
428
+ elif dataset_type == "auto":
429
+ ext = data_path.split('.')[-1]
430
+ if ext in ['csv', 'tsv']:
431
+ return get_csv_dataset
432
+ elif ext in ['tar']:
433
+ return get_wds_dataset
434
+ else:
435
+ raise ValueError(
436
+ f"Tried to figure out dataset type, but failed for extention {ext}.")
437
+ else:
438
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
439
+
440
+
441
+ def get_data(args, preprocess_fns, epoch=0):
442
+ preprocess_train, preprocess_val = preprocess_fns
443
+ data = {}
444
+
445
+ if args.train_data:
446
+ data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
447
+ args, preprocess_train, is_train=True, epoch=epoch)
448
+
449
+ if args.val_data:
450
+ data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
451
+ args, preprocess_val, is_train=False)
452
+
453
+ if args.imagenet_val is not None:
454
+ data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
455
+
456
+ if args.imagenet_v2 is not None:
457
+ data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
458
+
459
+ return data
open_vocab_seg/data/datasets/register_ade20k_full.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ from detectron2.data import DatasetCatalog, MetadataCatalog
5
+ from detectron2.data.datasets import load_sem_seg
6
+
7
+ ADE20K_SEM_SEG_FULL_CATEGORIES = [
8
+ {"name": "wall", "id": 2978, "trainId": 0},
9
+ {"name": "building, edifice", "id": 312, "trainId": 1},
10
+ {"name": "sky", "id": 2420, "trainId": 2},
11
+ {"name": "tree", "id": 2855, "trainId": 3},
12
+ {"name": "road, route", "id": 2131, "trainId": 4},
13
+ {"name": "floor, flooring", "id": 976, "trainId": 5},
14
+ {"name": "ceiling", "id": 447, "trainId": 6},
15
+ {"name": "bed", "id": 165, "trainId": 7},
16
+ {"name": "sidewalk, pavement", "id": 2377, "trainId": 8},
17
+ {"name": "earth, ground", "id": 838, "trainId": 9},
18
+ {"name": "cabinet", "id": 350, "trainId": 10},
19
+ {
20
+ "name": "person, individual, someone, somebody, mortal, soul",
21
+ "id": 1831,
22
+ "trainId": 11,
23
+ },
24
+ {"name": "grass", "id": 1125, "trainId": 12},
25
+ {"name": "windowpane, window", "id": 3055, "trainId": 13},
26
+ {"name": "car, auto, automobile, machine, motorcar", "id": 401, "trainId": 14},
27
+ {"name": "mountain, mount", "id": 1610, "trainId": 15},
28
+ {"name": "plant, flora, plant life", "id": 1910, "trainId": 16},
29
+ {"name": "table", "id": 2684, "trainId": 17},
30
+ {"name": "chair", "id": 471, "trainId": 18},
31
+ {"name": "curtain, drape, drapery, mantle, pall", "id": 687, "trainId": 19},
32
+ {"name": "door", "id": 774, "trainId": 20},
33
+ {"name": "sofa, couch, lounge", "id": 2473, "trainId": 21},
34
+ {"name": "sea", "id": 2264, "trainId": 22},
35
+ {"name": "painting, picture", "id": 1735, "trainId": 23},
36
+ {"name": "water", "id": 2994, "trainId": 24},
37
+ {"name": "mirror", "id": 1564, "trainId": 25},
38
+ {"name": "house", "id": 1276, "trainId": 26},
39
+ {"name": "rug, carpet, carpeting", "id": 2178, "trainId": 27},
40
+ {"name": "shelf", "id": 2329, "trainId": 28},
41
+ {"name": "armchair", "id": 57, "trainId": 29},
42
+ {"name": "fence, fencing", "id": 907, "trainId": 30},
43
+ {"name": "field", "id": 913, "trainId": 31},
44
+ {"name": "lamp", "id": 1395, "trainId": 32},
45
+ {"name": "rock, stone", "id": 2138, "trainId": 33},
46
+ {"name": "seat", "id": 2272, "trainId": 34},
47
+ {"name": "river", "id": 2128, "trainId": 35},
48
+ {"name": "desk", "id": 724, "trainId": 36},
49
+ {"name": "bathtub, bathing tub, bath, tub", "id": 155, "trainId": 37},
50
+ {"name": "railing, rail", "id": 2053, "trainId": 38},
51
+ {"name": "signboard, sign", "id": 2380, "trainId": 39},
52
+ {"name": "cushion", "id": 689, "trainId": 40},
53
+ {"name": "path", "id": 1788, "trainId": 41},
54
+ {"name": "work surface", "id": 3087, "trainId": 42},
55
+ {"name": "stairs, steps", "id": 2530, "trainId": 43},
56
+ {"name": "column, pillar", "id": 581, "trainId": 44},
57
+ {"name": "sink", "id": 2388, "trainId": 45},
58
+ {"name": "wardrobe, closet, press", "id": 2985, "trainId": 46},
59
+ {"name": "snow", "id": 2454, "trainId": 47},
60
+ {"name": "refrigerator, icebox", "id": 2096, "trainId": 48},
61
+ {"name": "base, pedestal, stand", "id": 137, "trainId": 49},
62
+ {"name": "bridge, span", "id": 294, "trainId": 50},
63
+ {"name": "blind, screen", "id": 212, "trainId": 51},
64
+ {"name": "runway", "id": 2185, "trainId": 52},
65
+ {"name": "cliff, drop, drop-off", "id": 524, "trainId": 53},
66
+ {"name": "sand", "id": 2212, "trainId": 54},
67
+ {"name": "fireplace, hearth, open fireplace", "id": 943, "trainId": 55},
68
+ {"name": "pillow", "id": 1869, "trainId": 56},
69
+ {"name": "screen door, screen", "id": 2251, "trainId": 57},
70
+ {
71
+ "name": "toilet, can, commode, crapper, pot, potty, stool, throne",
72
+ "id": 2793,
73
+ "trainId": 58,
74
+ },
75
+ {"name": "skyscraper", "id": 2423, "trainId": 59},
76
+ {"name": "grandstand, covered stand", "id": 1121, "trainId": 60},
77
+ {"name": "box", "id": 266, "trainId": 61},
78
+ {"name": "pool table, billiard table, snooker table", "id": 1948, "trainId": 62},
79
+ {"name": "palm, palm tree", "id": 1744, "trainId": 63},
80
+ {"name": "double door", "id": 783, "trainId": 64},
81
+ {"name": "coffee table, cocktail table", "id": 571, "trainId": 65},
82
+ {"name": "counter", "id": 627, "trainId": 66},
83
+ {"name": "countertop", "id": 629, "trainId": 67},
84
+ {"name": "chest of drawers, chest, bureau, dresser", "id": 491, "trainId": 68},
85
+ {"name": "kitchen island", "id": 1374, "trainId": 69},
86
+ {"name": "boat", "id": 223, "trainId": 70},
87
+ {"name": "waterfall, falls", "id": 3016, "trainId": 71},
88
+ {
89
+ "name": "stove, kitchen stove, range, kitchen range, cooking stove",
90
+ "id": 2598,
91
+ "trainId": 72,
92
+ },
93
+ {"name": "flower", "id": 978, "trainId": 73},
94
+ {"name": "bookcase", "id": 239, "trainId": 74},
95
+ {"name": "controls", "id": 608, "trainId": 75},
96
+ {"name": "book", "id": 236, "trainId": 76},
97
+ {"name": "stairway, staircase", "id": 2531, "trainId": 77},
98
+ {"name": "streetlight, street lamp", "id": 2616, "trainId": 78},
99
+ {
100
+ "name": "computer, computing machine, computing device, data processor, electronic computer, information processing system",
101
+ "id": 591,
102
+ "trainId": 79,
103
+ },
104
+ {
105
+ "name": "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle",
106
+ "id": 327,
107
+ "trainId": 80,
108
+ },
109
+ {"name": "swivel chair", "id": 2679, "trainId": 81},
110
+ {"name": "light, light source", "id": 1451, "trainId": 82},
111
+ {"name": "bench", "id": 181, "trainId": 83},
112
+ {"name": "case, display case, showcase, vitrine", "id": 420, "trainId": 84},
113
+ {"name": "towel", "id": 2821, "trainId": 85},
114
+ {"name": "fountain", "id": 1023, "trainId": 86},
115
+ {"name": "embankment", "id": 855, "trainId": 87},
116
+ {
117
+ "name": "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box",
118
+ "id": 2733,
119
+ "trainId": 88,
120
+ },
121
+ {"name": "van", "id": 2928, "trainId": 89},
122
+ {"name": "hill", "id": 1240, "trainId": 90},
123
+ {"name": "awning, sunshade, sunblind", "id": 77, "trainId": 91},
124
+ {"name": "poster, posting, placard, notice, bill, card", "id": 1969, "trainId": 92},
125
+ {"name": "truck, motortruck", "id": 2880, "trainId": 93},
126
+ {"name": "airplane, aeroplane, plane", "id": 14, "trainId": 94},
127
+ {"name": "pole", "id": 1936, "trainId": 95},
128
+ {"name": "tower", "id": 2828, "trainId": 96},
129
+ {"name": "court", "id": 631, "trainId": 97},
130
+ {"name": "ball", "id": 103, "trainId": 98},
131
+ {
132
+ "name": "aircraft carrier, carrier, flattop, attack aircraft carrier",
133
+ "id": 3144,
134
+ "trainId": 99,
135
+ },
136
+ {"name": "buffet, counter, sideboard", "id": 308, "trainId": 100},
137
+ {"name": "hovel, hut, hutch, shack, shanty", "id": 1282, "trainId": 101},
138
+ {"name": "apparel, wearing apparel, dress, clothes", "id": 38, "trainId": 102},
139
+ {"name": "minibike, motorbike", "id": 1563, "trainId": 103},
140
+ {
141
+ "name": "animal, animate being, beast, brute, creature, fauna",
142
+ "id": 29,
143
+ "trainId": 104,
144
+ },
145
+ {"name": "chandelier, pendant, pendent", "id": 480, "trainId": 105},
146
+ {"name": "step, stair", "id": 2569, "trainId": 106},
147
+ {"name": "booth, cubicle, stall, kiosk", "id": 247, "trainId": 107},
148
+ {"name": "bicycle, bike, wheel, cycle", "id": 187, "trainId": 108},
149
+ {"name": "doorframe, doorcase", "id": 778, "trainId": 109},
150
+ {"name": "sconce", "id": 2243, "trainId": 110},
151
+ {"name": "pond", "id": 1941, "trainId": 111},
152
+ {"name": "trade name, brand name, brand, marque", "id": 2833, "trainId": 112},
153
+ {
154
+ "name": "bannister, banister, balustrade, balusters, handrail",
155
+ "id": 120,
156
+ "trainId": 113,
157
+ },
158
+ {"name": "bag", "id": 95, "trainId": 114},
159
+ {"name": "traffic light, traffic signal, stoplight", "id": 2836, "trainId": 115},
160
+ {"name": "gazebo", "id": 1087, "trainId": 116},
161
+ {"name": "escalator, moving staircase, moving stairway", "id": 868, "trainId": 117},
162
+ {"name": "land, ground, soil", "id": 1401, "trainId": 118},
163
+ {"name": "board, plank", "id": 220, "trainId": 119},
164
+ {"name": "arcade machine", "id": 47, "trainId": 120},
165
+ {"name": "eiderdown, duvet, continental quilt", "id": 843, "trainId": 121},
166
+ {"name": "bar", "id": 123, "trainId": 122},
167
+ {"name": "stall, stand, sales booth", "id": 2537, "trainId": 123},
168
+ {"name": "playground", "id": 1927, "trainId": 124},
169
+ {"name": "ship", "id": 2337, "trainId": 125},
170
+ {"name": "ottoman, pouf, pouffe, puff, hassock", "id": 1702, "trainId": 126},
171
+ {
172
+ "name": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
173
+ "id": 64,
174
+ "trainId": 127,
175
+ },
176
+ {"name": "bottle", "id": 249, "trainId": 128},
177
+ {"name": "cradle", "id": 642, "trainId": 129},
178
+ {"name": "pot, flowerpot", "id": 1981, "trainId": 130},
179
+ {
180
+ "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
181
+ "id": 609,
182
+ "trainId": 131,
183
+ },
184
+ {"name": "train, railroad train", "id": 2840, "trainId": 132},
185
+ {"name": "stool", "id": 2586, "trainId": 133},
186
+ {"name": "lake", "id": 1393, "trainId": 134},
187
+ {"name": "tank, storage tank", "id": 2704, "trainId": 135},
188
+ {"name": "ice, water ice", "id": 1304, "trainId": 136},
189
+ {"name": "basket, handbasket", "id": 146, "trainId": 137},
190
+ {"name": "manhole", "id": 1494, "trainId": 138},
191
+ {"name": "tent, collapsible shelter", "id": 2739, "trainId": 139},
192
+ {"name": "canopy", "id": 389, "trainId": 140},
193
+ {"name": "microwave, microwave oven", "id": 1551, "trainId": 141},
194
+ {"name": "barrel, cask", "id": 131, "trainId": 142},
195
+ {"name": "dirt track", "id": 738, "trainId": 143},
196
+ {"name": "beam", "id": 161, "trainId": 144},
197
+ {"name": "dishwasher, dish washer, dishwashing machine", "id": 747, "trainId": 145},
198
+ {"name": "plate", "id": 1919, "trainId": 146},
199
+ {"name": "screen, crt screen", "id": 3109, "trainId": 147},
200
+ {"name": "ruins", "id": 2179, "trainId": 148},
201
+ {"name": "washer, automatic washer, washing machine", "id": 2989, "trainId": 149},
202
+ {"name": "blanket, cover", "id": 206, "trainId": 150},
203
+ {"name": "plaything, toy", "id": 1930, "trainId": 151},
204
+ {"name": "food, solid food", "id": 1002, "trainId": 152},
205
+ {"name": "screen, silver screen, projection screen", "id": 2254, "trainId": 153},
206
+ {"name": "oven", "id": 1708, "trainId": 154},
207
+ {"name": "stage", "id": 2526, "trainId": 155},
208
+ {"name": "beacon, lighthouse, beacon light, pharos", "id": 160, "trainId": 156},
209
+ {"name": "umbrella", "id": 2901, "trainId": 157},
210
+ {"name": "sculpture", "id": 2262, "trainId": 158},
211
+ {"name": "aqueduct", "id": 44, "trainId": 159},
212
+ {"name": "container", "id": 597, "trainId": 160},
213
+ {"name": "scaffolding, staging", "id": 2235, "trainId": 161},
214
+ {"name": "hood, exhaust hood", "id": 1260, "trainId": 162},
215
+ {"name": "curb, curbing, kerb", "id": 682, "trainId": 163},
216
+ {"name": "roller coaster", "id": 2151, "trainId": 164},
217
+ {"name": "horse, equus caballus", "id": 3107, "trainId": 165},
218
+ {"name": "catwalk", "id": 432, "trainId": 166},
219
+ {"name": "glass, drinking glass", "id": 1098, "trainId": 167},
220
+ {"name": "vase", "id": 2932, "trainId": 168},
221
+ {"name": "central reservation", "id": 461, "trainId": 169},
222
+ {"name": "carousel", "id": 410, "trainId": 170},
223
+ {"name": "radiator", "id": 2046, "trainId": 171},
224
+ {"name": "closet", "id": 533, "trainId": 172},
225
+ {"name": "machine", "id": 1481, "trainId": 173},
226
+ {"name": "pier, wharf, wharfage, dock", "id": 1858, "trainId": 174},
227
+ {"name": "fan", "id": 894, "trainId": 175},
228
+ {"name": "inflatable bounce game", "id": 1322, "trainId": 176},
229
+ {"name": "pitch", "id": 1891, "trainId": 177},
230
+ {"name": "paper", "id": 1756, "trainId": 178},
231
+ {"name": "arcade, colonnade", "id": 49, "trainId": 179},
232
+ {"name": "hot tub", "id": 1272, "trainId": 180},
233
+ {"name": "helicopter", "id": 1229, "trainId": 181},
234
+ {"name": "tray", "id": 2850, "trainId": 182},
235
+ {"name": "partition, divider", "id": 1784, "trainId": 183},
236
+ {"name": "vineyard", "id": 2962, "trainId": 184},
237
+ {"name": "bowl", "id": 259, "trainId": 185},
238
+ {"name": "bullring", "id": 319, "trainId": 186},
239
+ {"name": "flag", "id": 954, "trainId": 187},
240
+ {"name": "pot", "id": 1974, "trainId": 188},
241
+ {"name": "footbridge, overcrossing, pedestrian bridge", "id": 1013, "trainId": 189},
242
+ {"name": "shower", "id": 2356, "trainId": 190},
243
+ {
244
+ "name": "bag, traveling bag, travelling bag, grip, suitcase",
245
+ "id": 97,
246
+ "trainId": 191,
247
+ },
248
+ {"name": "bulletin board, notice board", "id": 318, "trainId": 192},
249
+ {"name": "confessional booth", "id": 592, "trainId": 193},
250
+ {"name": "trunk, tree trunk, bole", "id": 2885, "trainId": 194},
251
+ {"name": "forest", "id": 1017, "trainId": 195},
252
+ {"name": "elevator door", "id": 851, "trainId": 196},
253
+ {"name": "laptop, laptop computer", "id": 1407, "trainId": 197},
254
+ {"name": "instrument panel", "id": 1332, "trainId": 198},
255
+ {"name": "bucket, pail", "id": 303, "trainId": 199},
256
+ {"name": "tapestry, tapis", "id": 2714, "trainId": 200},
257
+ {"name": "platform", "id": 1924, "trainId": 201},
258
+ {"name": "jacket", "id": 1346, "trainId": 202},
259
+ {"name": "gate", "id": 1081, "trainId": 203},
260
+ {"name": "monitor, monitoring device", "id": 1583, "trainId": 204},
261
+ {
262
+ "name": "telephone booth, phone booth, call box, telephone box, telephone kiosk",
263
+ "id": 2727,
264
+ "trainId": 205,
265
+ },
266
+ {"name": "spotlight, spot", "id": 2509, "trainId": 206},
267
+ {"name": "ring", "id": 2123, "trainId": 207},
268
+ {"name": "control panel", "id": 602, "trainId": 208},
269
+ {"name": "blackboard, chalkboard", "id": 202, "trainId": 209},
270
+ {"name": "air conditioner, air conditioning", "id": 10, "trainId": 210},
271
+ {"name": "chest", "id": 490, "trainId": 211},
272
+ {"name": "clock", "id": 530, "trainId": 212},
273
+ {"name": "sand dune", "id": 2213, "trainId": 213},
274
+ {"name": "pipe, pipage, piping", "id": 1884, "trainId": 214},
275
+ {"name": "vault", "id": 2934, "trainId": 215},
276
+ {"name": "table football", "id": 2687, "trainId": 216},
277
+ {"name": "cannon", "id": 387, "trainId": 217},
278
+ {"name": "swimming pool, swimming bath, natatorium", "id": 2668, "trainId": 218},
279
+ {"name": "fluorescent, fluorescent fixture", "id": 982, "trainId": 219},
280
+ {"name": "statue", "id": 2547, "trainId": 220},
281
+ {
282
+ "name": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
283
+ "id": 1474,
284
+ "trainId": 221,
285
+ },
286
+ {"name": "exhibitor", "id": 877, "trainId": 222},
287
+ {"name": "ladder", "id": 1391, "trainId": 223},
288
+ {"name": "carport", "id": 414, "trainId": 224},
289
+ {"name": "dam", "id": 698, "trainId": 225},
290
+ {"name": "pulpit", "id": 2019, "trainId": 226},
291
+ {"name": "skylight, fanlight", "id": 2422, "trainId": 227},
292
+ {"name": "water tower", "id": 3010, "trainId": 228},
293
+ {"name": "grill, grille, grillwork", "id": 1139, "trainId": 229},
294
+ {"name": "display board", "id": 753, "trainId": 230},
295
+ {"name": "pane, pane of glass, window glass", "id": 1747, "trainId": 231},
296
+ {"name": "rubbish, trash, scrap", "id": 2175, "trainId": 232},
297
+ {"name": "ice rink", "id": 1301, "trainId": 233},
298
+ {"name": "fruit", "id": 1033, "trainId": 234},
299
+ {"name": "patio", "id": 1789, "trainId": 235},
300
+ {"name": "vending machine", "id": 2939, "trainId": 236},
301
+ {"name": "telephone, phone, telephone set", "id": 2730, "trainId": 237},
302
+ {"name": "net", "id": 1652, "trainId": 238},
303
+ {
304
+ "name": "backpack, back pack, knapsack, packsack, rucksack, haversack",
305
+ "id": 90,
306
+ "trainId": 239,
307
+ },
308
+ {"name": "jar", "id": 1349, "trainId": 240},
309
+ {"name": "track", "id": 2830, "trainId": 241},
310
+ {"name": "magazine", "id": 1485, "trainId": 242},
311
+ {"name": "shutter", "id": 2370, "trainId": 243},
312
+ {"name": "roof", "id": 2155, "trainId": 244},
313
+ {"name": "banner, streamer", "id": 118, "trainId": 245},
314
+ {"name": "landfill", "id": 1402, "trainId": 246},
315
+ {"name": "post", "id": 1957, "trainId": 247},
316
+ {"name": "altarpiece, reredos", "id": 3130, "trainId": 248},
317
+ {"name": "hat, chapeau, lid", "id": 1197, "trainId": 249},
318
+ {"name": "arch, archway", "id": 52, "trainId": 250},
319
+ {"name": "table game", "id": 2688, "trainId": 251},
320
+ {"name": "bag, handbag, pocketbook, purse", "id": 96, "trainId": 252},
321
+ {"name": "document, written document, papers", "id": 762, "trainId": 253},
322
+ {"name": "dome", "id": 772, "trainId": 254},
323
+ {"name": "pier", "id": 1857, "trainId": 255},
324
+ {"name": "shanties", "id": 2315, "trainId": 256},
325
+ {"name": "forecourt", "id": 1016, "trainId": 257},
326
+ {"name": "crane", "id": 643, "trainId": 258},
327
+ {"name": "dog, domestic dog, canis familiaris", "id": 3105, "trainId": 259},
328
+ {"name": "piano, pianoforte, forte-piano", "id": 1849, "trainId": 260},
329
+ {"name": "drawing", "id": 791, "trainId": 261},
330
+ {"name": "cabin", "id": 349, "trainId": 262},
331
+ {
332
+ "name": "ad, advertisement, advertizement, advertising, advertizing, advert",
333
+ "id": 6,
334
+ "trainId": 263,
335
+ },
336
+ {"name": "amphitheater, amphitheatre, coliseum", "id": 3114, "trainId": 264},
337
+ {"name": "monument", "id": 1587, "trainId": 265},
338
+ {"name": "henhouse", "id": 1233, "trainId": 266},
339
+ {"name": "cockpit", "id": 559, "trainId": 267},
340
+ {"name": "heater, warmer", "id": 1223, "trainId": 268},
341
+ {"name": "windmill, aerogenerator, wind generator", "id": 3049, "trainId": 269},
342
+ {"name": "pool", "id": 1943, "trainId": 270},
343
+ {"name": "elevator, lift", "id": 853, "trainId": 271},
344
+ {"name": "decoration, ornament, ornamentation", "id": 709, "trainId": 272},
345
+ {"name": "labyrinth", "id": 1390, "trainId": 273},
346
+ {"name": "text, textual matter", "id": 2748, "trainId": 274},
347
+ {"name": "printer", "id": 2007, "trainId": 275},
348
+ {"name": "mezzanine, first balcony", "id": 1546, "trainId": 276},
349
+ {"name": "mattress", "id": 1513, "trainId": 277},
350
+ {"name": "straw", "id": 2600, "trainId": 278},
351
+ {"name": "stalls", "id": 2538, "trainId": 279},
352
+ {"name": "patio, terrace", "id": 1790, "trainId": 280},
353
+ {"name": "billboard, hoarding", "id": 194, "trainId": 281},
354
+ {"name": "bus stop", "id": 326, "trainId": 282},
355
+ {"name": "trouser, pant", "id": 2877, "trainId": 283},
356
+ {"name": "console table, console", "id": 594, "trainId": 284},
357
+ {"name": "rack", "id": 2036, "trainId": 285},
358
+ {"name": "notebook", "id": 1662, "trainId": 286},
359
+ {"name": "shrine", "id": 2366, "trainId": 287},
360
+ {"name": "pantry", "id": 1754, "trainId": 288},
361
+ {"name": "cart", "id": 418, "trainId": 289},
362
+ {"name": "steam shovel", "id": 2553, "trainId": 290},
363
+ {"name": "porch", "id": 1951, "trainId": 291},
364
+ {"name": "postbox, mailbox, letter box", "id": 1963, "trainId": 292},
365
+ {"name": "figurine, statuette", "id": 918, "trainId": 293},
366
+ {"name": "recycling bin", "id": 2086, "trainId": 294},
367
+ {"name": "folding screen", "id": 997, "trainId": 295},
368
+ {"name": "telescope", "id": 2731, "trainId": 296},
369
+ {"name": "deck chair, beach chair", "id": 704, "trainId": 297},
370
+ {"name": "kennel", "id": 1365, "trainId": 298},
371
+ {"name": "coffee maker", "id": 569, "trainId": 299},
372
+ {"name": "altar, communion table, lord's table", "id": 3108, "trainId": 300},
373
+ {"name": "fish", "id": 948, "trainId": 301},
374
+ {"name": "easel", "id": 839, "trainId": 302},
375
+ {"name": "artificial golf green", "id": 63, "trainId": 303},
376
+ {"name": "iceberg", "id": 1305, "trainId": 304},
377
+ {"name": "candlestick, candle holder", "id": 378, "trainId": 305},
378
+ {"name": "shower stall, shower bath", "id": 2362, "trainId": 306},
379
+ {"name": "television stand", "id": 2734, "trainId": 307},
380
+ {
381
+ "name": "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle",
382
+ "id": 2982,
383
+ "trainId": 308,
384
+ },
385
+ {"name": "skeleton", "id": 2398, "trainId": 309},
386
+ {"name": "grand piano, grand", "id": 1119, "trainId": 310},
387
+ {"name": "candy, confect", "id": 382, "trainId": 311},
388
+ {"name": "grille door", "id": 1141, "trainId": 312},
389
+ {"name": "pedestal, plinth, footstall", "id": 1805, "trainId": 313},
390
+ {"name": "jersey, t-shirt, tee shirt", "id": 3102, "trainId": 314},
391
+ {"name": "shoe", "id": 2341, "trainId": 315},
392
+ {"name": "gravestone, headstone, tombstone", "id": 1131, "trainId": 316},
393
+ {"name": "shanty", "id": 2316, "trainId": 317},
394
+ {"name": "structure", "id": 2626, "trainId": 318},
395
+ {"name": "rocking chair, rocker", "id": 3104, "trainId": 319},
396
+ {"name": "bird", "id": 198, "trainId": 320},
397
+ {"name": "place mat", "id": 1896, "trainId": 321},
398
+ {"name": "tomb", "id": 2800, "trainId": 322},
399
+ {"name": "big top", "id": 190, "trainId": 323},
400
+ {
401
+ "name": "gas pump, gasoline pump, petrol pump, island dispenser",
402
+ "id": 3131,
403
+ "trainId": 324,
404
+ },
405
+ {"name": "lockers", "id": 1463, "trainId": 325},
406
+ {"name": "cage", "id": 357, "trainId": 326},
407
+ {"name": "finger", "id": 929, "trainId": 327},
408
+ {"name": "bleachers", "id": 209, "trainId": 328},
409
+ {"name": "ferris wheel", "id": 912, "trainId": 329},
410
+ {"name": "hairdresser chair", "id": 1164, "trainId": 330},
411
+ {"name": "mat", "id": 1509, "trainId": 331},
412
+ {"name": "stands", "id": 2539, "trainId": 332},
413
+ {"name": "aquarium, fish tank, marine museum", "id": 3116, "trainId": 333},
414
+ {
415
+ "name": "streetcar, tram, tramcar, trolley, trolley car",
416
+ "id": 2615,
417
+ "trainId": 334,
418
+ },
419
+ {"name": "napkin, table napkin, serviette", "id": 1644, "trainId": 335},
420
+ {"name": "dummy", "id": 818, "trainId": 336},
421
+ {"name": "booklet, brochure, folder, leaflet, pamphlet", "id": 242, "trainId": 337},
422
+ {"name": "sand trap", "id": 2217, "trainId": 338},
423
+ {"name": "shop, store", "id": 2347, "trainId": 339},
424
+ {"name": "table cloth", "id": 2686, "trainId": 340},
425
+ {"name": "service station", "id": 2300, "trainId": 341},
426
+ {"name": "coffin", "id": 572, "trainId": 342},
427
+ {"name": "drawer", "id": 789, "trainId": 343},
428
+ {"name": "cages", "id": 358, "trainId": 344},
429
+ {"name": "slot machine, coin machine", "id": 2443, "trainId": 345},
430
+ {"name": "balcony", "id": 101, "trainId": 346},
431
+ {"name": "volleyball court", "id": 2969, "trainId": 347},
432
+ {"name": "table tennis", "id": 2692, "trainId": 348},
433
+ {"name": "control table", "id": 606, "trainId": 349},
434
+ {"name": "shirt", "id": 2339, "trainId": 350},
435
+ {"name": "merchandise, ware, product", "id": 1533, "trainId": 351},
436
+ {"name": "railway", "id": 2060, "trainId": 352},
437
+ {"name": "parterre", "id": 1782, "trainId": 353},
438
+ {"name": "chimney", "id": 495, "trainId": 354},
439
+ {"name": "can, tin, tin can", "id": 371, "trainId": 355},
440
+ {"name": "tanks", "id": 2707, "trainId": 356},
441
+ {"name": "fabric, cloth, material, textile", "id": 889, "trainId": 357},
442
+ {"name": "alga, algae", "id": 3156, "trainId": 358},
443
+ {"name": "system", "id": 2683, "trainId": 359},
444
+ {"name": "map", "id": 1499, "trainId": 360},
445
+ {"name": "greenhouse", "id": 1135, "trainId": 361},
446
+ {"name": "mug", "id": 1619, "trainId": 362},
447
+ {"name": "barbecue", "id": 125, "trainId": 363},
448
+ {"name": "trailer", "id": 2838, "trainId": 364},
449
+ {
450
+ "name": "toilet tissue, toilet paper, bathroom tissue",
451
+ "id": 2792,
452
+ "trainId": 365,
453
+ },
454
+ {"name": "organ", "id": 1695, "trainId": 366},
455
+ {"name": "dishrag, dishcloth", "id": 746, "trainId": 367},
456
+ {"name": "island", "id": 1343, "trainId": 368},
457
+ {"name": "keyboard", "id": 1370, "trainId": 369},
458
+ {"name": "trench", "id": 2858, "trainId": 370},
459
+ {"name": "basket, basketball hoop, hoop", "id": 145, "trainId": 371},
460
+ {"name": "steering wheel, wheel", "id": 2565, "trainId": 372},
461
+ {"name": "pitcher, ewer", "id": 1892, "trainId": 373},
462
+ {"name": "goal", "id": 1103, "trainId": 374},
463
+ {"name": "bread, breadstuff, staff of life", "id": 286, "trainId": 375},
464
+ {"name": "beds", "id": 170, "trainId": 376},
465
+ {"name": "wood", "id": 3073, "trainId": 377},
466
+ {"name": "file cabinet", "id": 922, "trainId": 378},
467
+ {"name": "newspaper, paper", "id": 1655, "trainId": 379},
468
+ {"name": "motorboat", "id": 1602, "trainId": 380},
469
+ {"name": "rope", "id": 2160, "trainId": 381},
470
+ {"name": "guitar", "id": 1151, "trainId": 382},
471
+ {"name": "rubble", "id": 2176, "trainId": 383},
472
+ {"name": "scarf", "id": 2239, "trainId": 384},
473
+ {"name": "barrels", "id": 132, "trainId": 385},
474
+ {"name": "cap", "id": 394, "trainId": 386},
475
+ {"name": "leaves", "id": 1424, "trainId": 387},
476
+ {"name": "control tower", "id": 607, "trainId": 388},
477
+ {"name": "dashboard", "id": 700, "trainId": 389},
478
+ {"name": "bandstand", "id": 116, "trainId": 390},
479
+ {"name": "lectern", "id": 1425, "trainId": 391},
480
+ {"name": "switch, electric switch, electrical switch", "id": 2676, "trainId": 392},
481
+ {"name": "baseboard, mopboard, skirting board", "id": 141, "trainId": 393},
482
+ {"name": "shower room", "id": 2360, "trainId": 394},
483
+ {"name": "smoke", "id": 2449, "trainId": 395},
484
+ {"name": "faucet, spigot", "id": 897, "trainId": 396},
485
+ {"name": "bulldozer", "id": 317, "trainId": 397},
486
+ {"name": "saucepan", "id": 2228, "trainId": 398},
487
+ {"name": "shops", "id": 2351, "trainId": 399},
488
+ {"name": "meter", "id": 1543, "trainId": 400},
489
+ {"name": "crevasse", "id": 656, "trainId": 401},
490
+ {"name": "gear", "id": 1088, "trainId": 402},
491
+ {"name": "candelabrum, candelabra", "id": 373, "trainId": 403},
492
+ {"name": "sofa bed", "id": 2472, "trainId": 404},
493
+ {"name": "tunnel", "id": 2892, "trainId": 405},
494
+ {"name": "pallet", "id": 1740, "trainId": 406},
495
+ {"name": "wire, conducting wire", "id": 3067, "trainId": 407},
496
+ {"name": "kettle, boiler", "id": 1367, "trainId": 408},
497
+ {"name": "bidet", "id": 188, "trainId": 409},
498
+ {
499
+ "name": "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher",
500
+ "id": 79,
501
+ "trainId": 410,
502
+ },
503
+ {"name": "music stand", "id": 1633, "trainId": 411},
504
+ {"name": "pipe, tube", "id": 1885, "trainId": 412},
505
+ {"name": "cup", "id": 677, "trainId": 413},
506
+ {"name": "parking meter", "id": 1779, "trainId": 414},
507
+ {"name": "ice hockey rink", "id": 1297, "trainId": 415},
508
+ {"name": "shelter", "id": 2334, "trainId": 416},
509
+ {"name": "weeds", "id": 3027, "trainId": 417},
510
+ {"name": "temple", "id": 2735, "trainId": 418},
511
+ {"name": "patty, cake", "id": 1791, "trainId": 419},
512
+ {"name": "ski slope", "id": 2405, "trainId": 420},
513
+ {"name": "panel", "id": 1748, "trainId": 421},
514
+ {"name": "wallet", "id": 2983, "trainId": 422},
515
+ {"name": "wheel", "id": 3035, "trainId": 423},
516
+ {"name": "towel rack, towel horse", "id": 2824, "trainId": 424},
517
+ {"name": "roundabout", "id": 2168, "trainId": 425},
518
+ {"name": "canister, cannister, tin", "id": 385, "trainId": 426},
519
+ {"name": "rod", "id": 2148, "trainId": 427},
520
+ {"name": "soap dispenser", "id": 2465, "trainId": 428},
521
+ {"name": "bell", "id": 175, "trainId": 429},
522
+ {"name": "canvas", "id": 390, "trainId": 430},
523
+ {"name": "box office, ticket office, ticket booth", "id": 268, "trainId": 431},
524
+ {"name": "teacup", "id": 2722, "trainId": 432},
525
+ {"name": "trellis", "id": 2857, "trainId": 433},
526
+ {"name": "workbench", "id": 3088, "trainId": 434},
527
+ {"name": "valley, vale", "id": 2926, "trainId": 435},
528
+ {"name": "toaster", "id": 2782, "trainId": 436},
529
+ {"name": "knife", "id": 1378, "trainId": 437},
530
+ {"name": "podium", "id": 1934, "trainId": 438},
531
+ {"name": "ramp", "id": 2072, "trainId": 439},
532
+ {"name": "tumble dryer", "id": 2889, "trainId": 440},
533
+ {"name": "fireplug, fire hydrant, plug", "id": 944, "trainId": 441},
534
+ {"name": "gym shoe, sneaker, tennis shoe", "id": 1158, "trainId": 442},
535
+ {"name": "lab bench", "id": 1383, "trainId": 443},
536
+ {"name": "equipment", "id": 867, "trainId": 444},
537
+ {"name": "rocky formation", "id": 2145, "trainId": 445},
538
+ {"name": "plastic", "id": 1915, "trainId": 446},
539
+ {"name": "calendar", "id": 361, "trainId": 447},
540
+ {"name": "caravan", "id": 402, "trainId": 448},
541
+ {"name": "check-in-desk", "id": 482, "trainId": 449},
542
+ {"name": "ticket counter", "id": 2761, "trainId": 450},
543
+ {"name": "brush", "id": 300, "trainId": 451},
544
+ {"name": "mill", "id": 1554, "trainId": 452},
545
+ {"name": "covered bridge", "id": 636, "trainId": 453},
546
+ {"name": "bowling alley", "id": 260, "trainId": 454},
547
+ {"name": "hanger", "id": 1186, "trainId": 455},
548
+ {"name": "excavator", "id": 871, "trainId": 456},
549
+ {"name": "trestle", "id": 2859, "trainId": 457},
550
+ {"name": "revolving door", "id": 2103, "trainId": 458},
551
+ {"name": "blast furnace", "id": 208, "trainId": 459},
552
+ {"name": "scale, weighing machine", "id": 2236, "trainId": 460},
553
+ {"name": "projector", "id": 2012, "trainId": 461},
554
+ {"name": "soap", "id": 2462, "trainId": 462},
555
+ {"name": "locker", "id": 1462, "trainId": 463},
556
+ {"name": "tractor", "id": 2832, "trainId": 464},
557
+ {"name": "stretcher", "id": 2617, "trainId": 465},
558
+ {"name": "frame", "id": 1024, "trainId": 466},
559
+ {"name": "grating", "id": 1129, "trainId": 467},
560
+ {"name": "alembic", "id": 18, "trainId": 468},
561
+ {"name": "candle, taper, wax light", "id": 376, "trainId": 469},
562
+ {"name": "barrier", "id": 134, "trainId": 470},
563
+ {"name": "cardboard", "id": 407, "trainId": 471},
564
+ {"name": "cave", "id": 434, "trainId": 472},
565
+ {"name": "puddle", "id": 2017, "trainId": 473},
566
+ {"name": "tarp", "id": 2717, "trainId": 474},
567
+ {"name": "price tag", "id": 2005, "trainId": 475},
568
+ {"name": "watchtower", "id": 2993, "trainId": 476},
569
+ {"name": "meters", "id": 1545, "trainId": 477},
570
+ {
571
+ "name": "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb",
572
+ "id": 1445,
573
+ "trainId": 478,
574
+ },
575
+ {"name": "tracks", "id": 2831, "trainId": 479},
576
+ {"name": "hair dryer", "id": 1161, "trainId": 480},
577
+ {"name": "skirt", "id": 2411, "trainId": 481},
578
+ {"name": "viaduct", "id": 2949, "trainId": 482},
579
+ {"name": "paper towel", "id": 1769, "trainId": 483},
580
+ {"name": "coat", "id": 552, "trainId": 484},
581
+ {"name": "sheet", "id": 2327, "trainId": 485},
582
+ {"name": "fire extinguisher, extinguisher, asphyxiator", "id": 939, "trainId": 486},
583
+ {"name": "water wheel", "id": 3013, "trainId": 487},
584
+ {"name": "pottery, clayware", "id": 1986, "trainId": 488},
585
+ {"name": "magazine rack", "id": 1486, "trainId": 489},
586
+ {"name": "teapot", "id": 2723, "trainId": 490},
587
+ {"name": "microphone, mike", "id": 1549, "trainId": 491},
588
+ {"name": "support", "id": 2649, "trainId": 492},
589
+ {"name": "forklift", "id": 1020, "trainId": 493},
590
+ {"name": "canyon", "id": 392, "trainId": 494},
591
+ {"name": "cash register, register", "id": 422, "trainId": 495},
592
+ {"name": "leaf, leafage, foliage", "id": 1419, "trainId": 496},
593
+ {"name": "remote control, remote", "id": 2099, "trainId": 497},
594
+ {"name": "soap dish", "id": 2464, "trainId": 498},
595
+ {"name": "windshield, windscreen", "id": 3058, "trainId": 499},
596
+ {"name": "cat", "id": 430, "trainId": 500},
597
+ {"name": "cue, cue stick, pool cue, pool stick", "id": 675, "trainId": 501},
598
+ {"name": "vent, venthole, vent-hole, blowhole", "id": 2941, "trainId": 502},
599
+ {"name": "videos", "id": 2955, "trainId": 503},
600
+ {"name": "shovel", "id": 2355, "trainId": 504},
601
+ {"name": "eaves", "id": 840, "trainId": 505},
602
+ {"name": "antenna, aerial, transmitting aerial", "id": 32, "trainId": 506},
603
+ {"name": "shipyard", "id": 2338, "trainId": 507},
604
+ {"name": "hen, biddy", "id": 1232, "trainId": 508},
605
+ {"name": "traffic cone", "id": 2834, "trainId": 509},
606
+ {"name": "washing machines", "id": 2991, "trainId": 510},
607
+ {"name": "truck crane", "id": 2879, "trainId": 511},
608
+ {"name": "cds", "id": 444, "trainId": 512},
609
+ {"name": "niche", "id": 1657, "trainId": 513},
610
+ {"name": "scoreboard", "id": 2246, "trainId": 514},
611
+ {"name": "briefcase", "id": 296, "trainId": 515},
612
+ {"name": "boot", "id": 245, "trainId": 516},
613
+ {"name": "sweater, jumper", "id": 2661, "trainId": 517},
614
+ {"name": "hay", "id": 1202, "trainId": 518},
615
+ {"name": "pack", "id": 1714, "trainId": 519},
616
+ {"name": "bottle rack", "id": 251, "trainId": 520},
617
+ {"name": "glacier", "id": 1095, "trainId": 521},
618
+ {"name": "pergola", "id": 1828, "trainId": 522},
619
+ {"name": "building materials", "id": 311, "trainId": 523},
620
+ {"name": "television camera", "id": 2732, "trainId": 524},
621
+ {"name": "first floor", "id": 947, "trainId": 525},
622
+ {"name": "rifle", "id": 2115, "trainId": 526},
623
+ {"name": "tennis table", "id": 2738, "trainId": 527},
624
+ {"name": "stadium", "id": 2525, "trainId": 528},
625
+ {"name": "safety belt", "id": 2194, "trainId": 529},
626
+ {"name": "cover", "id": 634, "trainId": 530},
627
+ {"name": "dish rack", "id": 740, "trainId": 531},
628
+ {"name": "synthesizer", "id": 2682, "trainId": 532},
629
+ {"name": "pumpkin", "id": 2020, "trainId": 533},
630
+ {"name": "gutter", "id": 1156, "trainId": 534},
631
+ {"name": "fruit stand", "id": 1036, "trainId": 535},
632
+ {"name": "ice floe, floe", "id": 1295, "trainId": 536},
633
+ {"name": "handle, grip, handgrip, hold", "id": 1181, "trainId": 537},
634
+ {"name": "wheelchair", "id": 3037, "trainId": 538},
635
+ {"name": "mousepad, mouse mat", "id": 1614, "trainId": 539},
636
+ {"name": "diploma", "id": 736, "trainId": 540},
637
+ {"name": "fairground ride", "id": 893, "trainId": 541},
638
+ {"name": "radio", "id": 2047, "trainId": 542},
639
+ {"name": "hotplate", "id": 1274, "trainId": 543},
640
+ {"name": "junk", "id": 1361, "trainId": 544},
641
+ {"name": "wheelbarrow", "id": 3036, "trainId": 545},
642
+ {"name": "stream", "id": 2606, "trainId": 546},
643
+ {"name": "toll plaza", "id": 2797, "trainId": 547},
644
+ {"name": "punching bag", "id": 2022, "trainId": 548},
645
+ {"name": "trough", "id": 2876, "trainId": 549},
646
+ {"name": "throne", "id": 2758, "trainId": 550},
647
+ {"name": "chair desk", "id": 472, "trainId": 551},
648
+ {"name": "weighbridge", "id": 3028, "trainId": 552},
649
+ {"name": "extractor fan", "id": 882, "trainId": 553},
650
+ {"name": "hanging clothes", "id": 1189, "trainId": 554},
651
+ {"name": "dish, dish aerial, dish antenna, saucer", "id": 743, "trainId": 555},
652
+ {"name": "alarm clock, alarm", "id": 3122, "trainId": 556},
653
+ {"name": "ski lift", "id": 2401, "trainId": 557},
654
+ {"name": "chain", "id": 468, "trainId": 558},
655
+ {"name": "garage", "id": 1061, "trainId": 559},
656
+ {"name": "mechanical shovel", "id": 1523, "trainId": 560},
657
+ {"name": "wine rack", "id": 3059, "trainId": 561},
658
+ {"name": "tramway", "id": 2843, "trainId": 562},
659
+ {"name": "treadmill", "id": 2853, "trainId": 563},
660
+ {"name": "menu", "id": 1529, "trainId": 564},
661
+ {"name": "block", "id": 214, "trainId": 565},
662
+ {"name": "well", "id": 3032, "trainId": 566},
663
+ {"name": "witness stand", "id": 3071, "trainId": 567},
664
+ {"name": "branch", "id": 277, "trainId": 568},
665
+ {"name": "duck", "id": 813, "trainId": 569},
666
+ {"name": "casserole", "id": 426, "trainId": 570},
667
+ {"name": "frying pan", "id": 1039, "trainId": 571},
668
+ {"name": "desk organizer", "id": 727, "trainId": 572},
669
+ {"name": "mast", "id": 1508, "trainId": 573},
670
+ {"name": "spectacles, specs, eyeglasses, glasses", "id": 2490, "trainId": 574},
671
+ {"name": "service elevator", "id": 2299, "trainId": 575},
672
+ {"name": "dollhouse", "id": 768, "trainId": 576},
673
+ {"name": "hammock", "id": 1172, "trainId": 577},
674
+ {"name": "clothes hanging", "id": 537, "trainId": 578},
675
+ {"name": "photocopier", "id": 1847, "trainId": 579},
676
+ {"name": "notepad", "id": 1664, "trainId": 580},
677
+ {"name": "golf cart", "id": 1110, "trainId": 581},
678
+ {"name": "footpath", "id": 1014, "trainId": 582},
679
+ {"name": "cross", "id": 662, "trainId": 583},
680
+ {"name": "baptismal font", "id": 121, "trainId": 584},
681
+ {"name": "boiler", "id": 227, "trainId": 585},
682
+ {"name": "skip", "id": 2410, "trainId": 586},
683
+ {"name": "rotisserie", "id": 2165, "trainId": 587},
684
+ {"name": "tables", "id": 2696, "trainId": 588},
685
+ {"name": "water mill", "id": 3005, "trainId": 589},
686
+ {"name": "helmet", "id": 1231, "trainId": 590},
687
+ {"name": "cover curtain", "id": 635, "trainId": 591},
688
+ {"name": "brick", "id": 292, "trainId": 592},
689
+ {"name": "table runner", "id": 2690, "trainId": 593},
690
+ {"name": "ashtray", "id": 65, "trainId": 594},
691
+ {"name": "street box", "id": 2607, "trainId": 595},
692
+ {"name": "stick", "id": 2574, "trainId": 596},
693
+ {"name": "hangers", "id": 1188, "trainId": 597},
694
+ {"name": "cells", "id": 456, "trainId": 598},
695
+ {"name": "urinal", "id": 2913, "trainId": 599},
696
+ {"name": "centerpiece", "id": 459, "trainId": 600},
697
+ {"name": "portable fridge", "id": 1955, "trainId": 601},
698
+ {"name": "dvds", "id": 827, "trainId": 602},
699
+ {"name": "golf club", "id": 1111, "trainId": 603},
700
+ {"name": "skirting board", "id": 2412, "trainId": 604},
701
+ {"name": "water cooler", "id": 2997, "trainId": 605},
702
+ {"name": "clipboard", "id": 528, "trainId": 606},
703
+ {"name": "camera, photographic camera", "id": 366, "trainId": 607},
704
+ {"name": "pigeonhole", "id": 1863, "trainId": 608},
705
+ {"name": "chips", "id": 500, "trainId": 609},
706
+ {"name": "food processor", "id": 1001, "trainId": 610},
707
+ {"name": "post box", "id": 1958, "trainId": 611},
708
+ {"name": "lid", "id": 1441, "trainId": 612},
709
+ {"name": "drum", "id": 809, "trainId": 613},
710
+ {"name": "blender", "id": 210, "trainId": 614},
711
+ {"name": "cave entrance", "id": 435, "trainId": 615},
712
+ {"name": "dental chair", "id": 718, "trainId": 616},
713
+ {"name": "obelisk", "id": 1674, "trainId": 617},
714
+ {"name": "canoe", "id": 388, "trainId": 618},
715
+ {"name": "mobile", "id": 1572, "trainId": 619},
716
+ {"name": "monitors", "id": 1584, "trainId": 620},
717
+ {"name": "pool ball", "id": 1944, "trainId": 621},
718
+ {"name": "cue rack", "id": 674, "trainId": 622},
719
+ {"name": "baggage carts", "id": 99, "trainId": 623},
720
+ {"name": "shore", "id": 2352, "trainId": 624},
721
+ {"name": "fork", "id": 1019, "trainId": 625},
722
+ {"name": "paper filer", "id": 1763, "trainId": 626},
723
+ {"name": "bicycle rack", "id": 185, "trainId": 627},
724
+ {"name": "coat rack", "id": 554, "trainId": 628},
725
+ {"name": "garland", "id": 1066, "trainId": 629},
726
+ {"name": "sports bag", "id": 2508, "trainId": 630},
727
+ {"name": "fish tank", "id": 951, "trainId": 631},
728
+ {"name": "towel dispenser", "id": 2822, "trainId": 632},
729
+ {"name": "carriage", "id": 415, "trainId": 633},
730
+ {"name": "brochure", "id": 297, "trainId": 634},
731
+ {"name": "plaque", "id": 1914, "trainId": 635},
732
+ {"name": "stringer", "id": 2619, "trainId": 636},
733
+ {"name": "iron", "id": 1338, "trainId": 637},
734
+ {"name": "spoon", "id": 2505, "trainId": 638},
735
+ {"name": "flag pole", "id": 955, "trainId": 639},
736
+ {"name": "toilet brush", "id": 2786, "trainId": 640},
737
+ {"name": "book stand", "id": 238, "trainId": 641},
738
+ {"name": "water faucet, water tap, tap, hydrant", "id": 3000, "trainId": 642},
739
+ {"name": "ticket office", "id": 2763, "trainId": 643},
740
+ {"name": "broom", "id": 299, "trainId": 644},
741
+ {"name": "dvd", "id": 822, "trainId": 645},
742
+ {"name": "ice bucket", "id": 1288, "trainId": 646},
743
+ {"name": "carapace, shell, cuticle, shield", "id": 3101, "trainId": 647},
744
+ {"name": "tureen", "id": 2894, "trainId": 648},
745
+ {"name": "folders", "id": 992, "trainId": 649},
746
+ {"name": "chess", "id": 489, "trainId": 650},
747
+ {"name": "root", "id": 2157, "trainId": 651},
748
+ {"name": "sewing machine", "id": 2309, "trainId": 652},
749
+ {"name": "model", "id": 1576, "trainId": 653},
750
+ {"name": "pen", "id": 1810, "trainId": 654},
751
+ {"name": "violin", "id": 2964, "trainId": 655},
752
+ {"name": "sweatshirt", "id": 2662, "trainId": 656},
753
+ {"name": "recycling materials", "id": 2087, "trainId": 657},
754
+ {"name": "mitten", "id": 1569, "trainId": 658},
755
+ {"name": "chopping board, cutting board", "id": 503, "trainId": 659},
756
+ {"name": "mask", "id": 1505, "trainId": 660},
757
+ {"name": "log", "id": 1468, "trainId": 661},
758
+ {"name": "mouse, computer mouse", "id": 1613, "trainId": 662},
759
+ {"name": "grill", "id": 1138, "trainId": 663},
760
+ {"name": "hole", "id": 1256, "trainId": 664},
761
+ {"name": "target", "id": 2715, "trainId": 665},
762
+ {"name": "trash bag", "id": 2846, "trainId": 666},
763
+ {"name": "chalk", "id": 477, "trainId": 667},
764
+ {"name": "sticks", "id": 2576, "trainId": 668},
765
+ {"name": "balloon", "id": 108, "trainId": 669},
766
+ {"name": "score", "id": 2245, "trainId": 670},
767
+ {"name": "hair spray", "id": 1162, "trainId": 671},
768
+ {"name": "roll", "id": 2149, "trainId": 672},
769
+ {"name": "runner", "id": 2183, "trainId": 673},
770
+ {"name": "engine", "id": 858, "trainId": 674},
771
+ {"name": "inflatable glove", "id": 1324, "trainId": 675},
772
+ {"name": "games", "id": 1055, "trainId": 676},
773
+ {"name": "pallets", "id": 1741, "trainId": 677},
774
+ {"name": "baskets", "id": 149, "trainId": 678},
775
+ {"name": "coop", "id": 615, "trainId": 679},
776
+ {"name": "dvd player", "id": 825, "trainId": 680},
777
+ {"name": "rocking horse", "id": 2143, "trainId": 681},
778
+ {"name": "buckets", "id": 304, "trainId": 682},
779
+ {"name": "bread rolls", "id": 283, "trainId": 683},
780
+ {"name": "shawl", "id": 2322, "trainId": 684},
781
+ {"name": "watering can", "id": 3017, "trainId": 685},
782
+ {"name": "spotlights", "id": 2510, "trainId": 686},
783
+ {"name": "post-it", "id": 1960, "trainId": 687},
784
+ {"name": "bowls", "id": 265, "trainId": 688},
785
+ {"name": "security camera", "id": 2282, "trainId": 689},
786
+ {"name": "runner cloth", "id": 2184, "trainId": 690},
787
+ {"name": "lock", "id": 1461, "trainId": 691},
788
+ {"name": "alarm, warning device, alarm system", "id": 3113, "trainId": 692},
789
+ {"name": "side", "id": 2372, "trainId": 693},
790
+ {"name": "roulette", "id": 2166, "trainId": 694},
791
+ {"name": "bone", "id": 232, "trainId": 695},
792
+ {"name": "cutlery", "id": 693, "trainId": 696},
793
+ {"name": "pool balls", "id": 1945, "trainId": 697},
794
+ {"name": "wheels", "id": 3039, "trainId": 698},
795
+ {"name": "spice rack", "id": 2494, "trainId": 699},
796
+ {"name": "plant pots", "id": 1908, "trainId": 700},
797
+ {"name": "towel ring", "id": 2827, "trainId": 701},
798
+ {"name": "bread box", "id": 280, "trainId": 702},
799
+ {"name": "video", "id": 2950, "trainId": 703},
800
+ {"name": "funfair", "id": 1044, "trainId": 704},
801
+ {"name": "breads", "id": 288, "trainId": 705},
802
+ {"name": "tripod", "id": 2863, "trainId": 706},
803
+ {"name": "ironing board", "id": 1342, "trainId": 707},
804
+ {"name": "skimmer", "id": 2409, "trainId": 708},
805
+ {"name": "hollow", "id": 1258, "trainId": 709},
806
+ {"name": "scratching post", "id": 2249, "trainId": 710},
807
+ {"name": "tricycle", "id": 2862, "trainId": 711},
808
+ {"name": "file box", "id": 920, "trainId": 712},
809
+ {"name": "mountain pass", "id": 1607, "trainId": 713},
810
+ {"name": "tombstones", "id": 2802, "trainId": 714},
811
+ {"name": "cooker", "id": 610, "trainId": 715},
812
+ {"name": "card game, cards", "id": 3129, "trainId": 716},
813
+ {"name": "golf bag", "id": 1108, "trainId": 717},
814
+ {"name": "towel paper", "id": 2823, "trainId": 718},
815
+ {"name": "chaise lounge", "id": 476, "trainId": 719},
816
+ {"name": "sun", "id": 2641, "trainId": 720},
817
+ {"name": "toilet paper holder", "id": 2788, "trainId": 721},
818
+ {"name": "rake", "id": 2070, "trainId": 722},
819
+ {"name": "key", "id": 1368, "trainId": 723},
820
+ {"name": "umbrella stand", "id": 2903, "trainId": 724},
821
+ {"name": "dartboard", "id": 699, "trainId": 725},
822
+ {"name": "transformer", "id": 2844, "trainId": 726},
823
+ {"name": "fireplace utensils", "id": 942, "trainId": 727},
824
+ {"name": "sweatshirts", "id": 2663, "trainId": 728},
825
+ {
826
+ "name": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
827
+ "id": 457,
828
+ "trainId": 729,
829
+ },
830
+ {"name": "tallboy", "id": 2701, "trainId": 730},
831
+ {"name": "stapler", "id": 2540, "trainId": 731},
832
+ {"name": "sauna", "id": 2231, "trainId": 732},
833
+ {"name": "test tube", "id": 2746, "trainId": 733},
834
+ {"name": "palette", "id": 1738, "trainId": 734},
835
+ {"name": "shopping carts", "id": 2350, "trainId": 735},
836
+ {"name": "tools", "id": 2808, "trainId": 736},
837
+ {"name": "push button, push, button", "id": 2025, "trainId": 737},
838
+ {"name": "star", "id": 2541, "trainId": 738},
839
+ {"name": "roof rack", "id": 2156, "trainId": 739},
840
+ {"name": "barbed wire", "id": 126, "trainId": 740},
841
+ {"name": "spray", "id": 2512, "trainId": 741},
842
+ {"name": "ear", "id": 831, "trainId": 742},
843
+ {"name": "sponge", "id": 2503, "trainId": 743},
844
+ {"name": "racket", "id": 2039, "trainId": 744},
845
+ {"name": "tins", "id": 2774, "trainId": 745},
846
+ {"name": "eyeglasses", "id": 886, "trainId": 746},
847
+ {"name": "file", "id": 919, "trainId": 747},
848
+ {"name": "scarfs", "id": 2240, "trainId": 748},
849
+ {"name": "sugar bowl", "id": 2636, "trainId": 749},
850
+ {"name": "flip flop", "id": 963, "trainId": 750},
851
+ {"name": "headstones", "id": 1218, "trainId": 751},
852
+ {"name": "laptop bag", "id": 1406, "trainId": 752},
853
+ {"name": "leash", "id": 1420, "trainId": 753},
854
+ {"name": "climbing frame", "id": 526, "trainId": 754},
855
+ {"name": "suit hanger", "id": 2639, "trainId": 755},
856
+ {"name": "floor spotlight", "id": 975, "trainId": 756},
857
+ {"name": "plate rack", "id": 1921, "trainId": 757},
858
+ {"name": "sewer", "id": 2305, "trainId": 758},
859
+ {"name": "hard drive", "id": 1193, "trainId": 759},
860
+ {"name": "sprinkler", "id": 2517, "trainId": 760},
861
+ {"name": "tools box", "id": 2809, "trainId": 761},
862
+ {"name": "necklace", "id": 1647, "trainId": 762},
863
+ {"name": "bulbs", "id": 314, "trainId": 763},
864
+ {"name": "steel industry", "id": 2560, "trainId": 764},
865
+ {"name": "club", "id": 545, "trainId": 765},
866
+ {"name": "jack", "id": 1345, "trainId": 766},
867
+ {"name": "door bars", "id": 775, "trainId": 767},
868
+ {
869
+ "name": "control panel, instrument panel, control board, board, panel",
870
+ "id": 603,
871
+ "trainId": 768,
872
+ },
873
+ {"name": "hairbrush", "id": 1163, "trainId": 769},
874
+ {"name": "napkin holder", "id": 1641, "trainId": 770},
875
+ {"name": "office", "id": 1678, "trainId": 771},
876
+ {"name": "smoke detector", "id": 2450, "trainId": 772},
877
+ {"name": "utensils", "id": 2915, "trainId": 773},
878
+ {"name": "apron", "id": 42, "trainId": 774},
879
+ {"name": "scissors", "id": 2242, "trainId": 775},
880
+ {"name": "terminal", "id": 2741, "trainId": 776},
881
+ {"name": "grinder", "id": 1143, "trainId": 777},
882
+ {"name": "entry phone", "id": 862, "trainId": 778},
883
+ {"name": "newspaper stand", "id": 1654, "trainId": 779},
884
+ {"name": "pepper shaker", "id": 1826, "trainId": 780},
885
+ {"name": "onions", "id": 1689, "trainId": 781},
886
+ {
887
+ "name": "central processing unit, cpu, c p u , central processor, processor, mainframe",
888
+ "id": 3124,
889
+ "trainId": 782,
890
+ },
891
+ {"name": "tape", "id": 2710, "trainId": 783},
892
+ {"name": "bat", "id": 152, "trainId": 784},
893
+ {"name": "coaster", "id": 549, "trainId": 785},
894
+ {"name": "calculator", "id": 360, "trainId": 786},
895
+ {"name": "potatoes", "id": 1982, "trainId": 787},
896
+ {"name": "luggage rack", "id": 1478, "trainId": 788},
897
+ {"name": "salt", "id": 2203, "trainId": 789},
898
+ {"name": "street number", "id": 2612, "trainId": 790},
899
+ {"name": "viewpoint", "id": 2956, "trainId": 791},
900
+ {"name": "sword", "id": 2681, "trainId": 792},
901
+ {"name": "cd", "id": 437, "trainId": 793},
902
+ {"name": "rowing machine", "id": 2171, "trainId": 794},
903
+ {"name": "plug", "id": 1933, "trainId": 795},
904
+ {"name": "andiron, firedog, dog, dog-iron", "id": 3110, "trainId": 796},
905
+ {"name": "pepper", "id": 1824, "trainId": 797},
906
+ {"name": "tongs", "id": 2803, "trainId": 798},
907
+ {"name": "bonfire", "id": 234, "trainId": 799},
908
+ {"name": "dog dish", "id": 764, "trainId": 800},
909
+ {"name": "belt", "id": 177, "trainId": 801},
910
+ {"name": "dumbbells", "id": 817, "trainId": 802},
911
+ {"name": "videocassette recorder, vcr", "id": 3145, "trainId": 803},
912
+ {"name": "hook", "id": 1262, "trainId": 804},
913
+ {"name": "envelopes", "id": 864, "trainId": 805},
914
+ {"name": "shower faucet", "id": 2359, "trainId": 806},
915
+ {"name": "watch", "id": 2992, "trainId": 807},
916
+ {"name": "padlock", "id": 1725, "trainId": 808},
917
+ {"name": "swimming pool ladder", "id": 2667, "trainId": 809},
918
+ {"name": "spanners", "id": 2484, "trainId": 810},
919
+ {"name": "gravy boat", "id": 1133, "trainId": 811},
920
+ {"name": "notice board", "id": 1667, "trainId": 812},
921
+ {"name": "trash bags", "id": 2847, "trainId": 813},
922
+ {"name": "fire alarm", "id": 932, "trainId": 814},
923
+ {"name": "ladle", "id": 1392, "trainId": 815},
924
+ {"name": "stethoscope", "id": 2573, "trainId": 816},
925
+ {"name": "rocket", "id": 2140, "trainId": 817},
926
+ {"name": "funnel", "id": 1046, "trainId": 818},
927
+ {"name": "bowling pins", "id": 264, "trainId": 819},
928
+ {"name": "valve", "id": 2927, "trainId": 820},
929
+ {"name": "thermometer", "id": 2752, "trainId": 821},
930
+ {"name": "cups", "id": 679, "trainId": 822},
931
+ {"name": "spice jar", "id": 2493, "trainId": 823},
932
+ {"name": "night light", "id": 1658, "trainId": 824},
933
+ {"name": "soaps", "id": 2466, "trainId": 825},
934
+ {"name": "games table", "id": 1057, "trainId": 826},
935
+ {"name": "slotted spoon", "id": 2444, "trainId": 827},
936
+ {"name": "reel", "id": 2093, "trainId": 828},
937
+ {"name": "scourer", "id": 2248, "trainId": 829},
938
+ {"name": "sleeping robe", "id": 2432, "trainId": 830},
939
+ {"name": "desk mat", "id": 726, "trainId": 831},
940
+ {"name": "dumbbell", "id": 816, "trainId": 832},
941
+ {"name": "hammer", "id": 1171, "trainId": 833},
942
+ {"name": "tie", "id": 2766, "trainId": 834},
943
+ {"name": "typewriter", "id": 2900, "trainId": 835},
944
+ {"name": "shaker", "id": 2313, "trainId": 836},
945
+ {"name": "cheese dish", "id": 488, "trainId": 837},
946
+ {"name": "sea star", "id": 2265, "trainId": 838},
947
+ {"name": "racquet", "id": 2043, "trainId": 839},
948
+ {"name": "butane gas cylinder", "id": 332, "trainId": 840},
949
+ {"name": "paper weight", "id": 1771, "trainId": 841},
950
+ {"name": "shaving brush", "id": 2320, "trainId": 842},
951
+ {"name": "sunglasses", "id": 2646, "trainId": 843},
952
+ {"name": "gear shift", "id": 1089, "trainId": 844},
953
+ {"name": "towel rail", "id": 2826, "trainId": 845},
954
+ {"name": "adding machine, totalizer, totaliser", "id": 3148, "trainId": 846},
955
+ ]
956
+
957
+
958
+ def _get_ade20k_full_meta():
959
+ stuff_ids = [k["id"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
960
+ assert len(stuff_ids) == 847, len(stuff_ids)
961
+
962
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
963
+ stuff_classes = [k["name"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
964
+
965
+ ret = {
966
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
967
+ "stuff_classes": stuff_classes,
968
+ }
969
+ return ret
970
+
971
+
972
+ def register_all_ade20k_full(root):
973
+ meta = _get_ade20k_full_meta()
974
+ for name, dirname in [("val", "validation")]:
975
+ image_dir = os.path.join(root, "ADE20K_2021_17_01/images_detectron2", dirname)
976
+ gt_dir = os.path.join(root, "ADE20K_2021_17_01/annotations_detectron2", dirname)
977
+ name = f"ade20k_full_sem_seg_{name}"
978
+ DatasetCatalog.register(
979
+ name,
980
+ lambda x=image_dir, y=gt_dir: load_sem_seg(
981
+ y, x, gt_ext="tif", image_ext="jpg"
982
+ ),
983
+ )
984
+ MetadataCatalog.get(name).set(
985
+ stuff_classes=meta["stuff_classes"][:],
986
+ thing_classes=meta["stuff_classes"][:], # the same as stuff_classes
987
+ image_root=image_dir,
988
+ sem_seg_root=gt_dir,
989
+ evaluator_type="sem_seg",
990
+ ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images
991
+ )
992
+
993
+
994
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
995
+ register_all_ade20k_full(_root)
open_vocab_seg/data/datasets/register_cc3m.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ import pandas as pd
5
+ from detectron2.data import DatasetCatalog, MetadataCatalog
6
+ from detectron2.data.datasets import load_sem_seg
7
+ from detectron2.utils.file_io import PathManager
8
+
9
+
10
+ COCO_CATEGORIES = [
11
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
12
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
13
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
14
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
15
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
16
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
17
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
18
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
19
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
20
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
21
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
22
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
23
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
24
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
25
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
26
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
27
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
28
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
29
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
30
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
31
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
32
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
33
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
34
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
35
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
36
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
37
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
38
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
39
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
40
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
41
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
42
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
43
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
44
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
45
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
46
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
47
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
48
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
49
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
50
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
51
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
52
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
53
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
54
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
55
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
56
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
57
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
58
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
59
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
60
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
61
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
62
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
63
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
64
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
65
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
66
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
67
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
68
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
69
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
70
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
71
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
72
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
73
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
74
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
75
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
76
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
77
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
78
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
79
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
80
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
81
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
82
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
83
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
84
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
85
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
86
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
87
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
88
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
89
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
90
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
91
+ {"id": 92, "name": "banner", "supercategory": "textile"},
92
+ {"id": 93, "name": "blanket", "supercategory": "textile"},
93
+ {"id": 94, "name": "branch", "supercategory": "plant"},
94
+ {"id": 95, "name": "bridge", "supercategory": "building"},
95
+ {"id": 96, "name": "building-other", "supercategory": "building"},
96
+ {"id": 97, "name": "bush", "supercategory": "plant"},
97
+ {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
98
+ {"id": 99, "name": "cage", "supercategory": "structural"},
99
+ {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
100
+ {"id": 101, "name": "carpet", "supercategory": "floor"},
101
+ {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
102
+ {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
103
+ {"id": 104, "name": "cloth", "supercategory": "textile"},
104
+ {"id": 105, "name": "clothes", "supercategory": "textile"},
105
+ {"id": 106, "name": "clouds", "supercategory": "sky"},
106
+ {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
107
+ {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
108
+ {"id": 109, "name": "curtain", "supercategory": "textile"},
109
+ {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
110
+ {"id": 111, "name": "dirt", "supercategory": "ground"},
111
+ {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
112
+ {"id": 113, "name": "fence", "supercategory": "structural"},
113
+ {"id": 114, "name": "floor-marble", "supercategory": "floor"},
114
+ {"id": 115, "name": "floor-other", "supercategory": "floor"},
115
+ {"id": 116, "name": "floor-stone", "supercategory": "floor"},
116
+ {"id": 117, "name": "floor-tile", "supercategory": "floor"},
117
+ {"id": 118, "name": "floor-wood", "supercategory": "floor"},
118
+ {"id": 119, "name": "flower", "supercategory": "plant"},
119
+ {"id": 120, "name": "fog", "supercategory": "water"},
120
+ {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
121
+ {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
122
+ {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
123
+ {"id": 124, "name": "grass", "supercategory": "plant"},
124
+ {"id": 125, "name": "gravel", "supercategory": "ground"},
125
+ {"id": 126, "name": "ground-other", "supercategory": "ground"},
126
+ {"id": 127, "name": "hill", "supercategory": "solid"},
127
+ {"id": 128, "name": "house", "supercategory": "building"},
128
+ {"id": 129, "name": "leaves", "supercategory": "plant"},
129
+ {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
130
+ {"id": 131, "name": "mat", "supercategory": "textile"},
131
+ {"id": 132, "name": "metal", "supercategory": "raw-material"},
132
+ {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
133
+ {"id": 134, "name": "moss", "supercategory": "plant"},
134
+ {"id": 135, "name": "mountain", "supercategory": "solid"},
135
+ {"id": 136, "name": "mud", "supercategory": "ground"},
136
+ {"id": 137, "name": "napkin", "supercategory": "textile"},
137
+ {"id": 138, "name": "net", "supercategory": "structural"},
138
+ {"id": 139, "name": "paper", "supercategory": "raw-material"},
139
+ {"id": 140, "name": "pavement", "supercategory": "ground"},
140
+ {"id": 141, "name": "pillow", "supercategory": "textile"},
141
+ {"id": 142, "name": "plant-other", "supercategory": "plant"},
142
+ {"id": 143, "name": "plastic", "supercategory": "raw-material"},
143
+ {"id": 144, "name": "platform", "supercategory": "ground"},
144
+ {"id": 145, "name": "playingfield", "supercategory": "ground"},
145
+ {"id": 146, "name": "railing", "supercategory": "structural"},
146
+ {"id": 147, "name": "railroad", "supercategory": "ground"},
147
+ {"id": 148, "name": "river", "supercategory": "water"},
148
+ {"id": 149, "name": "road", "supercategory": "ground"},
149
+ {"id": 150, "name": "rock", "supercategory": "solid"},
150
+ {"id": 151, "name": "roof", "supercategory": "building"},
151
+ {"id": 152, "name": "rug", "supercategory": "textile"},
152
+ {"id": 153, "name": "salad", "supercategory": "food-stuff"},
153
+ {"id": 154, "name": "sand", "supercategory": "ground"},
154
+ {"id": 155, "name": "sea", "supercategory": "water"},
155
+ {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
156
+ {"id": 157, "name": "sky-other", "supercategory": "sky"},
157
+ {"id": 158, "name": "skyscraper", "supercategory": "building"},
158
+ {"id": 159, "name": "snow", "supercategory": "ground"},
159
+ {"id": 160, "name": "solid-other", "supercategory": "solid"},
160
+ {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
161
+ {"id": 162, "name": "stone", "supercategory": "solid"},
162
+ {"id": 163, "name": "straw", "supercategory": "plant"},
163
+ {"id": 164, "name": "structural-other", "supercategory": "structural"},
164
+ {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
165
+ {"id": 166, "name": "tent", "supercategory": "building"},
166
+ {"id": 167, "name": "textile-other", "supercategory": "textile"},
167
+ {"id": 168, "name": "towel", "supercategory": "textile"},
168
+ {"id": 169, "name": "tree", "supercategory": "plant"},
169
+ {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
170
+ {"id": 171, "name": "wall-brick", "supercategory": "wall"},
171
+ {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
172
+ {"id": 173, "name": "wall-other", "supercategory": "wall"},
173
+ {"id": 174, "name": "wall-panel", "supercategory": "wall"},
174
+ {"id": 175, "name": "wall-stone", "supercategory": "wall"},
175
+ {"id": 176, "name": "wall-tile", "supercategory": "wall"},
176
+ {"id": 177, "name": "wall-wood", "supercategory": "wall"},
177
+ {"id": 178, "name": "water-other", "supercategory": "water"},
178
+ {"id": 179, "name": "waterdrops", "supercategory": "water"},
179
+ {"id": 180, "name": "window-blind", "supercategory": "window"},
180
+ {"id": 181, "name": "window-other", "supercategory": "window"},
181
+ {"id": 182, "name": "wood", "supercategory": "solid"},
182
+ ]
183
+
184
+
185
+ ADE20K_150_CATEGORIES = [
186
+ {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
187
+ {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
188
+ {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
189
+ {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
190
+ {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
191
+ {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
192
+ {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
193
+ {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
194
+ {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
195
+ {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
196
+ {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
197
+ {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
198
+ {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
199
+ {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
200
+ {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
201
+ {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
202
+ {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
203
+ {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
204
+ {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
205
+ {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
206
+ {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
207
+ {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
208
+ {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
209
+ {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
210
+ {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
211
+ {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
212
+ {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
213
+ {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
214
+ {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
215
+ {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
216
+ {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
217
+ {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
218
+ {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
219
+ {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
220
+ {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
221
+ {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
222
+ {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
223
+ {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
224
+ {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
225
+ {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
226
+ {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
227
+ {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
228
+ {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
229
+ {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
230
+ {
231
+ "color": [6, 51, 255],
232
+ "id": 44,
233
+ "isthing": 1,
234
+ "name": "chest of drawers, chest, bureau, dresser",
235
+ },
236
+ {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
237
+ {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
238
+ {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
239
+ {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
240
+ {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
241
+ {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
242
+ {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
243
+ {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
244
+ {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
245
+ {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
246
+ {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
247
+ {
248
+ "color": [255, 71, 0],
249
+ "id": 56,
250
+ "isthing": 1,
251
+ "name": "pool table, billiard table, snooker table",
252
+ },
253
+ {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
254
+ {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
255
+ {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
256
+ {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
257
+ {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
258
+ {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
259
+ {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
260
+ {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
261
+ {
262
+ "color": [0, 255, 133],
263
+ "id": 65,
264
+ "isthing": 1,
265
+ "name": "toilet, can, commode, crapper, pot, potty, stool, throne",
266
+ },
267
+ {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
268
+ {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
269
+ {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
270
+ {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
271
+ {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
272
+ {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
273
+ {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
274
+ {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
275
+ {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
276
+ {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
277
+ {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
278
+ {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
279
+ {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
280
+ {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
281
+ {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
282
+ {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
283
+ {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
284
+ {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
285
+ {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
286
+ {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
287
+ {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
288
+ {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
289
+ {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
290
+ {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
291
+ {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
292
+ {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
293
+ {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
294
+ {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
295
+ {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
296
+ {
297
+ "color": [0, 122, 255],
298
+ "id": 95,
299
+ "isthing": 1,
300
+ "name": "bannister, banister, balustrade, balusters, handrail",
301
+ },
302
+ {
303
+ "color": [0, 255, 163],
304
+ "id": 96,
305
+ "isthing": 0,
306
+ "name": "escalator, moving staircase, moving stairway",
307
+ },
308
+ {
309
+ "color": [255, 153, 0],
310
+ "id": 97,
311
+ "isthing": 1,
312
+ "name": "ottoman, pouf, pouffe, puff, hassock",
313
+ },
314
+ {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
315
+ {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
316
+ {
317
+ "color": [143, 255, 0],
318
+ "id": 100,
319
+ "isthing": 0,
320
+ "name": "poster, posting, placard, notice, bill, card",
321
+ },
322
+ {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
323
+ {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
324
+ {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
325
+ {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
326
+ {
327
+ "color": [133, 0, 255],
328
+ "id": 105,
329
+ "isthing": 0,
330
+ "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
331
+ },
332
+ {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
333
+ {
334
+ "color": [184, 0, 255],
335
+ "id": 107,
336
+ "isthing": 1,
337
+ "name": "washer, automatic washer, washing machine",
338
+ },
339
+ {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
340
+ {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
341
+ {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
342
+ {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
343
+ {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
344
+ {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
345
+ {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
346
+ {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
347
+ {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
348
+ {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
349
+ {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
350
+ {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
351
+ {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
352
+ {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
353
+ {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
354
+ {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
355
+ {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
356
+ {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
357
+ {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
358
+ {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
359
+ {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
360
+ {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
361
+ {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
362
+ {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
363
+ {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
364
+ {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
365
+ {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
366
+ {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
367
+ {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
368
+ {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
369
+ {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
370
+ {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
371
+ {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
372
+ {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
373
+ {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
374
+ {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
375
+ {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
376
+ {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
377
+ {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
378
+ {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
379
+ {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
380
+ {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
381
+ ]
382
+
383
+ TEST_CATEGORIES = [
384
+ {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "Oculus"},
385
+ {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "Ukulele"},
386
+ ]
387
+
388
+ COCO_BASE_CATEGORIES = [
389
+ c
390
+ for i, c in enumerate(COCO_CATEGORIES)
391
+ if c["id"] - 1
392
+ not in [20, 24, 32, 33, 40, 56, 86, 99, 105, 123, 144, 147, 148, 168, 171]
393
+ ]
394
+ COCO_NOVEL_CATEGORIES = [
395
+ c
396
+ for i, c in enumerate(COCO_CATEGORIES)
397
+ if c["id"] - 1
398
+ in [20, 24, 32, 33, 40, 56, 86, 99, 105, 123, 144, 147, 148, 168, 171]
399
+ ]
400
+
401
+
402
+ def load_cc_image(csv_file, img_key='filepath', caption_key='title', sep="\t"):
403
+ print(f'Loading csv data from {csv_file}.')
404
+ df = pd.read_csv(csv_file, sep=sep)
405
+
406
+ input_files = df[img_key].tolist()
407
+ captions = df[caption_key].tolist()
408
+
409
+ print("Loaded {} images".format(len(input_files)))
410
+
411
+ dataset_dicts = []
412
+ for (img_path, text) in zip(input_files, captions):
413
+ record = {}
414
+ record["file_name"] = img_path
415
+ record["caption"] = text
416
+ dataset_dicts.append(record)
417
+
418
+ return dataset_dicts
419
+
420
+
421
+ def _get_coco_stuff_meta(cat_list):
422
+ # Id 0 is reserved for ignore_label, we change ignore_label for 0
423
+ # to 255 in our pre-processing.
424
+ stuff_ids = [k["id"] for k in cat_list]
425
+
426
+ # For semantic segmentation, this mapping maps from contiguous stuff id
427
+ # (in [0, 91], used in models) to ids in the dataset (used for processing results)
428
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
429
+ stuff_classes = [k["name"] for k in cat_list]
430
+
431
+ ret = {
432
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
433
+ "stuff_classes": stuff_classes,
434
+ }
435
+ return ret
436
+
437
+
438
+ def register_cc_3m(csv_file):
439
+
440
+ meta = _get_coco_stuff_meta(TEST_CATEGORIES)
441
+ name = "cc_3m_train"
442
+
443
+ DatasetCatalog.register(
444
+ name,
445
+ lambda x=csv_file: load_cc_image(x),
446
+ )
447
+ MetadataCatalog.get(name).set(
448
+ csv_file=csv_file,
449
+ evaluator_type="dummy",
450
+ ignore_label=255,
451
+ **meta,
452
+ )
453
+
454
+
455
+ # _csv_file = "/home/jeffliang/zsseg/datasets/coco/coco_train_merge_captions.csv"
456
+ _csv_file = "/home/jeffliang/zsseg/configs/masked_images/pred/samples.csv"
457
+ register_cc_3m(_csv_file)
open_vocab_seg/data/datasets/register_coco_stuff.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ from detectron2.data import DatasetCatalog, MetadataCatalog
5
+ from detectron2.data.datasets import load_sem_seg
6
+
7
+
8
+ COCO_CATEGORIES = [
9
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
10
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
11
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
12
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
13
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
14
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
15
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
16
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
17
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
18
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
19
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
20
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
21
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
22
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
23
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
24
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
25
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
26
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
27
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
28
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
29
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
30
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
31
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
32
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
33
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
34
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
35
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
36
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
37
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
38
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
39
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
40
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
41
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
42
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
43
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
44
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
45
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
46
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
47
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
48
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
49
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
50
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
51
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
52
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
53
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
54
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
55
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
56
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
57
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
58
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
59
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
60
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
61
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
62
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
63
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
64
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
65
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
66
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
67
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
68
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
69
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
70
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
71
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
72
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
73
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
74
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
75
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
76
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
77
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
78
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
79
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
80
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
81
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
82
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
83
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
84
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
85
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
86
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
87
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
88
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
89
+ {"id": 92, "name": "banner", "supercategory": "textile"},
90
+ {"id": 93, "name": "blanket", "supercategory": "textile"},
91
+ {"id": 94, "name": "branch", "supercategory": "plant"},
92
+ {"id": 95, "name": "bridge", "supercategory": "building"},
93
+ {"id": 96, "name": "building-other", "supercategory": "building"},
94
+ {"id": 97, "name": "bush", "supercategory": "plant"},
95
+ {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
96
+ {"id": 99, "name": "cage", "supercategory": "structural"},
97
+ {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
98
+ {"id": 101, "name": "carpet", "supercategory": "floor"},
99
+ {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
100
+ {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
101
+ {"id": 104, "name": "cloth", "supercategory": "textile"},
102
+ {"id": 105, "name": "clothes", "supercategory": "textile"},
103
+ {"id": 106, "name": "clouds", "supercategory": "sky"},
104
+ {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
105
+ {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
106
+ {"id": 109, "name": "curtain", "supercategory": "textile"},
107
+ {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
108
+ {"id": 111, "name": "dirt", "supercategory": "ground"},
109
+ {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
110
+ {"id": 113, "name": "fence", "supercategory": "structural"},
111
+ {"id": 114, "name": "floor-marble", "supercategory": "floor"},
112
+ {"id": 115, "name": "floor-other", "supercategory": "floor"},
113
+ {"id": 116, "name": "floor-stone", "supercategory": "floor"},
114
+ {"id": 117, "name": "floor-tile", "supercategory": "floor"},
115
+ {"id": 118, "name": "floor-wood", "supercategory": "floor"},
116
+ {"id": 119, "name": "flower", "supercategory": "plant"},
117
+ {"id": 120, "name": "fog", "supercategory": "water"},
118
+ {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
119
+ {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
120
+ {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
121
+ {"id": 124, "name": "grass", "supercategory": "plant"},
122
+ {"id": 125, "name": "gravel", "supercategory": "ground"},
123
+ {"id": 126, "name": "ground-other", "supercategory": "ground"},
124
+ {"id": 127, "name": "hill", "supercategory": "solid"},
125
+ {"id": 128, "name": "house", "supercategory": "building"},
126
+ {"id": 129, "name": "leaves", "supercategory": "plant"},
127
+ {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
128
+ {"id": 131, "name": "mat", "supercategory": "textile"},
129
+ {"id": 132, "name": "metal", "supercategory": "raw-material"},
130
+ {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
131
+ {"id": 134, "name": "moss", "supercategory": "plant"},
132
+ {"id": 135, "name": "mountain", "supercategory": "solid"},
133
+ {"id": 136, "name": "mud", "supercategory": "ground"},
134
+ {"id": 137, "name": "napkin", "supercategory": "textile"},
135
+ {"id": 138, "name": "net", "supercategory": "structural"},
136
+ {"id": 139, "name": "paper", "supercategory": "raw-material"},
137
+ {"id": 140, "name": "pavement", "supercategory": "ground"},
138
+ {"id": 141, "name": "pillow", "supercategory": "textile"},
139
+ {"id": 142, "name": "plant-other", "supercategory": "plant"},
140
+ {"id": 143, "name": "plastic", "supercategory": "raw-material"},
141
+ {"id": 144, "name": "platform", "supercategory": "ground"},
142
+ {"id": 145, "name": "playingfield", "supercategory": "ground"},
143
+ {"id": 146, "name": "railing", "supercategory": "structural"},
144
+ {"id": 147, "name": "railroad", "supercategory": "ground"},
145
+ {"id": 148, "name": "river", "supercategory": "water"},
146
+ {"id": 149, "name": "road", "supercategory": "ground"},
147
+ {"id": 150, "name": "rock", "supercategory": "solid"},
148
+ {"id": 151, "name": "roof", "supercategory": "building"},
149
+ {"id": 152, "name": "rug", "supercategory": "textile"},
150
+ {"id": 153, "name": "salad", "supercategory": "food-stuff"},
151
+ {"id": 154, "name": "sand", "supercategory": "ground"},
152
+ {"id": 155, "name": "sea", "supercategory": "water"},
153
+ {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
154
+ {"id": 157, "name": "sky-other", "supercategory": "sky"},
155
+ {"id": 158, "name": "skyscraper", "supercategory": "building"},
156
+ {"id": 159, "name": "snow", "supercategory": "ground"},
157
+ {"id": 160, "name": "solid-other", "supercategory": "solid"},
158
+ {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
159
+ {"id": 162, "name": "stone", "supercategory": "solid"},
160
+ {"id": 163, "name": "straw", "supercategory": "plant"},
161
+ {"id": 164, "name": "structural-other", "supercategory": "structural"},
162
+ {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
163
+ {"id": 166, "name": "tent", "supercategory": "building"},
164
+ {"id": 167, "name": "textile-other", "supercategory": "textile"},
165
+ {"id": 168, "name": "towel", "supercategory": "textile"},
166
+ {"id": 169, "name": "tree", "supercategory": "plant"},
167
+ {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
168
+ {"id": 171, "name": "wall-brick", "supercategory": "wall"},
169
+ {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
170
+ {"id": 173, "name": "wall-other", "supercategory": "wall"},
171
+ {"id": 174, "name": "wall-panel", "supercategory": "wall"},
172
+ {"id": 175, "name": "wall-stone", "supercategory": "wall"},
173
+ {"id": 176, "name": "wall-tile", "supercategory": "wall"},
174
+ {"id": 177, "name": "wall-wood", "supercategory": "wall"},
175
+ {"id": 178, "name": "water-other", "supercategory": "water"},
176
+ {"id": 179, "name": "waterdrops", "supercategory": "water"},
177
+ {"id": 180, "name": "window-blind", "supercategory": "window"},
178
+ {"id": 181, "name": "window-other", "supercategory": "window"},
179
+ {"id": 182, "name": "wood", "supercategory": "solid"},
180
+ ]
181
+
182
+ def _get_coco_stuff_meta(cat_list):
183
+ # Id 0 is reserved for ignore_label, we change ignore_label for 0
184
+ # to 255 in our pre-processing.
185
+ stuff_ids = [k["id"] for k in cat_list]
186
+
187
+ # For semantic segmentation, this mapping maps from contiguous stuff id
188
+ # (in [0, 91], used in models) to ids in the dataset (used for processing results)
189
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
190
+ stuff_classes = [k["name"] for k in cat_list]
191
+
192
+ ret = {
193
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
194
+ "stuff_classes": stuff_classes,
195
+ }
196
+ return ret
197
+
198
+
199
+ def register_all_coco_stuff_10k(root):
200
+ root = os.path.join(root, "coco", "coco_stuff_10k")
201
+ meta = _get_coco_stuff_meta(COCO_CATEGORIES)
202
+ for name, image_dirname, sem_seg_dirname in [
203
+ ("train", "images_detectron2/train", "annotations_detectron2/train"),
204
+ ]:
205
+ image_dir = os.path.join(root, image_dirname)
206
+ gt_dir = os.path.join(root, sem_seg_dirname)
207
+ name = f"coco_2017_{name}_stuff_10k_sem_seg"
208
+ DatasetCatalog.register(
209
+ name,
210
+ lambda x=image_dir, y=gt_dir: load_sem_seg(
211
+ y, x, gt_ext="png", image_ext="jpg"
212
+ ),
213
+ )
214
+ MetadataCatalog.get(name).set(
215
+ image_root=image_dir,
216
+ sem_seg_root=gt_dir,
217
+ evaluator_type="sem_seg",
218
+ ignore_label=255,
219
+ **meta,
220
+ )
221
+
222
+
223
+ def register_all_coco_stuff(root):
224
+ root = os.path.join(root, "coco")
225
+ meta = _get_coco_stuff_meta(COCO_CATEGORIES)
226
+
227
+ for name, image_dirname, sem_seg_dirname in [
228
+ ("train", "train2017", "stuffthingmaps_detectron2/train2017"),
229
+ ]:
230
+ image_dir = os.path.join(root, image_dirname)
231
+ gt_dir = os.path.join(root, sem_seg_dirname)
232
+ all_name = f"coco_2017_{name}_stuff_sem_seg"
233
+ DatasetCatalog.register(
234
+ all_name,
235
+ lambda x=image_dir, y=gt_dir: load_sem_seg(
236
+ y, x, gt_ext="png", image_ext="jpg"
237
+ ),
238
+ )
239
+ MetadataCatalog.get(all_name).set(
240
+ image_root=image_dir,
241
+ sem_seg_root=gt_dir,
242
+ evaluator_type="sem_seg",
243
+ ignore_label=255,
244
+ **meta,
245
+ )
246
+
247
+
248
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
249
+ register_all_coco_stuff_10k(_root)
250
+ register_all_coco_stuff(_root)
open_vocab_seg/data/datasets/register_pascal_context.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ from detectron2.data import DatasetCatalog, MetadataCatalog
5
+ from detectron2.data.datasets import load_sem_seg
6
+
7
+ PASCALCONTEX59_NAMES = (
8
+ "aeroplane",
9
+ "bicycle",
10
+ "bird",
11
+ "boat",
12
+ "bottle",
13
+ "bus",
14
+ "car",
15
+ "cat",
16
+ "chair",
17
+ "cow",
18
+ "table",
19
+ "dog",
20
+ "horse",
21
+ "motorbike",
22
+ "person",
23
+ "pottedplant",
24
+ "sheep",
25
+ "sofa",
26
+ "train",
27
+ "tvmonitor",
28
+ "bag",
29
+ "bed",
30
+ "bench",
31
+ "book",
32
+ "building",
33
+ "cabinet",
34
+ "ceiling",
35
+ "cloth",
36
+ "computer",
37
+ "cup",
38
+ "door",
39
+ "fence",
40
+ "floor",
41
+ "flower",
42
+ "food",
43
+ "grass",
44
+ "ground",
45
+ "keyboard",
46
+ "light",
47
+ "mountain",
48
+ "mouse",
49
+ "curtain",
50
+ "platform",
51
+ "sign",
52
+ "plate",
53
+ "road",
54
+ "rock",
55
+ "shelves",
56
+ "sidewalk",
57
+ "sky",
58
+ "snow",
59
+ "bedclothes",
60
+ "track",
61
+ "tree",
62
+ "truck",
63
+ "wall",
64
+ "water",
65
+ "window",
66
+ "wood",
67
+ )
68
+
69
+ PASCALCONTEX459_NAMES = (
70
+ "accordion",
71
+ "aeroplane",
72
+ "air conditioner",
73
+ "antenna",
74
+ "artillery",
75
+ "ashtray",
76
+ "atrium",
77
+ "baby carriage",
78
+ "bag",
79
+ "ball",
80
+ "balloon",
81
+ "bamboo weaving",
82
+ "barrel",
83
+ "baseball bat",
84
+ "basket",
85
+ "basketball backboard",
86
+ "bathtub",
87
+ "bed",
88
+ "bedclothes",
89
+ "beer",
90
+ "bell",
91
+ "bench",
92
+ "bicycle",
93
+ "binoculars",
94
+ "bird",
95
+ "bird cage",
96
+ "bird feeder",
97
+ "bird nest",
98
+ "blackboard",
99
+ "board",
100
+ "boat",
101
+ "bone",
102
+ "book",
103
+ "bottle",
104
+ "bottle opener",
105
+ "bowl",
106
+ "box",
107
+ "bracelet",
108
+ "brick",
109
+ "bridge",
110
+ "broom",
111
+ "brush",
112
+ "bucket",
113
+ "building",
114
+ "bus",
115
+ "cabinet",
116
+ "cabinet door",
117
+ "cage",
118
+ "cake",
119
+ "calculator",
120
+ "calendar",
121
+ "camel",
122
+ "camera",
123
+ "camera lens",
124
+ "can",
125
+ "candle",
126
+ "candle holder",
127
+ "cap",
128
+ "car",
129
+ "card",
130
+ "cart",
131
+ "case",
132
+ "casette recorder",
133
+ "cash register",
134
+ "cat",
135
+ "cd",
136
+ "cd player",
137
+ "ceiling",
138
+ "cell phone",
139
+ "cello",
140
+ "chain",
141
+ "chair",
142
+ "chessboard",
143
+ "chicken",
144
+ "chopstick",
145
+ "clip",
146
+ "clippers",
147
+ "clock",
148
+ "closet",
149
+ "cloth",
150
+ "clothes tree",
151
+ "coffee",
152
+ "coffee machine",
153
+ "comb",
154
+ "computer",
155
+ "concrete",
156
+ "cone",
157
+ "container",
158
+ "control booth",
159
+ "controller",
160
+ "cooker",
161
+ "copying machine",
162
+ "coral",
163
+ "cork",
164
+ "corkscrew",
165
+ "counter",
166
+ "court",
167
+ "cow",
168
+ "crabstick",
169
+ "crane",
170
+ "crate",
171
+ "cross",
172
+ "crutch",
173
+ "cup",
174
+ "curtain",
175
+ "cushion",
176
+ "cutting board",
177
+ "dais",
178
+ "disc",
179
+ "disc case",
180
+ "dishwasher",
181
+ "dock",
182
+ "dog",
183
+ "dolphin",
184
+ "door",
185
+ "drainer",
186
+ "dray",
187
+ "drink dispenser",
188
+ "drinking machine",
189
+ "drop",
190
+ "drug",
191
+ "drum",
192
+ "drum kit",
193
+ "duck",
194
+ "dumbbell",
195
+ "earphone",
196
+ "earrings",
197
+ "egg",
198
+ "electric fan",
199
+ "electric iron",
200
+ "electric pot",
201
+ "electric saw",
202
+ "electronic keyboard",
203
+ "engine",
204
+ "envelope",
205
+ "equipment",
206
+ "escalator",
207
+ "exhibition booth",
208
+ "extinguisher",
209
+ "eyeglass",
210
+ "fan",
211
+ "faucet",
212
+ "fax machine",
213
+ "fence",
214
+ "ferris wheel",
215
+ "fire extinguisher",
216
+ "fire hydrant",
217
+ "fire place",
218
+ "fish",
219
+ "fish tank",
220
+ "fishbowl",
221
+ "fishing net",
222
+ "fishing pole",
223
+ "flag",
224
+ "flagstaff",
225
+ "flame",
226
+ "flashlight",
227
+ "floor",
228
+ "flower",
229
+ "fly",
230
+ "foam",
231
+ "food",
232
+ "footbridge",
233
+ "forceps",
234
+ "fork",
235
+ "forklift",
236
+ "fountain",
237
+ "fox",
238
+ "frame",
239
+ "fridge",
240
+ "frog",
241
+ "fruit",
242
+ "funnel",
243
+ "furnace",
244
+ "game controller",
245
+ "game machine",
246
+ "gas cylinder",
247
+ "gas hood",
248
+ "gas stove",
249
+ "gift box",
250
+ "glass",
251
+ "glass marble",
252
+ "globe",
253
+ "glove",
254
+ "goal",
255
+ "grandstand",
256
+ "grass",
257
+ "gravestone",
258
+ "ground",
259
+ "guardrail",
260
+ "guitar",
261
+ "gun",
262
+ "hammer",
263
+ "hand cart",
264
+ "handle",
265
+ "handrail",
266
+ "hanger",
267
+ "hard disk drive",
268
+ "hat",
269
+ "hay",
270
+ "headphone",
271
+ "heater",
272
+ "helicopter",
273
+ "helmet",
274
+ "holder",
275
+ "hook",
276
+ "horse",
277
+ "horse-drawn carriage",
278
+ "hot-air balloon",
279
+ "hydrovalve",
280
+ "ice",
281
+ "inflator pump",
282
+ "ipod",
283
+ "iron",
284
+ "ironing board",
285
+ "jar",
286
+ "kart",
287
+ "kettle",
288
+ "key",
289
+ "keyboard",
290
+ "kitchen range",
291
+ "kite",
292
+ "knife",
293
+ "knife block",
294
+ "ladder",
295
+ "ladder truck",
296
+ "ladle",
297
+ "laptop",
298
+ "leaves",
299
+ "lid",
300
+ "life buoy",
301
+ "light",
302
+ "light bulb",
303
+ "lighter",
304
+ "line",
305
+ "lion",
306
+ "lobster",
307
+ "lock",
308
+ "machine",
309
+ "mailbox",
310
+ "mannequin",
311
+ "map",
312
+ "mask",
313
+ "mat",
314
+ "match book",
315
+ "mattress",
316
+ "menu",
317
+ "metal",
318
+ "meter box",
319
+ "microphone",
320
+ "microwave",
321
+ "mirror",
322
+ "missile",
323
+ "model",
324
+ "money",
325
+ "monkey",
326
+ "mop",
327
+ "motorbike",
328
+ "mountain",
329
+ "mouse",
330
+ "mouse pad",
331
+ "musical instrument",
332
+ "napkin",
333
+ "net",
334
+ "newspaper",
335
+ "oar",
336
+ "ornament",
337
+ "outlet",
338
+ "oven",
339
+ "oxygen bottle",
340
+ "pack",
341
+ "pan",
342
+ "paper",
343
+ "paper box",
344
+ "paper cutter",
345
+ "parachute",
346
+ "parasol",
347
+ "parterre",
348
+ "patio",
349
+ "pelage",
350
+ "pen",
351
+ "pen container",
352
+ "pencil",
353
+ "person",
354
+ "photo",
355
+ "piano",
356
+ "picture",
357
+ "pig",
358
+ "pillar",
359
+ "pillow",
360
+ "pipe",
361
+ "pitcher",
362
+ "plant",
363
+ "plastic",
364
+ "plate",
365
+ "platform",
366
+ "player",
367
+ "playground",
368
+ "pliers",
369
+ "plume",
370
+ "poker",
371
+ "poker chip",
372
+ "pole",
373
+ "pool table",
374
+ "postcard",
375
+ "poster",
376
+ "pot",
377
+ "pottedplant",
378
+ "printer",
379
+ "projector",
380
+ "pumpkin",
381
+ "rabbit",
382
+ "racket",
383
+ "radiator",
384
+ "radio",
385
+ "rail",
386
+ "rake",
387
+ "ramp",
388
+ "range hood",
389
+ "receiver",
390
+ "recorder",
391
+ "recreational machines",
392
+ "remote control",
393
+ "road",
394
+ "robot",
395
+ "rock",
396
+ "rocket",
397
+ "rocking horse",
398
+ "rope",
399
+ "rug",
400
+ "ruler",
401
+ "runway",
402
+ "saddle",
403
+ "sand",
404
+ "saw",
405
+ "scale",
406
+ "scanner",
407
+ "scissors",
408
+ "scoop",
409
+ "screen",
410
+ "screwdriver",
411
+ "sculpture",
412
+ "scythe",
413
+ "sewer",
414
+ "sewing machine",
415
+ "shed",
416
+ "sheep",
417
+ "shell",
418
+ "shelves",
419
+ "shoe",
420
+ "shopping cart",
421
+ "shovel",
422
+ "sidecar",
423
+ "sidewalk",
424
+ "sign",
425
+ "signal light",
426
+ "sink",
427
+ "skateboard",
428
+ "ski",
429
+ "sky",
430
+ "sled",
431
+ "slippers",
432
+ "smoke",
433
+ "snail",
434
+ "snake",
435
+ "snow",
436
+ "snowmobiles",
437
+ "sofa",
438
+ "spanner",
439
+ "spatula",
440
+ "speaker",
441
+ "speed bump",
442
+ "spice container",
443
+ "spoon",
444
+ "sprayer",
445
+ "squirrel",
446
+ "stage",
447
+ "stair",
448
+ "stapler",
449
+ "stick",
450
+ "sticky note",
451
+ "stone",
452
+ "stool",
453
+ "stove",
454
+ "straw",
455
+ "stretcher",
456
+ "sun",
457
+ "sunglass",
458
+ "sunshade",
459
+ "surveillance camera",
460
+ "swan",
461
+ "sweeper",
462
+ "swim ring",
463
+ "swimming pool",
464
+ "swing",
465
+ "switch",
466
+ "table",
467
+ "tableware",
468
+ "tank",
469
+ "tap",
470
+ "tape",
471
+ "tarp",
472
+ "telephone",
473
+ "telephone booth",
474
+ "tent",
475
+ "tire",
476
+ "toaster",
477
+ "toilet",
478
+ "tong",
479
+ "tool",
480
+ "toothbrush",
481
+ "towel",
482
+ "toy",
483
+ "toy car",
484
+ "track",
485
+ "train",
486
+ "trampoline",
487
+ "trash bin",
488
+ "tray",
489
+ "tree",
490
+ "tricycle",
491
+ "tripod",
492
+ "trophy",
493
+ "truck",
494
+ "tube",
495
+ "turtle",
496
+ "tvmonitor",
497
+ "tweezers",
498
+ "typewriter",
499
+ "umbrella",
500
+ "unknown",
501
+ "vacuum cleaner",
502
+ "vending machine",
503
+ "video camera",
504
+ "video game console",
505
+ "video player",
506
+ "video tape",
507
+ "violin",
508
+ "wakeboard",
509
+ "wall",
510
+ "wallet",
511
+ "wardrobe",
512
+ "washing machine",
513
+ "watch",
514
+ "water",
515
+ "water dispenser",
516
+ "water pipe",
517
+ "water skate board",
518
+ "watermelon",
519
+ "whale",
520
+ "wharf",
521
+ "wheel",
522
+ "wheelchair",
523
+ "window",
524
+ "window blinds",
525
+ "wineglass",
526
+ "wire",
527
+ "wood",
528
+ "wool",
529
+
530
+ )
531
+
532
+
533
+ def _get_voc_meta(cat_list):
534
+ ret = {
535
+ "stuff_classes": cat_list,
536
+ }
537
+ return ret
538
+
539
+
540
+ def register_pascal_context_59(root):
541
+ root = os.path.join(root, "VOCdevkit/VOC2010")
542
+ meta = _get_voc_meta(PASCALCONTEX59_NAMES)
543
+ for name, image_dirname, sem_seg_dirname in [
544
+ ("val", "JPEGImages", "annotations_detectron2/pc59_val"),
545
+ ]:
546
+ image_dir = os.path.join(root, image_dirname)
547
+ gt_dir = os.path.join(root, sem_seg_dirname)
548
+ all_name = f"pascal_context_59_sem_seg_{name}"
549
+ DatasetCatalog.register(
550
+ all_name,
551
+ lambda x=image_dir, y=gt_dir: load_sem_seg(
552
+ y, x, gt_ext="png", image_ext="jpg"
553
+ ),
554
+ )
555
+ MetadataCatalog.get(all_name).set(
556
+ image_root=image_dir,
557
+ sem_seg_root=gt_dir,
558
+ evaluator_type="sem_seg",
559
+ ignore_label=255,
560
+ **meta,
561
+ )
562
+
563
+ def register_pascal_context_459(root):
564
+ root = os.path.join(root, "VOCdevkit/VOC2010")
565
+ meta = _get_voc_meta(PASCALCONTEX459_NAMES)
566
+ for name, image_dirname, sem_seg_dirname in [
567
+ ("val", "JPEGImages", "annotations_detectron2/pc459_val"),
568
+ ]:
569
+ image_dir = os.path.join(root, image_dirname)
570
+ gt_dir = os.path.join(root, sem_seg_dirname)
571
+ all_name = f"pascal_context_459_sem_seg_{name}"
572
+ DatasetCatalog.register(
573
+ all_name,
574
+ lambda x=image_dir, y=gt_dir: load_sem_seg(
575
+ y, x, gt_ext="tif", image_ext="jpg"
576
+ ),
577
+ )
578
+ MetadataCatalog.get(all_name).set(
579
+ image_root=image_dir,
580
+ sem_seg_root=gt_dir,
581
+ evaluator_type="sem_seg",
582
+ ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images
583
+ **meta,
584
+ )
585
+
586
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
587
+ register_pascal_context_59(_root)
588
+ register_pascal_context_459(_root)
open_vocab_seg/data/datasets/register_voc_seg.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ from detectron2.data import DatasetCatalog, MetadataCatalog
5
+ from detectron2.data.datasets import load_sem_seg
6
+
7
+ PASCALVOC20_NAMES = (
8
+ "aeroplane",
9
+ "bicycle",
10
+ "bird",
11
+ "boat",
12
+ "bottle",
13
+ "bus",
14
+ "car",
15
+ "cat",
16
+ "chair",
17
+ "cow",
18
+ "diningtable",
19
+ "dog",
20
+ "horse",
21
+ "motorbike",
22
+ "person",
23
+ "pottedplant",
24
+ "sheep",
25
+ "sofa",
26
+ "train",
27
+ "tvmonitor",
28
+ )
29
+
30
+ def _get_voc_meta(cat_list):
31
+ ret = {
32
+ "stuff_classes": cat_list,
33
+ }
34
+ return ret
35
+
36
+
37
+ def register_pascalvoc(root):
38
+ root = os.path.join(root, "VOCdevkit/VOC2012")
39
+ meta = _get_voc_meta(PASCALVOC20_NAMES)
40
+
41
+ for name, image_dirname, sem_seg_dirname in [
42
+ ("val", "JPEGImages", "annotations_detectron2/val"),
43
+ ]:
44
+ image_dir = os.path.join(root, image_dirname)
45
+ gt_dir = os.path.join(root, sem_seg_dirname)
46
+ all_name = f"pascalvoc20_sem_seg_{name}"
47
+ DatasetCatalog.register(
48
+ all_name,
49
+ lambda x=image_dir, y=gt_dir: load_sem_seg(
50
+ y, x, gt_ext="png", image_ext="jpg"
51
+ ),
52
+ )
53
+ MetadataCatalog.get(all_name).set(
54
+ image_root=image_dir,
55
+ sem_seg_root=gt_dir,
56
+ evaluator_type="sem_seg",
57
+ ignore_label=255,
58
+ **meta,
59
+ )
60
+
61
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
62
+ register_pascalvoc(_root)
open_vocab_seg/evaluation/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from .generalized_sem_seg_evaluation import GeneralizedSemSegEvaluator
open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import itertools
5
+ import json
6
+ import numpy as np
7
+ import os
8
+ from collections import OrderedDict
9
+ import PIL.Image as Image
10
+ import torch
11
+
12
+ from detectron2.data import DatasetCatalog, MetadataCatalog
13
+ from detectron2.utils.comm import all_gather, is_main_process, synchronize
14
+ from detectron2.utils.file_io import PathManager
15
+
16
+ from detectron2.evaluation import SemSegEvaluator
17
+
18
+
19
+ class GeneralizedSemSegEvaluator(SemSegEvaluator):
20
+ """
21
+ Evaluate semantic segmentation metrics.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ dataset_name,
27
+ distributed=True,
28
+ output_dir=None,
29
+ *,
30
+ num_classes=None,
31
+ ignore_label=None,
32
+ post_process_func=None,
33
+ ):
34
+ super().__init__(
35
+ dataset_name,
36
+ distributed=distributed,
37
+ output_dir=output_dir,
38
+ num_classes=num_classes,
39
+ ignore_label=ignore_label,
40
+ )
41
+ meta = MetadataCatalog.get(dataset_name)
42
+ try:
43
+ self._evaluation_set = meta.evaluation_set
44
+ except AttributeError:
45
+ self._evaluation_set = None
46
+ self.post_process_func = (
47
+ post_process_func
48
+ if post_process_func is not None
49
+ else lambda x, **kwargs: x
50
+ )
51
+
52
+ def process(self, inputs, outputs):
53
+ """
54
+ Args:
55
+ inputs: the inputs to a model.
56
+ It is a list of dicts. Each dict corresponds to an image and
57
+ contains keys like "height", "width", "file_name".
58
+ outputs: the outputs of a model. It is either list of semantic segmentation predictions
59
+ (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
60
+ segmentation prediction in the same format.
61
+ """
62
+ for input, output in zip(inputs, outputs):
63
+ output = self.post_process_func(
64
+ output["sem_seg"], image=np.array(Image.open(input["file_name"]))
65
+ )
66
+ output = output.argmax(dim=0).to(self._cpu_device)
67
+ pred = np.array(output, dtype=np.int)
68
+ with PathManager.open(
69
+ self.input_file_to_gt_file[input["file_name"]], "rb"
70
+ ) as f:
71
+ gt = np.array(Image.open(f), dtype=np.int)
72
+
73
+ gt[gt == self._ignore_label] = self._num_classes
74
+
75
+ self._conf_matrix += np.bincount(
76
+ (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
77
+ minlength=self._conf_matrix.size,
78
+ ).reshape(self._conf_matrix.shape)
79
+
80
+ self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
81
+
82
+ def evaluate(self):
83
+ """
84
+ Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
85
+
86
+ * Mean intersection-over-union averaged across classes (mIoU)
87
+ * Frequency Weighted IoU (fwIoU)
88
+ * Mean pixel accuracy averaged across classes (mACC)
89
+ * Pixel Accuracy (pACC)
90
+ """
91
+ if self._distributed:
92
+ synchronize()
93
+ conf_matrix_list = all_gather(self._conf_matrix)
94
+ self._predictions = all_gather(self._predictions)
95
+ self._predictions = list(itertools.chain(*self._predictions))
96
+ if not is_main_process():
97
+ return
98
+
99
+ self._conf_matrix = np.zeros_like(self._conf_matrix)
100
+ for conf_matrix in conf_matrix_list:
101
+ self._conf_matrix += conf_matrix
102
+
103
+ if self._output_dir:
104
+ PathManager.mkdirs(self._output_dir)
105
+ file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
106
+ with PathManager.open(file_path, "w") as f:
107
+ f.write(json.dumps(self._predictions))
108
+
109
+ acc = np.full(self._num_classes, np.nan, dtype=np.float)
110
+ iou = np.full(self._num_classes, np.nan, dtype=np.float)
111
+ tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
112
+ pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
113
+ class_weights = pos_gt / np.sum(pos_gt)
114
+ pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
115
+ acc_valid = pos_gt > 0
116
+ acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
117
+ iou_valid = (pos_gt + pos_pred) > 0
118
+ union = pos_gt + pos_pred - tp
119
+ iou[acc_valid] = tp[acc_valid] / union[acc_valid]
120
+ macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
121
+ miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
122
+ fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
123
+ pacc = np.sum(tp) / np.sum(pos_gt)
124
+
125
+ res = {}
126
+ res["mIoU"] = 100 * miou
127
+ res["fwIoU"] = 100 * fiou
128
+ for i, name in enumerate(self._class_names):
129
+ res["IoU-{}".format(name)] = 100 * iou[i]
130
+ res["mACC"] = 100 * macc
131
+ res["pACC"] = 100 * pacc
132
+ for i, name in enumerate(self._class_names):
133
+ res["ACC-{}".format(name)] = 100 * acc[i]
134
+ if self._evaluation_set is not None:
135
+ for set_name, set_inds in self._evaluation_set.items():
136
+ iou_list = []
137
+ set_inds = np.array(set_inds, np.int)
138
+ mask = np.zeros((len(iou),)).astype(np.bool)
139
+ mask[set_inds] = 1
140
+ miou = np.sum(iou[mask][acc_valid[mask]]) / np.sum(iou_valid[mask])
141
+ pacc = np.sum(tp[mask]) / np.sum(pos_gt[mask])
142
+ res["mIoU-{}".format(set_name)] = 100 * miou
143
+ res["pAcc-{}".format(set_name)] = 100 * pacc
144
+ iou_list.append(miou)
145
+ miou = np.sum(iou[~mask][acc_valid[~mask]]) / np.sum(iou_valid[~mask])
146
+ pacc = np.sum(tp[~mask]) / np.sum(pos_gt[~mask])
147
+ res["mIoU-un{}".format(set_name)] = 100 * miou
148
+ res["pAcc-un{}".format(set_name)] = 100 * pacc
149
+ iou_list.append(miou)
150
+ res["hIoU-{}".format(set_name)] = (
151
+ 100 * len(iou_list) / sum([1 / iou for iou in iou_list])
152
+ )
153
+ if self._output_dir:
154
+ file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
155
+ with PathManager.open(file_path, "wb") as f:
156
+ torch.save(res, f)
157
+ results = OrderedDict({"sem_seg": res})
158
+ self._logger.info(results)
159
+ return results
open_vocab_seg/mask_former_model.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
13
+ from detectron2.modeling.backbone import Backbone
14
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
15
+ from detectron2.structures import ImageList
16
+
17
+ from .modeling.criterion import SetCriterion
18
+ from .modeling.matcher import HungarianMatcher
19
+
20
+
21
+ @META_ARCH_REGISTRY.register()
22
+ class MaskFormer(nn.Module):
23
+ """
24
+ Main class for mask classification semantic segmentation architectures.
25
+ """
26
+
27
+ @configurable
28
+ def __init__(
29
+ self,
30
+ *,
31
+ backbone: Backbone,
32
+ sem_seg_head: nn.Module,
33
+ criterion: nn.Module,
34
+ num_queries: int,
35
+ panoptic_on: bool,
36
+ object_mask_threshold: float,
37
+ overlap_threshold: float,
38
+ metadata,
39
+ size_divisibility: int,
40
+ sem_seg_postprocess_before_inference: bool,
41
+ pixel_mean: Tuple[float],
42
+ pixel_std: Tuple[float],
43
+ ):
44
+ """
45
+ Args:
46
+ backbone: a backbone module, must follow detectron2's backbone interface
47
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
48
+ criterion: a module that defines the loss
49
+ num_queries: int, number of queries
50
+ panoptic_on: bool, whether to output panoptic segmentation prediction
51
+ object_mask_threshold: float, threshold to filter query based on classification score
52
+ for panoptic segmentation inference
53
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
54
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
55
+ segmentation inference
56
+ size_divisibility: Some backbones require the input height and width to be divisible by a
57
+ specific integer. We can use this to override such requirement.
58
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
59
+ to original input size before semantic segmentation inference or after.
60
+ For high-resolution dataset like Mapillary, resizing predictions before
61
+ inference will cause OOM error.
62
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
63
+ the per-channel mean and std to be used to normalize the input image
64
+ """
65
+ super().__init__()
66
+ self.backbone = backbone
67
+ self.sem_seg_head = sem_seg_head
68
+ self.criterion = criterion
69
+ self.num_queries = num_queries
70
+ self.overlap_threshold = overlap_threshold
71
+ self.panoptic_on = panoptic_on
72
+ self.object_mask_threshold = object_mask_threshold
73
+ self.metadata = metadata
74
+ if size_divisibility < 0:
75
+ # use backbone size_divisibility if not set
76
+ size_divisibility = self.backbone.size_divisibility
77
+ self.size_divisibility = size_divisibility
78
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
79
+ self.register_buffer(
80
+ "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
81
+ )
82
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
83
+
84
+ @classmethod
85
+ def from_config(cls, cfg):
86
+ backbone = build_backbone(cfg)
87
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
88
+
89
+ # Loss parameters:
90
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
91
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
92
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
93
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
94
+
95
+ # building criterion
96
+ matcher = HungarianMatcher(
97
+ cost_class=1,
98
+ cost_mask=mask_weight,
99
+ cost_dice=dice_weight,
100
+ )
101
+
102
+ weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight}
103
+ if deep_supervision:
104
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
105
+ aux_weight_dict = {}
106
+ for i in range(dec_layers - 1):
107
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
108
+ weight_dict.update(aux_weight_dict)
109
+
110
+ losses = ["labels", "masks"]
111
+
112
+ criterion = SetCriterion(
113
+ sem_seg_head.num_classes,
114
+ matcher=matcher,
115
+ weight_dict=weight_dict,
116
+ eos_coef=no_object_weight,
117
+ losses=losses,
118
+ )
119
+
120
+ return {
121
+ "backbone": backbone,
122
+ "sem_seg_head": sem_seg_head,
123
+ "criterion": criterion,
124
+ "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
125
+ "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
126
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
127
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
128
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
129
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
130
+ "sem_seg_postprocess_before_inference": (
131
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
132
+ or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
133
+ ),
134
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
135
+ "pixel_std": cfg.MODEL.PIXEL_STD,
136
+ }
137
+
138
+ @property
139
+ def device(self):
140
+ return self.pixel_mean.device
141
+
142
+ def forward(self, batched_inputs):
143
+ """
144
+ Args:
145
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
146
+ Each item in the list contains the inputs for one image.
147
+ For now, each item in the list is a dict that contains:
148
+ * "image": Tensor, image in (C, H, W) format.
149
+ * "instances": per-region ground truth
150
+ * Other information that's included in the original dicts, such as:
151
+ "height", "width" (int): the output resolution of the model (may be different
152
+ from input resolution), used in inference.
153
+ Returns:
154
+ list[dict]:
155
+ each dict has the results for one image. The dict contains the following keys:
156
+
157
+ * "sem_seg":
158
+ A Tensor that represents the
159
+ per-pixel segmentation prediced by the head.
160
+ The prediction has shape KxHxW that represents the logits of
161
+ each class for each pixel.
162
+ * "panoptic_seg":
163
+ A tuple that represent panoptic output
164
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
165
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
166
+ Each dict contains keys "id", "category_id", "isthing".
167
+ """
168
+ images = [x["image"].to(self.device) for x in batched_inputs]
169
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
170
+ images = ImageList.from_tensors(images, self.size_divisibility)
171
+
172
+ features = self.backbone(images.tensor)
173
+ outputs = self.sem_seg_head(features)
174
+
175
+ if self.training:
176
+ # mask classification target
177
+ if "instances" in batched_inputs[0]:
178
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
179
+ targets = self.prepare_targets(gt_instances, images)
180
+ else:
181
+ targets = None
182
+
183
+ # bipartite matching-based loss
184
+ losses = self.criterion(outputs, targets)
185
+
186
+ for k in list(losses.keys()):
187
+ if k in self.criterion.weight_dict:
188
+ losses[k] *= self.criterion.weight_dict[k]
189
+ else:
190
+ # remove this loss if not specified in `weight_dict`
191
+ losses.pop(k)
192
+
193
+ return losses
194
+ else:
195
+ mask_cls_results = outputs["pred_logits"]
196
+ mask_pred_results = outputs["pred_masks"]
197
+ # upsample masks
198
+ mask_pred_results = F.interpolate(
199
+ mask_pred_results,
200
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
201
+ mode="bilinear",
202
+ align_corners=False,
203
+ )
204
+
205
+ processed_results = []
206
+ for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
207
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
208
+ ):
209
+ height = input_per_image.get("height", image_size[0])
210
+ width = input_per_image.get("width", image_size[1])
211
+
212
+ if self.sem_seg_postprocess_before_inference:
213
+ mask_pred_result = sem_seg_postprocess(
214
+ mask_pred_result, image_size, height, width
215
+ )
216
+
217
+ # semantic segmentation inference
218
+ r = self.semantic_inference(mask_cls_result, mask_pred_result)
219
+ if not self.sem_seg_postprocess_before_inference:
220
+ r = sem_seg_postprocess(r, image_size, height, width)
221
+ processed_results.append({"sem_seg": r})
222
+
223
+ # panoptic segmentation inference
224
+ if self.panoptic_on:
225
+ panoptic_r = self.panoptic_inference(
226
+ mask_cls_result, mask_pred_result
227
+ )
228
+ processed_results[-1]["panoptic_seg"] = panoptic_r
229
+
230
+ return processed_results
231
+
232
+ def prepare_targets(self, targets, images):
233
+ h, w = images.tensor.shape[-2:]
234
+ new_targets = []
235
+ for targets_per_image in targets:
236
+ # pad gt
237
+ gt_masks = targets_per_image.gt_masks
238
+ padded_masks = torch.zeros(
239
+ (gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device
240
+ )
241
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
242
+ new_targets.append(
243
+ {
244
+ "labels": targets_per_image.gt_classes,
245
+ "masks": padded_masks,
246
+ }
247
+ )
248
+ return new_targets
249
+
250
+ def semantic_inference(self, mask_cls, mask_pred):
251
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
252
+ mask_pred = mask_pred.sigmoid()
253
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
254
+ return semseg
open_vocab_seg/modeling/.DS_Store ADDED
Binary file (6.15 kB). View file
 
open_vocab_seg/modeling/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from .backbone.swin import D2SwinTransformer
5
+ from .backbone.clip_resnet import D2ModifiedResNet
6
+ from .heads.mask_former_head import MaskFormerHead
7
+ from .heads.open_vocab_mask_former_head import OpenVocabMaskFormerHead
8
+ from .heads.pixel_decoder import BasePixelDecoder
open_vocab_seg/modeling/backbone/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
open_vocab_seg/modeling/backbone/clip_resnet.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from collections import OrderedDict
5
+ import torch
6
+ import torch.nn as nn
7
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1, dilation=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(
21
+ planes, planes, 3, padding=1 * dilation, bias=False, dilation=dilation
22
+ )
23
+ self.bn2 = nn.BatchNorm2d(planes)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+
30
+ self.relu = nn.ReLU(inplace=True)
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(
37
+ OrderedDict(
38
+ [
39
+ ("-1", nn.AvgPool2d(stride)),
40
+ (
41
+ "0",
42
+ nn.Conv2d(
43
+ inplanes,
44
+ planes * self.expansion,
45
+ 1,
46
+ stride=1,
47
+ bias=False,
48
+ ),
49
+ ),
50
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
51
+ ]
52
+ )
53
+ )
54
+
55
+ def forward(self, x: torch.Tensor):
56
+ identity = x
57
+
58
+ out = self.relu(self.bn1(self.conv1(x)))
59
+ out = self.relu(self.bn2(self.conv2(out)))
60
+ out = self.avgpool(out)
61
+ out = self.bn3(self.conv3(out))
62
+
63
+ if self.downsample is not None:
64
+ identity = self.downsample(x)
65
+
66
+ out += identity
67
+ out = self.relu(out)
68
+ return out
69
+
70
+
71
+ class ModifiedResNet(nn.Module):
72
+ """
73
+ A ResNet class that is similar to torchvision's but contains the following changes:
74
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
75
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
76
+ - The final pooling layer is a QKV attention instead of an average pool
77
+ """
78
+
79
+ def __init__(self, layers, width=64, strides=[2, 1, 2, 2, 2], multi_grid=[1, 1, 1]):
80
+ super().__init__()
81
+
82
+ # the 3-layer stem
83
+ self.conv1 = nn.Conv2d(
84
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
85
+ )
86
+ self.bn1 = nn.BatchNorm2d(width // 2)
87
+ self.conv2 = nn.Conv2d(
88
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
89
+ )
90
+ self.bn2 = nn.BatchNorm2d(width // 2)
91
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
92
+ self.bn3 = nn.BatchNorm2d(width)
93
+ self.avgpool = nn.AvgPool2d(strides[0]) if strides[0] > 1 else nn.Identity()
94
+ self.relu = nn.ReLU(inplace=True)
95
+
96
+ # residual layers
97
+ self._inplanes = width # this is a *mutable* variable used during construction
98
+ self.layer1 = self._make_layer(width, layers[0], stride=strides[1])
99
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=strides[2])
100
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=strides[3])
101
+ self.layer4 = self._make_layer(
102
+ width * 8, layers[3], stride=strides[4], dilations=multi_grid
103
+ )
104
+ self.num_features = [width * 4, width * 8, width * 16, width * 32]
105
+
106
+ def _make_layer(self, planes, blocks, stride=1, dilations=None):
107
+ if dilations is None:
108
+ dilations = [1] * blocks
109
+ layers = [Bottleneck(self._inplanes, planes, stride, dilation=dilations[0])]
110
+ self._inplanes = planes * Bottleneck.expansion
111
+
112
+ for i in range(1, blocks):
113
+ layers.append(Bottleneck(self._inplanes, planes, dilation=dilations[i]))
114
+
115
+ return nn.Sequential(*layers)
116
+
117
+ def forward(self, x):
118
+ def stem(x):
119
+ for conv, bn in [
120
+ (self.conv1, self.bn1),
121
+ (self.conv2, self.bn2),
122
+ (self.conv3, self.bn3),
123
+ ]:
124
+ x = self.relu(bn(conv(x)))
125
+ x = self.avgpool(x)
126
+ return x
127
+
128
+ output = {}
129
+ x = x.type(self.conv1.weight.dtype)
130
+ x = stem(x) # 1/4,1/4
131
+ x = self.layer1(x)
132
+ output["res2"] = x
133
+ x = self.layer2(x) # 1/8,1/8
134
+ output["res3"] = x
135
+ x = self.layer3(x) # 1/16,1/16
136
+ output["res4"] = x
137
+ x = self.layer4(x) # 1/32,1/32
138
+ output["res5"] = x
139
+ return output
140
+
141
+
142
+ @BACKBONE_REGISTRY.register()
143
+ class D2ModifiedResNet(ModifiedResNet, Backbone):
144
+ def __init__(self, cfg, input_shape):
145
+ depth = cfg.MODEL.RESNETS.DEPTH
146
+ num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
147
+ width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
148
+ bottleneck_channels = num_groups * width_per_group
149
+ num_blocks_per_stage = {
150
+ 18: [2, 2, 2, 2],
151
+ 34: [3, 4, 6, 3],
152
+ 50: [3, 4, 6, 3],
153
+ 101: [3, 4, 23, 3],
154
+ 152: [3, 8, 36, 3],
155
+ }[depth]
156
+ strides = [2, 1, 2, 2, 2]
157
+ multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID
158
+ if cfg.MODEL.RESNETS.STEM_TYPE == "deeplab":
159
+ strides = [1, 1, 2, 2, 2]
160
+ super().__init__(
161
+ num_blocks_per_stage,
162
+ bottleneck_channels,
163
+ strides=strides,
164
+ multi_grid=multi_grid,
165
+ )
166
+ self._out_features = cfg.MODEL.RESNETS.OUT_FEATURES
167
+
168
+ self._out_feature_strides = {
169
+ "res2": 4,
170
+ "res3": 8,
171
+ "res4": 16,
172
+ "res5": 32,
173
+ }
174
+ self._out_feature_channels = {
175
+ "res2": self.num_features[0],
176
+ "res3": self.num_features[1],
177
+ "res4": self.num_features[2],
178
+ "res5": self.num_features[3],
179
+ }
180
+
181
+ def forward(self, x):
182
+ """
183
+ Args:
184
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
185
+ Returns:
186
+ dict[str->Tensor]: names and the corresponding features
187
+ """
188
+ outputs = {}
189
+ y = super().forward(x)
190
+ for k in y.keys():
191
+ if k in self._out_features:
192
+ outputs[k] = y[k]
193
+ return outputs
194
+
195
+ def output_shape(self):
196
+ return {
197
+ name: ShapeSpec(
198
+ channels=self._out_feature_channels[name],
199
+ stride=self._out_feature_strides[name],
200
+ )
201
+ for name in self._out_features
202
+ }
203
+
204
+ @property
205
+ def size_divisibility(self):
206
+ return 32
open_vocab_seg/modeling/backbone/swin.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint as checkpoint
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
20
+
21
+
22
+ class Mlp(nn.Module):
23
+ """Multilayer perceptron."""
24
+
25
+ def __init__(
26
+ self,
27
+ in_features,
28
+ hidden_features=None,
29
+ out_features=None,
30
+ act_layer=nn.GELU,
31
+ drop=0.0,
32
+ ):
33
+ super().__init__()
34
+ out_features = out_features or in_features
35
+ hidden_features = hidden_features or in_features
36
+ self.fc1 = nn.Linear(in_features, hidden_features)
37
+ self.act = act_layer()
38
+ self.fc2 = nn.Linear(hidden_features, out_features)
39
+ self.drop = nn.Dropout(drop)
40
+
41
+ def forward(self, x):
42
+ x = self.fc1(x)
43
+ x = self.act(x)
44
+ x = self.drop(x)
45
+ x = self.fc2(x)
46
+ x = self.drop(x)
47
+ return x
48
+
49
+
50
+ def window_partition(x, window_size):
51
+ """
52
+ Args:
53
+ x: (B, H, W, C)
54
+ window_size (int): window size
55
+ Returns:
56
+ windows: (num_windows*B, window_size, window_size, C)
57
+ """
58
+ B, H, W, C = x.shape
59
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
60
+ windows = (
61
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
62
+ )
63
+ return windows
64
+
65
+
66
+ def window_reverse(windows, window_size, H, W):
67
+ """
68
+ Args:
69
+ windows: (num_windows*B, window_size, window_size, C)
70
+ window_size (int): Window size
71
+ H (int): Height of image
72
+ W (int): Width of image
73
+ Returns:
74
+ x: (B, H, W, C)
75
+ """
76
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
77
+ x = windows.view(
78
+ B, H // window_size, W // window_size, window_size, window_size, -1
79
+ )
80
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
81
+ return x
82
+
83
+
84
+ class WindowAttention(nn.Module):
85
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
86
+ It supports both of shifted and non-shifted window.
87
+ Args:
88
+ dim (int): Number of input channels.
89
+ window_size (tuple[int]): The height and width of the window.
90
+ num_heads (int): Number of attention heads.
91
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
92
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
93
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
94
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ dim,
100
+ window_size,
101
+ num_heads,
102
+ qkv_bias=True,
103
+ qk_scale=None,
104
+ attn_drop=0.0,
105
+ proj_drop=0.0,
106
+ ):
107
+
108
+ super().__init__()
109
+ self.dim = dim
110
+ self.window_size = window_size # Wh, Ww
111
+ self.num_heads = num_heads
112
+ head_dim = dim // num_heads
113
+ self.scale = qk_scale or head_dim ** -0.5
114
+
115
+ # define a parameter table of relative position bias
116
+ self.relative_position_bias_table = nn.Parameter(
117
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
118
+ ) # 2*Wh-1 * 2*Ww-1, nH
119
+
120
+ # get pair-wise relative position index for each token inside the window
121
+ coords_h = torch.arange(self.window_size[0])
122
+ coords_w = torch.arange(self.window_size[1])
123
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
124
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
125
+ relative_coords = (
126
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
127
+ ) # 2, Wh*Ww, Wh*Ww
128
+ relative_coords = relative_coords.permute(
129
+ 1, 2, 0
130
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
131
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
132
+ relative_coords[:, :, 1] += self.window_size[1] - 1
133
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
134
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
135
+ self.register_buffer("relative_position_index", relative_position_index)
136
+
137
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
138
+ self.attn_drop = nn.Dropout(attn_drop)
139
+ self.proj = nn.Linear(dim, dim)
140
+ self.proj_drop = nn.Dropout(proj_drop)
141
+
142
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
143
+ self.softmax = nn.Softmax(dim=-1)
144
+
145
+ def forward(self, x, mask=None):
146
+ """Forward function.
147
+ Args:
148
+ x: input features with shape of (num_windows*B, N, C)
149
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
150
+ """
151
+ B_, N, C = x.shape
152
+ qkv = (
153
+ self.qkv(x)
154
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
155
+ .permute(2, 0, 3, 1, 4)
156
+ )
157
+ q, k, v = (
158
+ qkv[0],
159
+ qkv[1],
160
+ qkv[2],
161
+ ) # make torchscript happy (cannot use tensor as tuple)
162
+
163
+ q = q * self.scale
164
+ attn = q @ k.transpose(-2, -1)
165
+
166
+ relative_position_bias = self.relative_position_bias_table[
167
+ self.relative_position_index.view(-1)
168
+ ].view(
169
+ self.window_size[0] * self.window_size[1],
170
+ self.window_size[0] * self.window_size[1],
171
+ -1,
172
+ ) # Wh*Ww,Wh*Ww,nH
173
+ relative_position_bias = relative_position_bias.permute(
174
+ 2, 0, 1
175
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
176
+ attn = attn + relative_position_bias.unsqueeze(0)
177
+
178
+ if mask is not None:
179
+ nW = mask.shape[0]
180
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
181
+ 1
182
+ ).unsqueeze(0)
183
+ attn = attn.view(-1, self.num_heads, N, N)
184
+ attn = self.softmax(attn)
185
+ else:
186
+ attn = self.softmax(attn)
187
+
188
+ attn = self.attn_drop(attn)
189
+
190
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
191
+ x = self.proj(x)
192
+ x = self.proj_drop(x)
193
+ return x
194
+
195
+
196
+ class SwinTransformerBlock(nn.Module):
197
+ """Swin Transformer Block.
198
+ Args:
199
+ dim (int): Number of input channels.
200
+ num_heads (int): Number of attention heads.
201
+ window_size (int): Window size.
202
+ shift_size (int): Shift size for SW-MSA.
203
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
204
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
205
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
206
+ drop (float, optional): Dropout rate. Default: 0.0
207
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
208
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
209
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
210
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ dim,
216
+ num_heads,
217
+ window_size=7,
218
+ shift_size=0,
219
+ mlp_ratio=4.0,
220
+ qkv_bias=True,
221
+ qk_scale=None,
222
+ drop=0.0,
223
+ attn_drop=0.0,
224
+ drop_path=0.0,
225
+ act_layer=nn.GELU,
226
+ norm_layer=nn.LayerNorm,
227
+ ):
228
+ super().__init__()
229
+ self.dim = dim
230
+ self.num_heads = num_heads
231
+ self.window_size = window_size
232
+ self.shift_size = shift_size
233
+ self.mlp_ratio = mlp_ratio
234
+ assert (
235
+ 0 <= self.shift_size < self.window_size
236
+ ), "shift_size must in 0-window_size"
237
+
238
+ self.norm1 = norm_layer(dim)
239
+ self.attn = WindowAttention(
240
+ dim,
241
+ window_size=to_2tuple(self.window_size),
242
+ num_heads=num_heads,
243
+ qkv_bias=qkv_bias,
244
+ qk_scale=qk_scale,
245
+ attn_drop=attn_drop,
246
+ proj_drop=drop,
247
+ )
248
+
249
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
250
+ self.norm2 = norm_layer(dim)
251
+ mlp_hidden_dim = int(dim * mlp_ratio)
252
+ self.mlp = Mlp(
253
+ in_features=dim,
254
+ hidden_features=mlp_hidden_dim,
255
+ act_layer=act_layer,
256
+ drop=drop,
257
+ )
258
+
259
+ self.H = None
260
+ self.W = None
261
+
262
+ def forward(self, x, mask_matrix):
263
+ """Forward function.
264
+ Args:
265
+ x: Input feature, tensor size (B, H*W, C).
266
+ H, W: Spatial resolution of the input feature.
267
+ mask_matrix: Attention mask for cyclic shift.
268
+ """
269
+ B, L, C = x.shape
270
+ H, W = self.H, self.W
271
+ assert L == H * W, "input feature has wrong size"
272
+
273
+ shortcut = x
274
+ x = self.norm1(x)
275
+ x = x.view(B, H, W, C)
276
+
277
+ # pad feature maps to multiples of window size
278
+ pad_l = pad_t = 0
279
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
280
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
281
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
282
+ _, Hp, Wp, _ = x.shape
283
+
284
+ # cyclic shift
285
+ if self.shift_size > 0:
286
+ shifted_x = torch.roll(
287
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
288
+ )
289
+ attn_mask = mask_matrix
290
+ else:
291
+ shifted_x = x
292
+ attn_mask = None
293
+
294
+ # partition windows
295
+ x_windows = window_partition(
296
+ shifted_x, self.window_size
297
+ ) # nW*B, window_size, window_size, C
298
+ x_windows = x_windows.view(
299
+ -1, self.window_size * self.window_size, C
300
+ ) # nW*B, window_size*window_size, C
301
+
302
+ # W-MSA/SW-MSA
303
+ attn_windows = self.attn(
304
+ x_windows, mask=attn_mask
305
+ ) # nW*B, window_size*window_size, C
306
+
307
+ # merge windows
308
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
309
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
310
+
311
+ # reverse cyclic shift
312
+ if self.shift_size > 0:
313
+ x = torch.roll(
314
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
315
+ )
316
+ else:
317
+ x = shifted_x
318
+
319
+ if pad_r > 0 or pad_b > 0:
320
+ x = x[:, :H, :W, :].contiguous()
321
+
322
+ x = x.view(B, H * W, C)
323
+
324
+ # FFN
325
+ x = shortcut + self.drop_path(x)
326
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
327
+
328
+ return x
329
+
330
+
331
+ class PatchMerging(nn.Module):
332
+ """Patch Merging Layer
333
+ Args:
334
+ dim (int): Number of input channels.
335
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
336
+ """
337
+
338
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
339
+ super().__init__()
340
+ self.dim = dim
341
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
342
+ self.norm = norm_layer(4 * dim)
343
+
344
+ def forward(self, x, H, W):
345
+ """Forward function.
346
+ Args:
347
+ x: Input feature, tensor size (B, H*W, C).
348
+ H, W: Spatial resolution of the input feature.
349
+ """
350
+ B, L, C = x.shape
351
+ assert L == H * W, "input feature has wrong size"
352
+
353
+ x = x.view(B, H, W, C)
354
+
355
+ # padding
356
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
357
+ if pad_input:
358
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
359
+
360
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
361
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
362
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
363
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
364
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
365
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
366
+
367
+ x = self.norm(x)
368
+ x = self.reduction(x)
369
+
370
+ return x
371
+
372
+
373
+ class BasicLayer(nn.Module):
374
+ """A basic Swin Transformer layer for one stage.
375
+ Args:
376
+ dim (int): Number of feature channels
377
+ depth (int): Depths of this stage.
378
+ num_heads (int): Number of attention head.
379
+ window_size (int): Local window size. Default: 7.
380
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
381
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
382
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
383
+ drop (float, optional): Dropout rate. Default: 0.0
384
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
385
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
386
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
387
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
388
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ dim,
394
+ depth,
395
+ num_heads,
396
+ window_size=7,
397
+ mlp_ratio=4.0,
398
+ qkv_bias=True,
399
+ qk_scale=None,
400
+ drop=0.0,
401
+ attn_drop=0.0,
402
+ drop_path=0.0,
403
+ norm_layer=nn.LayerNorm,
404
+ downsample=None,
405
+ use_checkpoint=False,
406
+ ):
407
+ super().__init__()
408
+ self.window_size = window_size
409
+ self.shift_size = window_size // 2
410
+ self.depth = depth
411
+ self.use_checkpoint = use_checkpoint
412
+
413
+ # build blocks
414
+ self.blocks = nn.ModuleList(
415
+ [
416
+ SwinTransformerBlock(
417
+ dim=dim,
418
+ num_heads=num_heads,
419
+ window_size=window_size,
420
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
421
+ mlp_ratio=mlp_ratio,
422
+ qkv_bias=qkv_bias,
423
+ qk_scale=qk_scale,
424
+ drop=drop,
425
+ attn_drop=attn_drop,
426
+ drop_path=drop_path[i]
427
+ if isinstance(drop_path, list)
428
+ else drop_path,
429
+ norm_layer=norm_layer,
430
+ )
431
+ for i in range(depth)
432
+ ]
433
+ )
434
+
435
+ # patch merging layer
436
+ if downsample is not None:
437
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
438
+ else:
439
+ self.downsample = None
440
+
441
+ def forward(self, x, H, W):
442
+ """Forward function.
443
+ Args:
444
+ x: Input feature, tensor size (B, H*W, C).
445
+ H, W: Spatial resolution of the input feature.
446
+ """
447
+
448
+ # calculate attention mask for SW-MSA
449
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
450
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
451
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
452
+ h_slices = (
453
+ slice(0, -self.window_size),
454
+ slice(-self.window_size, -self.shift_size),
455
+ slice(-self.shift_size, None),
456
+ )
457
+ w_slices = (
458
+ slice(0, -self.window_size),
459
+ slice(-self.window_size, -self.shift_size),
460
+ slice(-self.shift_size, None),
461
+ )
462
+ cnt = 0
463
+ for h in h_slices:
464
+ for w in w_slices:
465
+ img_mask[:, h, w, :] = cnt
466
+ cnt += 1
467
+
468
+ mask_windows = window_partition(
469
+ img_mask, self.window_size
470
+ ) # nW, window_size, window_size, 1
471
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
472
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
473
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
474
+ attn_mask == 0, float(0.0)
475
+ )
476
+
477
+ for blk in self.blocks:
478
+ blk.H, blk.W = H, W
479
+ if self.use_checkpoint:
480
+ x = checkpoint.checkpoint(blk, x, attn_mask)
481
+ else:
482
+ x = blk(x, attn_mask)
483
+ if self.downsample is not None:
484
+ x_down = self.downsample(x, H, W)
485
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
486
+ return x, H, W, x_down, Wh, Ww
487
+ else:
488
+ return x, H, W, x, H, W
489
+
490
+
491
+ class PatchEmbed(nn.Module):
492
+ """Image to Patch Embedding
493
+ Args:
494
+ patch_size (int): Patch token size. Default: 4.
495
+ in_chans (int): Number of input image channels. Default: 3.
496
+ embed_dim (int): Number of linear projection output channels. Default: 96.
497
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
498
+ """
499
+
500
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
501
+ super().__init__()
502
+ patch_size = to_2tuple(patch_size)
503
+ self.patch_size = patch_size
504
+
505
+ self.in_chans = in_chans
506
+ self.embed_dim = embed_dim
507
+
508
+ self.proj = nn.Conv2d(
509
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
510
+ )
511
+ if norm_layer is not None:
512
+ self.norm = norm_layer(embed_dim)
513
+ else:
514
+ self.norm = None
515
+
516
+ def forward(self, x):
517
+ """Forward function."""
518
+ # padding
519
+ _, _, H, W = x.size()
520
+ if W % self.patch_size[1] != 0:
521
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
522
+ if H % self.patch_size[0] != 0:
523
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
524
+
525
+ x = self.proj(x) # B C Wh Ww
526
+ if self.norm is not None:
527
+ Wh, Ww = x.size(2), x.size(3)
528
+ x = x.flatten(2).transpose(1, 2)
529
+ x = self.norm(x)
530
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
531
+
532
+ return x
533
+
534
+
535
+ class SwinTransformer(nn.Module):
536
+ """Swin Transformer backbone.
537
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
538
+ https://arxiv.org/pdf/2103.14030
539
+ Args:
540
+ pretrain_img_size (int): Input image size for training the pretrained model,
541
+ used in absolute postion embedding. Default 224.
542
+ patch_size (int | tuple(int)): Patch size. Default: 4.
543
+ in_chans (int): Number of input image channels. Default: 3.
544
+ embed_dim (int): Number of linear projection output channels. Default: 96.
545
+ depths (tuple[int]): Depths of each Swin Transformer stage.
546
+ num_heads (tuple[int]): Number of attention head of each stage.
547
+ window_size (int): Window size. Default: 7.
548
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
549
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
550
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
551
+ drop_rate (float): Dropout rate.
552
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
553
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
554
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
555
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
556
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
557
+ out_indices (Sequence[int]): Output from which stages.
558
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
559
+ -1 means not freezing any parameters.
560
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
561
+ """
562
+
563
+ def __init__(
564
+ self,
565
+ pretrain_img_size=224,
566
+ patch_size=4,
567
+ in_chans=3,
568
+ embed_dim=96,
569
+ depths=[2, 2, 6, 2],
570
+ num_heads=[3, 6, 12, 24],
571
+ window_size=7,
572
+ mlp_ratio=4.0,
573
+ qkv_bias=True,
574
+ qk_scale=None,
575
+ drop_rate=0.0,
576
+ attn_drop_rate=0.0,
577
+ drop_path_rate=0.2,
578
+ norm_layer=nn.LayerNorm,
579
+ ape=False,
580
+ patch_norm=True,
581
+ out_indices=(0, 1, 2, 3),
582
+ norm_indices=None,
583
+ frozen_stages=-1,
584
+ use_checkpoint=False,
585
+ projection=False,
586
+ project_dim=256,
587
+ ):
588
+ super().__init__()
589
+
590
+ self.pretrain_img_size = pretrain_img_size
591
+ self.num_layers = len(depths)
592
+ self.embed_dim = embed_dim
593
+ self.ape = ape
594
+ self.patch_norm = patch_norm
595
+ self.out_indices = out_indices
596
+ self.norm_indices = norm_indices if norm_indices is not None else out_indices
597
+ self.frozen_stages = frozen_stages
598
+
599
+ # split image into non-overlapping patches
600
+ self.patch_embed = PatchEmbed(
601
+ patch_size=patch_size,
602
+ in_chans=in_chans,
603
+ embed_dim=embed_dim,
604
+ norm_layer=norm_layer if self.patch_norm else None,
605
+ )
606
+
607
+ # absolute position embedding
608
+ if self.ape:
609
+ pretrain_img_size = to_2tuple(pretrain_img_size)
610
+ patch_size = to_2tuple(patch_size)
611
+ patches_resolution = [
612
+ pretrain_img_size[0] // patch_size[0],
613
+ pretrain_img_size[1] // patch_size[1],
614
+ ]
615
+
616
+ self.absolute_pos_embed = nn.Parameter(
617
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
618
+ )
619
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
620
+
621
+ self.pos_drop = nn.Dropout(p=drop_rate)
622
+
623
+ # stochastic depth
624
+ dpr = [
625
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
626
+ ] # stochastic depth decay rule
627
+
628
+ # build layers
629
+ self.layers = nn.ModuleList()
630
+ for i_layer in range(self.num_layers):
631
+ layer = BasicLayer(
632
+ dim=int(embed_dim * 2 ** i_layer),
633
+ depth=depths[i_layer],
634
+ num_heads=num_heads[i_layer],
635
+ window_size=window_size,
636
+ mlp_ratio=mlp_ratio,
637
+ qkv_bias=qkv_bias,
638
+ qk_scale=qk_scale,
639
+ drop=drop_rate,
640
+ attn_drop=attn_drop_rate,
641
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
642
+ norm_layer=norm_layer,
643
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
644
+ use_checkpoint=use_checkpoint,
645
+ )
646
+ self.layers.append(layer)
647
+
648
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
649
+ self.num_features = num_features
650
+
651
+ # add a norm layer for each output
652
+ for i_layer in self.norm_indices:
653
+ if i_layer >= len(self.num_features):
654
+ continue
655
+ layer = norm_layer(num_features[i_layer])
656
+ layer_name = f"norm{i_layer}"
657
+ self.add_module(layer_name, layer)
658
+ # add projector head
659
+ self.projection = projection
660
+ if projection:
661
+ self.project_dim = project_dim
662
+ self.norm = norm_layer(self.num_features[-1])
663
+ self.projector = nn.Linear(self.num_features[-1], project_dim, bias=False)
664
+ self._freeze_stages()
665
+
666
+ def _freeze_stages(self):
667
+ if self.frozen_stages >= 0:
668
+ self.patch_embed.eval()
669
+ for param in self.patch_embed.parameters():
670
+ param.requires_grad = False
671
+
672
+ if self.frozen_stages >= 1 and self.ape:
673
+ self.absolute_pos_embed.requires_grad = False
674
+
675
+ if self.frozen_stages >= 2:
676
+ self.pos_drop.eval()
677
+ for i in range(0, self.frozen_stages - 1):
678
+ m = self.layers[i]
679
+ m.eval()
680
+ for param in m.parameters():
681
+ param.requires_grad = False
682
+
683
+ def init_weights(self, pretrained=None):
684
+ """Initialize the weights in backbone.
685
+ Args:
686
+ pretrained (str, optional): Path to pre-trained weights.
687
+ Defaults to None.
688
+ """
689
+
690
+ def _init_weights(m):
691
+ if isinstance(m, nn.Linear):
692
+ trunc_normal_(m.weight, std=0.02)
693
+ if isinstance(m, nn.Linear) and m.bias is not None:
694
+ nn.init.constant_(m.bias, 0)
695
+ elif isinstance(m, nn.LayerNorm):
696
+ nn.init.constant_(m.bias, 0)
697
+ nn.init.constant_(m.weight, 1.0)
698
+
699
+ def forward(self, x):
700
+ """Forward function."""
701
+ x = self.patch_embed(x)
702
+
703
+ Wh, Ww = x.size(2), x.size(3)
704
+ if self.ape:
705
+ # interpolate the position embedding to the corresponding size
706
+ absolute_pos_embed = F.interpolate(
707
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
708
+ )
709
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
710
+ else:
711
+ x = x.flatten(2).transpose(1, 2)
712
+ x = self.pos_drop(x)
713
+
714
+ outs = {}
715
+ for i in range(self.num_layers):
716
+ layer = self.layers[i]
717
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
718
+
719
+ if i in self.out_indices:
720
+ if i in self.norm_indices:
721
+ norm_layer = getattr(self, f"norm{i}")
722
+ x_out = norm_layer(x_out)
723
+ out = (
724
+ x_out.view(-1, H, W, self.num_features[i])
725
+ .permute(0, 3, 1, 2)
726
+ .contiguous()
727
+ )
728
+ outs["res{}".format(i + 2)] = out
729
+ if self.projection:
730
+ x_out = self.norm(x_out)
731
+ x_out = x_out.view(-1, H, W, self.num_features[-1]).contiguous()
732
+ outs["fc"] = self.projector(x_out).permute(0, 3, 1, 2)
733
+
734
+ return outs
735
+
736
+ def train(self, mode=True):
737
+ """Convert the model into training mode while keep layers freezed."""
738
+ super(SwinTransformer, self).train(mode)
739
+ self._freeze_stages()
740
+
741
+
742
+ @BACKBONE_REGISTRY.register()
743
+ class D2SwinTransformer(SwinTransformer, Backbone):
744
+ def __init__(self, cfg, input_shape):
745
+
746
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
747
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
748
+ in_chans = 3
749
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
750
+ depths = cfg.MODEL.SWIN.DEPTHS
751
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
752
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
753
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
754
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
755
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
756
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
757
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
758
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
759
+ norm_layer = nn.LayerNorm
760
+ ape = cfg.MODEL.SWIN.APE
761
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
762
+ norm_indices = cfg.MODEL.SWIN.NORM_INDICES
763
+ projection = cfg.MODEL.SWIN.PROJECTION
764
+ project_dim = cfg.MODEL.SWIN.PROJECT_DIM
765
+ super().__init__(
766
+ pretrain_img_size,
767
+ patch_size,
768
+ in_chans,
769
+ embed_dim,
770
+ depths,
771
+ num_heads,
772
+ window_size,
773
+ mlp_ratio,
774
+ qkv_bias,
775
+ qk_scale,
776
+ drop_rate,
777
+ attn_drop_rate,
778
+ drop_path_rate,
779
+ norm_layer,
780
+ ape,
781
+ patch_norm,
782
+ norm_indices=norm_indices,
783
+ projection=projection,
784
+ project_dim=project_dim,
785
+ )
786
+
787
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
788
+
789
+ self._out_feature_strides = {
790
+ "res2": 4,
791
+ "res3": 8,
792
+ "res4": 16,
793
+ "res5": 32,
794
+ "fc": 32,
795
+ }
796
+ self._out_feature_channels = {
797
+ "res2": self.num_features[0],
798
+ "res3": self.num_features[1],
799
+ "res4": self.num_features[2],
800
+ "res5": self.num_features[3],
801
+ "fc": self.num_features[3],
802
+ }
803
+
804
+ def forward(self, x):
805
+ """
806
+ Args:
807
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
808
+ Returns:
809
+ dict[str->Tensor]: names and the corresponding features
810
+ """
811
+ assert (
812
+ x.dim() == 4
813
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
814
+ outputs = {}
815
+ y = super().forward(x)
816
+ for k in y.keys():
817
+ if k in self._out_features:
818
+ outputs[k] = y[k]
819
+ return outputs
820
+
821
+ def output_shape(self):
822
+ return {
823
+ name: ShapeSpec(
824
+ channels=self._out_feature_channels[name],
825
+ stride=self._out_feature_strides[name],
826
+ )
827
+ for name in self._out_features
828
+ }
829
+
830
+ @property
831
+ def size_divisibility(self):
832
+ return 32
open_vocab_seg/modeling/clip_adapter/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from .text_template import (
5
+ PredefinedPromptExtractor,
6
+ ImageNetPromptExtractor,
7
+ VILDPromptExtractor,
8
+ )
9
+ from .adapter import ClipAdapter, MaskFormerClipAdapter
10
+
11
+
12
+ def build_text_prompt(cfg):
13
+ if cfg.TEXT_TEMPLATES == "predefined":
14
+ text_templates = PredefinedPromptExtractor(cfg.PREDEFINED_PROMPT_TEMPLATES)
15
+ elif cfg.TEXT_TEMPLATES == "imagenet":
16
+ text_templates = ImageNetPromptExtractor()
17
+ elif cfg.TEXT_TEMPLATES == "vild":
18
+ text_templates = VILDPromptExtractor()
19
+ else:
20
+ raise NotImplementedError(
21
+ "Prompt learner {} is not supported".format(cfg.TEXT_TEMPLATES)
22
+ )
23
+ return text_templates
24
+
25
+ from .clip import tokenize
open_vocab_seg/modeling/clip_adapter/adapter.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+ # Modified by Feng Liang from
4
+ # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/adapter.py
5
+
6
+ from typing import List
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from detectron2.structures import BitMasks
11
+ from .utils import build_clip_model, crop_with_mask
12
+ from .text_template import PromptExtractor
13
+
14
+
15
+ PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
16
+ PIXEL_STD = (0.26862954, 0.26130258, 0.27577711)
17
+
18
+
19
+ class ClipAdapter(nn.Module):
20
+ def __init__(self, clip_model_name: str, mask_prompt_depth: int, text_templates: PromptExtractor):
21
+ super().__init__()
22
+ self.clip_model = build_clip_model(clip_model_name, mask_prompt_depth)
23
+ self.text_templates = text_templates
24
+ self.text_templates.init_buffer(self.clip_model)
25
+ self.text_feature_buffer = {}
26
+
27
+ def forward(self, image: torch.Tensor, text: List[str], **kwargs):
28
+ image = self._preprocess_image(image, **kwargs)
29
+ text_feature = self.get_text_features(text) # k,feat_dim
30
+ image_features = self.get_image_features(image)
31
+ return self.get_sim_logits(text_feature, image_features)
32
+
33
+ def _preprocess_image(self, image: torch.Tensor):
34
+ return image
35
+
36
+ def _get_text_features(self, noun_list: List[str]):
37
+ left_noun_list = [
38
+ noun for noun in noun_list if noun not in self.text_feature_buffer
39
+ ]
40
+ if len(left_noun_list) > 0:
41
+ left_text_features = self.text_templates(
42
+ left_noun_list, self.clip_model
43
+ )
44
+ self.text_feature_buffer.update(
45
+ {
46
+ noun: text_feature
47
+ for noun, text_feature in zip(
48
+ left_noun_list, left_text_features
49
+ )
50
+ }
51
+ )
52
+ return torch.stack([self.text_feature_buffer[noun] for noun in noun_list])
53
+
54
+
55
+ def get_text_features(self, noun_list: List[str]):
56
+ return self._get_text_features(noun_list)
57
+
58
+ def get_image_features(self, image: torch.Tensor):
59
+ image_features = self.clip_model.visual(image)
60
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
61
+ return image_features
62
+
63
+ def get_sim_logits(
64
+ self,
65
+ text_features: torch.Tensor,
66
+ image_features: torch.Tensor,
67
+ temperature: float = 100,
68
+ ):
69
+ return temperature * image_features @ text_features.T
70
+
71
+ def normalize_feature(self, feat: torch.Tensor):
72
+ return feat / feat.norm(dim=-1, keepdim=True)
73
+
74
+
75
+ class MaskFormerClipAdapter(ClipAdapter):
76
+ def __init__(
77
+ self,
78
+ clip_model_name: str,
79
+ text_templates: PromptExtractor,
80
+ mask_fill: str = "mean",
81
+ mask_expand_ratio: float = 1.0,
82
+ mask_thr: float = 0.5,
83
+ mask_matting: bool = False,
84
+ region_resized: bool = True,
85
+ mask_prompt_depth: int = 0,
86
+ mask_prompt_fwd: bool = False,
87
+ ):
88
+ super().__init__(clip_model_name, mask_prompt_depth, text_templates)
89
+ self.non_object_embedding = nn.Parameter(
90
+ torch.empty(1, self.clip_model.text_projection.shape[-1])
91
+ )
92
+ nn.init.normal_(
93
+ self.non_object_embedding.data,
94
+ std=self.clip_model.transformer.width ** -0.5,
95
+ )
96
+ # for test
97
+ self.mask_fill = mask_fill
98
+ if self.mask_fill == "zero":
99
+ self.mask_fill = (0.0, 0.0, 0.0)
100
+ elif self.mask_fill == "mean":
101
+ self.mask_fill = [255.0 * c for c in PIXEL_MEAN]
102
+ else:
103
+ raise NotImplementedError(
104
+ "Unknown mask_fill method: {}".format(self.mask_fill)
105
+ )
106
+ self.mask_expand_ratio = mask_expand_ratio
107
+ self.mask_thr = mask_thr
108
+ self.mask_matting = mask_matting
109
+ self.region_resized = region_resized
110
+ self.mask_prompt_fwd = mask_prompt_fwd
111
+ self.register_buffer(
112
+ "pixel_mean", torch.Tensor(PIXEL_MEAN).reshape(1, 3, 1, 1) * 255.0
113
+ )
114
+ self.register_buffer(
115
+ "pixel_std", torch.Tensor(PIXEL_STD).reshape(1, 3, 1, 1) * 255.0
116
+ )
117
+
118
+ def forward(
119
+ self,
120
+ image: torch.Tensor,
121
+ text: List[str],
122
+ mask: torch.Tensor,
123
+ normalize: bool = True,
124
+ fwd_w_region_mask: bool = False,
125
+ ):
126
+ (regions, unnorm_regions), region_masks, valid_flag = self._preprocess_image(image, mask, normalize=normalize)
127
+ if regions is None:
128
+ return None, valid_flag
129
+ if isinstance(regions, list):
130
+ assert NotImplementedError
131
+ image_features = torch.cat(
132
+ [self.get_image_features(image_i) for image_i in regions], dim=0
133
+ )
134
+ else:
135
+ if self.mask_prompt_fwd:
136
+ image_features = self.get_image_features(regions, region_masks)
137
+ else:
138
+ image_features = self.get_image_features(regions)
139
+ text_feature = self.get_text_features(text) # k,feat_dim
140
+ return self.get_sim_logits(text_feature, image_features), unnorm_regions, valid_flag
141
+
142
+ def get_image_features(self, image, region_masks=None):
143
+ image_features = self.clip_model.visual(image, region_masks)
144
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
145
+ return image_features
146
+
147
+ def _preprocess_image(
148
+ self, image: torch.Tensor, mask: torch.Tensor, normalize: bool = True
149
+ ):
150
+ """crop, mask and normalize the image
151
+
152
+ Args:
153
+ image ([type]): [C,H,W]
154
+ mask ([type]): [K,H,W
155
+ normalize (bool, optional): [description]. Defaults to True.
156
+ """
157
+ dtype = mask.dtype
158
+ bin_mask = mask > self.mask_thr
159
+ valid = bin_mask.sum(dim=(-1, -2)) > 0
160
+ bin_mask = bin_mask[valid]
161
+ mask = mask[valid]
162
+ if not self.mask_matting:
163
+ mask = bin_mask
164
+ bin_mask = BitMasks(bin_mask)
165
+ bboxes = bin_mask.get_bounding_boxes()
166
+ # crop,mask
167
+ regions = []
168
+ region_masks = []
169
+ for bbox, single_mask in zip(bboxes, mask):
170
+ region, region_mask = crop_with_mask(
171
+ image.type(dtype),
172
+ single_mask.type(dtype),
173
+ bbox,
174
+ fill=self.mask_fill,
175
+ expand_ratio=self.mask_expand_ratio,
176
+ )
177
+ regions.append(region.unsqueeze(0))
178
+ region_masks.append(region_mask.unsqueeze(0))
179
+ if len(regions) == 0:
180
+ return None, valid
181
+ unnorm_regions = regions
182
+ if normalize:
183
+ regions = [(r - self.pixel_mean) / self.pixel_std for r in regions]
184
+ # resize
185
+ if self.region_resized:
186
+ regions = [
187
+ F.interpolate(r, size=(224, 224), mode="bicubic") for r in regions
188
+ ]
189
+ regions = torch.cat(regions)
190
+ region_masks = [
191
+ F.interpolate(r, size=(224, 224), mode="nearest") for r in region_masks
192
+ ]
193
+ region_masks = torch.cat(region_masks)
194
+ unnorm_regions = [
195
+ F.interpolate(r, size=(224, 224), mode="bicubic") for r in unnorm_regions
196
+ ]
197
+ unnorm_regions = torch.cat(unnorm_regions)
198
+ return (regions, unnorm_regions), region_masks, valid
199
+
200
+ def get_text_features(self, noun_list: List[str]):
201
+ object_text_features = self._get_text_features(noun_list)
202
+ non_object_text_features = (
203
+ self.non_object_embedding
204
+ / self.non_object_embedding.norm(dim=-1, keepdim=True)
205
+ )
206
+ return torch.cat([object_text_features, non_object_text_features], dim=0)
open_vocab_seg/modeling/clip_adapter/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
open_vocab_seg/modeling/clip_adapter/clip/clip.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from collections import OrderedDict
6
+ from typing import Union, List
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+
19
+ BICUBIC = InterpolationMode.BICUBIC
20
+ except ImportError:
21
+ BICUBIC = Image.BICUBIC
22
+
23
+
24
+ if torch.__version__.split(".") < ["1", "7", "1"]:
25
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
26
+
27
+
28
+ __all__ = ["available_models", "load", "tokenize"]
29
+ _tokenizer = _Tokenizer()
30
+
31
+ _MODELS = {
32
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
33
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
34
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
35
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
+ }
41
+
42
+
43
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if (
55
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
56
+ == expected_sha256
57
+ ):
58
+ return download_target
59
+ else:
60
+ warnings.warn(
61
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
62
+ )
63
+
64
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
65
+ with tqdm(
66
+ total=int(source.info().get("Content-Length")),
67
+ ncols=80,
68
+ unit="iB",
69
+ unit_scale=True,
70
+ ) as loop:
71
+ while True:
72
+ buffer = source.read(8192)
73
+ if not buffer:
74
+ break
75
+
76
+ output.write(buffer)
77
+ loop.update(len(buffer))
78
+
79
+ if (
80
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
81
+ != expected_sha256
82
+ ):
83
+ raise RuntimeError(
84
+ f"Model has been downloaded but the SHA256 checksum does not not match"
85
+ )
86
+
87
+ return download_target
88
+
89
+
90
+ def _transform(n_px):
91
+ return Compose(
92
+ [
93
+ Resize(n_px, interpolation=BICUBIC),
94
+ CenterCrop(n_px),
95
+ lambda image: image.convert("RGB"),
96
+ ToTensor(),
97
+ Normalize(
98
+ (0.48145466, 0.4578275, 0.40821073),
99
+ (0.26862954, 0.26130258, 0.27577711),
100
+ ),
101
+ ]
102
+ )
103
+
104
+
105
+ def available_models() -> List[str]:
106
+ """Returns the names of available CLIP models"""
107
+ return list(_MODELS.keys())
108
+
109
+
110
+ def load(
111
+ name: str,
112
+ mask_prompt_depth: int = 0,
113
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
114
+ jit=False,
115
+ ):
116
+ """Load a CLIP model
117
+
118
+ Parameters
119
+ ----------
120
+ name : str
121
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
122
+
123
+ device : Union[str, torch.device]
124
+ The device to put the loaded model
125
+
126
+ jit : bool
127
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
128
+
129
+ Returns
130
+ -------
131
+ model : torch.nn.Module
132
+ The CLIP model
133
+
134
+ preprocess : Callable[[PIL.Image], torch.Tensor]
135
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
136
+ """
137
+ if name in _MODELS:
138
+ model_path = _download(_MODELS[name])
139
+ elif os.path.isfile(name):
140
+ model_path = name
141
+ else:
142
+ raise RuntimeError(
143
+ f"Model {name} not found; available models = {available_models()}"
144
+ )
145
+
146
+ try:
147
+ # loading JIT archive
148
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
149
+ state_dict = None
150
+ except RuntimeError:
151
+ # loading saved state dict
152
+ if jit:
153
+ warnings.warn(
154
+ f"File {model_path} is not a JIT archive. Loading as a state dict instead"
155
+ )
156
+ jit = False
157
+ state_dict = torch.load(model_path, map_location="cpu")
158
+ if 'state_dict' in state_dict:
159
+ new_state_dict = OrderedDict()
160
+ for k, v in state_dict['state_dict'].items():
161
+ if k.startswith('module.'):
162
+ name = k[7:] # remove `module.`
163
+ new_state_dict[name] = v
164
+ state_dict = new_state_dict
165
+
166
+ if not jit:
167
+ model = build_model(state_dict or model.state_dict(), mask_prompt_depth).to(device)
168
+ if str(device) == "cpu":
169
+ model.float()
170
+ return model, _transform(model.visual.input_resolution)
171
+
172
+ # patch the device names
173
+ device_holder = torch.jit.trace(
174
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
175
+ )
176
+ device_node = [
177
+ n
178
+ for n in device_holder.graph.findAllNodes("prim::Constant")
179
+ if "Device" in repr(n)
180
+ ][-1]
181
+
182
+ def patch_device(module):
183
+ try:
184
+ graphs = [module.graph] if hasattr(module, "graph") else []
185
+ except RuntimeError:
186
+ graphs = []
187
+
188
+ if hasattr(module, "forward1"):
189
+ graphs.append(module.forward1.graph)
190
+
191
+ for graph in graphs:
192
+ for node in graph.findAllNodes("prim::Constant"):
193
+ if "value" in node.attributeNames() and str(node["value"]).startswith(
194
+ "cuda"
195
+ ):
196
+ node.copyAttributes(device_node)
197
+
198
+ model.apply(patch_device)
199
+ patch_device(model.encode_image)
200
+ patch_device(model.encode_text)
201
+
202
+ # patch dtype to float32 on CPU
203
+ if str(device) == "cpu":
204
+ float_holder = torch.jit.trace(
205
+ lambda: torch.ones([]).float(), example_inputs=[]
206
+ )
207
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
208
+ float_node = float_input.node()
209
+
210
+ def patch_float(module):
211
+ try:
212
+ graphs = [module.graph] if hasattr(module, "graph") else []
213
+ except RuntimeError:
214
+ graphs = []
215
+
216
+ if hasattr(module, "forward1"):
217
+ graphs.append(module.forward1.graph)
218
+
219
+ for graph in graphs:
220
+ for node in graph.findAllNodes("aten::to"):
221
+ inputs = list(node.inputs())
222
+ for i in [
223
+ 1,
224
+ 2,
225
+ ]: # dtype can be the second or third argument to aten::to()
226
+ if inputs[i].node()["value"] == 5:
227
+ inputs[i].node().copyAttributes(float_node)
228
+
229
+ model.apply(patch_float)
230
+ patch_float(model.encode_image)
231
+ patch_float(model.encode_text)
232
+
233
+ model.float()
234
+
235
+ return model, _transform(model.input_resolution.item())
236
+
237
+
238
+ def tokenize(
239
+ texts: Union[str, List[str]],
240
+ context_length: int = 77,
241
+ truncate: bool = False,
242
+ return_length: bool = False,
243
+ ) -> torch.LongTensor:
244
+ """
245
+ Returns the tokenized representation of given input string(s)
246
+
247
+ Parameters
248
+ ----------
249
+ texts : Union[str, List[str]]
250
+ An input string or a list of input strings to tokenize
251
+
252
+ context_length : int
253
+ The context length to use; all CLIP models use 77 as the context length
254
+
255
+ truncate: bool
256
+ Whether to truncate the text in case its encoding is longer than the context length
257
+
258
+ Returns
259
+ -------
260
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
261
+ """
262
+ if isinstance(texts, str):
263
+ texts = [texts]
264
+
265
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
266
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
267
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
268
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
269
+ length = []
270
+ for i, tokens in enumerate(all_tokens):
271
+ if len(tokens) > context_length:
272
+ if truncate:
273
+ tokens = tokens[:context_length]
274
+ tokens[-1] = eot_token
275
+ length.append(context_length)
276
+ else:
277
+ raise RuntimeError(
278
+ f"Input {texts[i]} is too long for context length {context_length}"
279
+ )
280
+ else:
281
+ length.append(len(tokens))
282
+ result[i, : len(tokens)] = torch.tensor(tokens)
283
+ if return_length:
284
+ return result, length
285
+ return result
open_vocab_seg/modeling/clip_adapter/clip/model.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+ # Modified by Feng Liang from https://github.com/openai/CLIP/blob/main/clip/model.py
4
+
5
+ from collections import OrderedDict
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+
14
+ class Bottleneck(nn.Module):
15
+ expansion = 4
16
+
17
+ def __init__(self, inplanes, planes, stride=1):
18
+ super().__init__()
19
+
20
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
21
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
22
+ self.bn1 = nn.BatchNorm2d(planes)
23
+
24
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
25
+ self.bn2 = nn.BatchNorm2d(planes)
26
+
27
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28
+
29
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31
+
32
+ self.relu = nn.ReLU(inplace=True)
33
+ self.downsample = None
34
+ self.stride = stride
35
+
36
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
37
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38
+ self.downsample = nn.Sequential(
39
+ OrderedDict(
40
+ [
41
+ ("-1", nn.AvgPool2d(stride)),
42
+ (
43
+ "0",
44
+ nn.Conv2d(
45
+ inplanes,
46
+ planes * self.expansion,
47
+ 1,
48
+ stride=1,
49
+ bias=False,
50
+ ),
51
+ ),
52
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
53
+ ]
54
+ )
55
+ )
56
+
57
+ def forward(self, x: torch.Tensor):
58
+ identity = x
59
+
60
+ out = self.relu(self.bn1(self.conv1(x)))
61
+ out = self.relu(self.bn2(self.conv2(out)))
62
+ out = self.avgpool(out)
63
+ out = self.bn3(self.conv3(out))
64
+
65
+ if self.downsample is not None:
66
+ identity = self.downsample(x)
67
+
68
+ out += identity
69
+ out = self.relu(out)
70
+ return out
71
+
72
+
73
+ class AttentionPool2d(nn.Module):
74
+ def __init__(
75
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
76
+ ):
77
+ super().__init__()
78
+ self.positional_embedding = nn.Parameter(
79
+ torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
80
+ )
81
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
82
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
83
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
84
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
85
+ self.num_heads = num_heads
86
+ self.grid_size = spacial_dim
87
+
88
+ def forward(self, x, mask=None, return_cls=True):
89
+ b, c, gh, gw = x.shape
90
+ # remove irrelated feature
91
+ if mask is not None:
92
+ mask = F.interpolate(mask[:, None, ...], size=(gh, gw)).squeeze(
93
+ 1
94
+ ) # [N,H,W] -> [N,grid,grid]
95
+ mask = (mask > 0.5).reshape(mask.shape[0], -1)
96
+ mask = torch.cat([mask, mask.new_ones(mask.shape[0], 1)], dim=1)
97
+ if x.size()[0] == 1:
98
+ x = x.expand(mask.shape[0], c, gh, gw)
99
+
100
+ x = x.reshape(x.shape[0], c, gh * gw).permute(2, 0, 1) # NCHW -> (HW)NC
101
+
102
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
103
+ positional_embedding = self.positional_embedding
104
+ if not (self.positional_embedding.shape[0] == x.shape[0]):
105
+ cls_pos = positional_embedding[0:1, :]
106
+ per_pos_embedding = (
107
+ F.interpolate(
108
+ positional_embedding[1:, :]
109
+ .permute(1, 0)
110
+ .view(1, -1, self.grid_size, self.grid_size),
111
+ size=(gh, gw),
112
+ mode="bicubic",
113
+ )
114
+ .reshape(-1, gh * gw)
115
+ .permute(1, 0)
116
+ )
117
+ positional_embedding = torch.cat([cls_pos, per_pos_embedding])
118
+
119
+ x = x + positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
120
+ x, _ = F.multi_head_attention_forward(
121
+ query=x,
122
+ key=x,
123
+ value=x,
124
+ embed_dim_to_check=x.shape[-1],
125
+ num_heads=self.num_heads,
126
+ q_proj_weight=self.q_proj.weight,
127
+ k_proj_weight=self.k_proj.weight,
128
+ v_proj_weight=self.v_proj.weight,
129
+ in_proj_weight=None,
130
+ in_proj_bias=torch.cat(
131
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
132
+ ),
133
+ bias_k=None,
134
+ bias_v=None,
135
+ add_zero_attn=False,
136
+ dropout_p=0,
137
+ out_proj_weight=self.c_proj.weight,
138
+ out_proj_bias=self.c_proj.bias,
139
+ use_separate_proj_weight=True,
140
+ training=self.training,
141
+ need_weights=False,
142
+ key_padding_mask=mask,
143
+ )
144
+
145
+ if return_cls:
146
+ return x[0]
147
+ else:
148
+ return x
149
+
150
+
151
+ class ModifiedResNet(nn.Module):
152
+ """
153
+ A ResNet class that is similar to torchvision's but contains the following changes:
154
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
155
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
156
+ - The final pooling layer is a QKV attention instead of an average pool
157
+ """
158
+
159
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
160
+ super().__init__()
161
+ self.output_dim = output_dim
162
+ self.input_resolution = input_resolution
163
+
164
+ # the 3-layer stem
165
+ self.conv1 = nn.Conv2d(
166
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
167
+ )
168
+ self.bn1 = nn.BatchNorm2d(width // 2)
169
+ self.conv2 = nn.Conv2d(
170
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
171
+ )
172
+ self.bn2 = nn.BatchNorm2d(width // 2)
173
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
174
+ self.bn3 = nn.BatchNorm2d(width)
175
+ self.avgpool = nn.AvgPool2d(2)
176
+ self.relu = nn.ReLU(inplace=True)
177
+
178
+ # residual layers
179
+ self._inplanes = width # this is a *mutable* variable used during construction
180
+ self.layer1 = self._make_layer(width, layers[0])
181
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
182
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
183
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
184
+
185
+ embed_dim = width * 32 # the ResNet feature dimension
186
+ self.attnpool = AttentionPool2d(
187
+ input_resolution // 32, embed_dim, heads, output_dim
188
+ )
189
+
190
+ def _make_layer(self, planes, blocks, stride=1):
191
+ layers = [Bottleneck(self._inplanes, planes, stride)]
192
+
193
+ self._inplanes = planes * Bottleneck.expansion
194
+ for _ in range(1, blocks):
195
+ layers.append(Bottleneck(self._inplanes, planes))
196
+
197
+ return nn.Sequential(*layers)
198
+
199
+ def forward(self, x, mask: torch.Tensor = None, return_cls=True):
200
+ def stem(x):
201
+ for conv, bn in [
202
+ (self.conv1, self.bn1),
203
+ (self.conv2, self.bn2),
204
+ (self.conv3, self.bn3),
205
+ ]:
206
+ x = self.relu(bn(conv(x)))
207
+ x = self.avgpool(x)
208
+ return x
209
+
210
+ x = x.type(self.conv1.weight.dtype)
211
+ x = stem(x) # 1/4,1/4
212
+ x = self.layer1(x)
213
+ x = self.layer2(x) # 1/8,1/8
214
+ x = self.layer3(x) # 1/16,1/16
215
+ x = self.layer4(x) # 1/32,1/32
216
+ b, c, gh, gw = x.shape
217
+ x = self.attnpool(x, mask, return_cls)
218
+ if not return_cls:
219
+ return x[1:].permute(1, 0, 2).reshape(b, gh, gw, x.shape[-1]) # N,L,C
220
+ return x
221
+
222
+
223
+ class LayerNorm(nn.LayerNorm):
224
+ """Subclass torch's LayerNorm to handle fp16."""
225
+
226
+ def forward(self, x: torch.Tensor):
227
+ orig_type = x.dtype
228
+ ret = super().forward(x.type(torch.float32))
229
+ return ret.type(orig_type)
230
+
231
+
232
+ class QuickGELU(nn.Module):
233
+ def forward(self, x: torch.Tensor):
234
+ return x * torch.sigmoid(1.702 * x)
235
+
236
+
237
+ class ResidualAttentionBlock(nn.Module):
238
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
239
+ super().__init__()
240
+
241
+ self.attn = nn.MultiheadAttention(d_model, n_head)
242
+ self.ln_1 = LayerNorm(d_model)
243
+ self.mlp = nn.Sequential(
244
+ OrderedDict(
245
+ [
246
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
247
+ ("gelu", QuickGELU()),
248
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
249
+ ]
250
+ )
251
+ )
252
+ self.ln_2 = LayerNorm(d_model)
253
+ self.attn_mask = attn_mask
254
+
255
+ def attention(self, x: torch.Tensor, **kwargs):
256
+ self.attn_mask = (
257
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
258
+ if self.attn_mask is not None
259
+ else None
260
+ )
261
+ return self.attn(
262
+ x, x, x, need_weights=False, attn_mask=self.attn_mask, **kwargs
263
+ )[0]
264
+
265
+ def forward(self, x: torch.Tensor, **kwargs):
266
+ x = x + self.attention(self.ln_1(x), **kwargs)
267
+ x = x + self.mlp(self.ln_2(x))
268
+ return x
269
+
270
+
271
+ class Transformer(nn.Module):
272
+ def __init__(
273
+ self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
274
+ ):
275
+ super().__init__()
276
+ self.width = width
277
+ self.layers = layers
278
+ self.resblocks = nn.Sequential(
279
+ *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
280
+ )
281
+
282
+ def forward(self, x: torch.Tensor, **kwargs):
283
+ for block in self.resblocks:
284
+ x = block(x, **kwargs)
285
+ return x
286
+
287
+
288
+ class VisionTransformer(nn.Module):
289
+ def __init__(
290
+ self,
291
+ input_resolution: int,
292
+ patch_size: int,
293
+ mask_prompt_depth: int,
294
+ width: int,
295
+ layers: int,
296
+ heads: int,
297
+ output_dim: int,
298
+ ):
299
+ super().__init__()
300
+ self.input_resolution = input_resolution
301
+ self.output_dim = output_dim
302
+ self.conv1 = nn.Conv2d(
303
+ in_channels=3,
304
+ out_channels=width,
305
+ kernel_size=patch_size,
306
+ stride=patch_size,
307
+ bias=False,
308
+ )
309
+
310
+ scale = width ** -0.5
311
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
312
+ self.positional_embedding = nn.Parameter(
313
+ scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
314
+ )
315
+ self.grid_size = input_resolution // patch_size
316
+ self.ln_pre = LayerNorm(width)
317
+
318
+ self.transformer = Transformer(width, layers, heads)
319
+
320
+ self.ln_post = LayerNorm(width)
321
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
322
+
323
+ self.mask_pool = nn.AvgPool2d(patch_size, stride=patch_size)
324
+ self.mask_prompt_depth = mask_prompt_depth
325
+ self.mask_embedding = nn.Parameter(torch.zeros(self.mask_prompt_depth, self.grid_size * self.grid_size, width))
326
+
327
+ def forward(self, x: torch.Tensor, m: torch.Tensor = None):
328
+ x = self.conv1(x) # shape = [*, width, grid, grid]
329
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
330
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
331
+ if m is not None:
332
+ m = self.mask_pool(m.to(torch.float).squeeze()).reshape(m.shape[0], -1).unsqueeze(-1)
333
+ m = torch.ceil(m)
334
+ if self.mask_embedding.shape[1] == 1:
335
+ mask_embedding = self.mask_embedding.to(x.dtype).repeat(1, x.shape[1], 1)
336
+ else:
337
+ mask_embedding = self.mask_embedding.to(x.dtype)
338
+ x = x * m + mask_embedding[0].unsqueeze(0) * (1 - m)
339
+
340
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
341
+ x = x + self.positional_embedding.to(x.dtype)
342
+ x = self.ln_pre(x)
343
+
344
+ x = x.permute(1, 0, 2) # NLD -> LND
345
+ if m is not None:
346
+ for i, blk in enumerate(self.transformer.resblocks):
347
+ d = i + 1
348
+ x = blk(x)
349
+ if d < self.mask_prompt_depth:
350
+ masked_x = x[1:, :, :] * m.permute(1, 0, 2) + \
351
+ mask_embedding[d].unsqueeze(0).permute(1, 0, 2) * (1 - m.permute(1, 0, 2))
352
+ x = torch.cat([x[:1, :, :], masked_x], dim=0)
353
+ else:
354
+ x = self.transformer(x)
355
+ x = x.permute(1, 0, 2) # LND -> NLD
356
+
357
+ x = self.ln_post(x[:, 0, :])
358
+
359
+ if self.proj is not None:
360
+ x = x @ self.proj
361
+
362
+ return x
363
+
364
+
365
+
366
+ class CLIP(nn.Module):
367
+ def __init__(
368
+ self,
369
+ embed_dim: int,
370
+ # vision
371
+ image_resolution: int,
372
+ vision_layers: Union[Tuple[int, int, int, int], int],
373
+ vision_width: int,
374
+ vision_patch_size: int,
375
+ mask_prompt_depth: int,
376
+ # text
377
+ context_length: int,
378
+ vocab_size: int,
379
+ transformer_width: int,
380
+ transformer_heads: int,
381
+ transformer_layers: int,
382
+ ):
383
+ super().__init__()
384
+
385
+ self.context_length = context_length
386
+
387
+ if isinstance(vision_layers, (tuple, list)):
388
+ vision_heads = vision_width * 32 // 64
389
+ self.visual = ModifiedResNet(
390
+ layers=vision_layers,
391
+ output_dim=embed_dim,
392
+ heads=vision_heads,
393
+ input_resolution=image_resolution,
394
+ width=vision_width,
395
+ )
396
+ else:
397
+ vision_heads = vision_width // 64
398
+ self.visual = VisionTransformer(
399
+ input_resolution=image_resolution,
400
+ patch_size=vision_patch_size,
401
+ mask_prompt_depth=mask_prompt_depth,
402
+ width=vision_width,
403
+ layers=vision_layers,
404
+ heads=vision_heads,
405
+ output_dim=embed_dim,
406
+ )
407
+
408
+ self.transformer = Transformer(
409
+ width=transformer_width,
410
+ layers=transformer_layers,
411
+ heads=transformer_heads,
412
+ attn_mask=self.build_attention_mask(),
413
+ )
414
+
415
+ self.vocab_size = vocab_size
416
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
417
+ self.positional_embedding = nn.Parameter(
418
+ torch.empty(self.context_length, transformer_width)
419
+ )
420
+ self.ln_final = LayerNorm(transformer_width)
421
+
422
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
423
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
424
+
425
+ self.initialize_parameters()
426
+
427
+ def initialize_parameters(self):
428
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
429
+ nn.init.normal_(self.positional_embedding, std=0.01)
430
+
431
+ if isinstance(self.visual, ModifiedResNet):
432
+ if self.visual.attnpool is not None:
433
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
434
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
435
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
436
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
437
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
438
+
439
+ for resnet_block in [
440
+ self.visual.layer1,
441
+ self.visual.layer2,
442
+ self.visual.layer3,
443
+ self.visual.layer4,
444
+ ]:
445
+ for name, param in resnet_block.named_parameters():
446
+ if name.endswith("bn3.weight"):
447
+ nn.init.zeros_(param)
448
+
449
+ proj_std = (self.transformer.width ** -0.5) * (
450
+ (2 * self.transformer.layers) ** -0.5
451
+ )
452
+ attn_std = self.transformer.width ** -0.5
453
+ fc_std = (2 * self.transformer.width) ** -0.5
454
+ for block in self.transformer.resblocks:
455
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
456
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
457
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
458
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
459
+
460
+ if self.text_projection is not None:
461
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
462
+
463
+ def build_attention_mask(self):
464
+ # lazily create causal attention mask, with full attention between the vision tokens
465
+ # pytorch uses additive attention mask; fill with -inf
466
+ mask = torch.empty(self.context_length, self.context_length)
467
+ mask.fill_(float("-inf"))
468
+ mask.triu_(1) # zero out the lower diagonal
469
+ return mask
470
+
471
+ @property
472
+ def dtype(self):
473
+ return self.visual.conv1.weight.dtype
474
+
475
+ def encode_image(self, image, **kwargs):
476
+ return self.visual(image.type(self.dtype), **kwargs)
477
+
478
+ def encode_text(self, text):
479
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
480
+
481
+ x = x + self.positional_embedding.type(self.dtype)
482
+ x = x.permute(1, 0, 2) # NLD -> LND
483
+ x = self.transformer(x)
484
+ x = x.permute(1, 0, 2) # LND -> NLD
485
+ x = self.ln_final(x).type(self.dtype)
486
+
487
+ # x.shape = [batch_size, n_ctx, transformer.width]
488
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
489
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
490
+
491
+ return x
492
+
493
+ def forward(self, image, text):
494
+ image_features = self.encode_image(image)
495
+ text_features = self.encode_text(text)
496
+
497
+ # normalized features
498
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
499
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
500
+
501
+ # cosine similarity as logits
502
+ logit_scale = self.logit_scale.exp()
503
+ logits_per_image = logit_scale * image_features @ text_features.t()
504
+ logits_per_text = logit_scale * text_features @ image_features.t()
505
+
506
+ # shape = [global_batch_size, global_batch_size]
507
+ return logits_per_image, logits_per_text
508
+
509
+
510
+ def convert_weights(model: nn.Module):
511
+ """Convert applicable model parameters to fp16"""
512
+
513
+ def _convert_weights_to_fp16(l):
514
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
515
+ l.weight.data = l.weight.data.half()
516
+ if l.bias is not None:
517
+ l.bias.data = l.bias.data.half()
518
+
519
+ if isinstance(l, nn.MultiheadAttention):
520
+ for attr in [
521
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
522
+ "in_proj_bias",
523
+ "bias_k",
524
+ "bias_v",
525
+ ]:
526
+ tensor = getattr(l, attr)
527
+ if tensor is not None:
528
+ tensor.data = tensor.data.half()
529
+
530
+ for name in ["text_projection", "proj"]:
531
+ if hasattr(l, name):
532
+ attr = getattr(l, name)
533
+ if attr is not None:
534
+ attr.data = attr.data.half()
535
+
536
+ model.apply(_convert_weights_to_fp16)
537
+
538
+
539
+ def build_model(state_dict: dict, mask_prompt_depth: int = 0):
540
+ vit = "visual.proj" in state_dict
541
+
542
+ if vit:
543
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
544
+ vision_layers = len(
545
+ [
546
+ k
547
+ for k in state_dict.keys()
548
+ if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
549
+ ]
550
+ )
551
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
552
+ grid_size = round(
553
+ (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
554
+ )
555
+ image_resolution = vision_patch_size * grid_size
556
+ else:
557
+ assert mask_prompt_depth == 0, 'ResNets do not support mask prompt tuning'
558
+ counts: list = [
559
+ len(
560
+ set(
561
+ k.split(".")[2]
562
+ for k in state_dict
563
+ if k.startswith(f"visual.layer{b}")
564
+ )
565
+ )
566
+ for b in [1, 2, 3, 4]
567
+ ]
568
+ vision_layers = tuple(counts)
569
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
570
+ output_width = round(
571
+ (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
572
+ )
573
+ vision_patch_size = None
574
+ assert (
575
+ output_width ** 2 + 1
576
+ == state_dict["visual.attnpool.positional_embedding"].shape[0]
577
+ )
578
+ image_resolution = output_width * 32
579
+
580
+ embed_dim = state_dict["text_projection"].shape[1]
581
+ context_length = state_dict["positional_embedding"].shape[0]
582
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
583
+ transformer_width = state_dict["ln_final.weight"].shape[0]
584
+ transformer_heads = transformer_width // 64
585
+ transformer_layers = len(
586
+ set(
587
+ k.split(".")[2]
588
+ for k in state_dict
589
+ if k.startswith(f"transformer.resblocks")
590
+ )
591
+ )
592
+
593
+ model = CLIP(
594
+ embed_dim,
595
+ image_resolution,
596
+ vision_layers,
597
+ vision_width,
598
+ vision_patch_size,
599
+ mask_prompt_depth,
600
+ context_length,
601
+ vocab_size,
602
+ transformer_width,
603
+ transformer_heads,
604
+ transformer_layers,
605
+ )
606
+
607
+ for key in ["input_resolution", "context_length", "vocab_size"]:
608
+ if key in state_dict:
609
+ del state_dict[key]
610
+
611
+ convert_weights(model)
612
+ model.load_state_dict(state_dict, strict=False)
613
+ return model.eval()
open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(
13
+ os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
14
+ )
15
+
16
+
17
+ @lru_cache()
18
+ def bytes_to_unicode():
19
+ """
20
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
21
+ The reversible bpe codes work on unicode strings.
22
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
23
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
24
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
25
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
26
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
27
+ """
28
+ bs = (
29
+ list(range(ord("!"), ord("~") + 1))
30
+ + list(range(ord("¡"), ord("¬") + 1))
31
+ + list(range(ord("®"), ord("ÿ") + 1))
32
+ )
33
+ cs = bs[:]
34
+ n = 0
35
+ for b in range(2 ** 8):
36
+ if b not in bs:
37
+ bs.append(b)
38
+ cs.append(2 ** 8 + n)
39
+ n += 1
40
+ cs = [chr(n) for n in cs]
41
+ return dict(zip(bs, cs))
42
+
43
+
44
+ def get_pairs(word):
45
+ """Return set of symbol pairs in a word.
46
+ Word is represented as tuple of symbols (symbols being variable-length strings).
47
+ """
48
+ pairs = set()
49
+ prev_char = word[0]
50
+ for char in word[1:]:
51
+ pairs.add((prev_char, char))
52
+ prev_char = char
53
+ return pairs
54
+
55
+
56
+ def basic_clean(text):
57
+ text = ftfy.fix_text(text)
58
+ text = html.unescape(html.unescape(text))
59
+ return text.strip()
60
+
61
+
62
+ def whitespace_clean(text):
63
+ text = re.sub(r"\s+", " ", text)
64
+ text = text.strip()
65
+ return text
66
+
67
+
68
+ class SimpleTokenizer(object):
69
+ def __init__(self, bpe_path: str = default_bpe()):
70
+ self.byte_encoder = bytes_to_unicode()
71
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
72
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
73
+ merges = merges[1 : 49152 - 256 - 2 + 1]
74
+ merges = [tuple(merge.split()) for merge in merges]
75
+ vocab = list(bytes_to_unicode().values())
76
+ vocab = vocab + [v + "</w>" for v in vocab]
77
+ for merge in merges:
78
+ vocab.append("".join(merge))
79
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
80
+ self.encoder = dict(zip(vocab, range(len(vocab))))
81
+ self.decoder = {v: k for k, v in self.encoder.items()}
82
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
83
+ self.cache = {
84
+ "<|startoftext|>": "<|startoftext|>",
85
+ "<|endoftext|>": "<|endoftext|>",
86
+ }
87
+ self.pat = re.compile(
88
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
89
+ re.IGNORECASE,
90
+ )
91
+
92
+ def bpe(self, token):
93
+ if token in self.cache:
94
+ return self.cache[token]
95
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
96
+ pairs = get_pairs(word)
97
+
98
+ if not pairs:
99
+ return token + "</w>"
100
+
101
+ while True:
102
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
103
+ if bigram not in self.bpe_ranks:
104
+ break
105
+ first, second = bigram
106
+ new_word = []
107
+ i = 0
108
+ while i < len(word):
109
+ try:
110
+ j = word.index(first, i)
111
+ new_word.extend(word[i:j])
112
+ i = j
113
+ except:
114
+ new_word.extend(word[i:])
115
+ break
116
+
117
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
118
+ new_word.append(first + second)
119
+ i += 2
120
+ else:
121
+ new_word.append(word[i])
122
+ i += 1
123
+ new_word = tuple(new_word)
124
+ word = new_word
125
+ if len(word) == 1:
126
+ break
127
+ else:
128
+ pairs = get_pairs(word)
129
+ word = " ".join(word)
130
+ self.cache[token] = word
131
+ return word
132
+
133
+ def encode(self, text):
134
+ bpe_tokens = []
135
+ text = whitespace_clean(basic_clean(text)).lower()
136
+ for token in re.findall(self.pat, text):
137
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
138
+ bpe_tokens.extend(
139
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
140
+ )
141
+ return bpe_tokens
142
+
143
+ def decode(self, tokens):
144
+ text = "".join([self.decoder[token] for token in tokens])
145
+ text = (
146
+ bytearray([self.byte_decoder[c] for c in text])
147
+ .decode("utf-8", errors="replace")
148
+ .replace("</w>", " ")
149
+ )
150
+ return text
open_vocab_seg/modeling/clip_adapter/text_template.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+ # Modified by Feng Liang from
4
+ # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/text_prompt.py
5
+ # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/utils.py
6
+
7
+ from typing import List
8
+
9
+ # import clip
10
+ from .clip import tokenize
11
+ import torch
12
+ from torch import nn
13
+
14
+ IMAGENET_PROMPT = [
15
+ "a bad photo of a {}.",
16
+ "a photo of many {}.",
17
+ "a sculpture of a {}.",
18
+ "a photo of the hard to see {}.",
19
+ "a low resolution photo of the {}.",
20
+ "a rendering of a {}.",
21
+ "graffiti of a {}.",
22
+ "a bad photo of the {}.",
23
+ "a cropped photo of the {}.",
24
+ "a tattoo of a {}.",
25
+ "the embroidered {}.",
26
+ "a photo of a hard to see {}.",
27
+ "a bright photo of a {}.",
28
+ "a photo of a clean {}.",
29
+ "a photo of a dirty {}.",
30
+ "a dark photo of the {}.",
31
+ "a drawing of a {}.",
32
+ "a photo of my {}.",
33
+ "the plastic {}.",
34
+ "a photo of the cool {}.",
35
+ "a close-up photo of a {}.",
36
+ "a black and white photo of the {}.",
37
+ "a painting of the {}.",
38
+ "a painting of a {}.",
39
+ "a pixelated photo of the {}.",
40
+ "a sculpture of the {}.",
41
+ "a bright photo of the {}.",
42
+ "a cropped photo of a {}.",
43
+ "a plastic {}.",
44
+ "a photo of the dirty {}.",
45
+ "a jpeg corrupted photo of a {}.",
46
+ "a blurry photo of the {}.",
47
+ "a photo of the {}.",
48
+ "a good photo of the {}.",
49
+ "a rendering of the {}.",
50
+ "a {} in a video game.",
51
+ "a photo of one {}.",
52
+ "a doodle of a {}.",
53
+ "a close-up photo of the {}.",
54
+ "a photo of a {}.",
55
+ "the origami {}.",
56
+ "the {} in a video game.",
57
+ "a sketch of a {}.",
58
+ "a doodle of the {}.",
59
+ "a origami {}.",
60
+ "a low resolution photo of a {}.",
61
+ "the toy {}.",
62
+ "a rendition of the {}.",
63
+ "a photo of the clean {}.",
64
+ "a photo of a large {}.",
65
+ "a rendition of a {}.",
66
+ "a photo of a nice {}.",
67
+ "a photo of a weird {}.",
68
+ "a blurry photo of a {}.",
69
+ "a cartoon {}.",
70
+ "art of a {}.",
71
+ "a sketch of the {}.",
72
+ "a embroidered {}.",
73
+ "a pixelated photo of a {}.",
74
+ "itap of the {}.",
75
+ "a jpeg corrupted photo of the {}.",
76
+ "a good photo of a {}.",
77
+ "a plushie {}.",
78
+ "a photo of the nice {}.",
79
+ "a photo of the small {}.",
80
+ "a photo of the weird {}.",
81
+ "the cartoon {}.",
82
+ "art of the {}.",
83
+ "a drawing of the {}.",
84
+ "a photo of the large {}.",
85
+ "a black and white photo of a {}.",
86
+ "the plushie {}.",
87
+ "a dark photo of a {}.",
88
+ "itap of a {}.",
89
+ "graffiti of the {}.",
90
+ "a toy {}.",
91
+ "itap of my {}.",
92
+ "a photo of a cool {}.",
93
+ "a photo of a small {}.",
94
+ "a tattoo of the {}.",
95
+ ]
96
+
97
+ VILD_PROMPT = [
98
+ "a photo of a {}.",
99
+ "This is a photo of a {}",
100
+ "There is a {} in the scene",
101
+ "There is the {} in the scene",
102
+ "a photo of a {} in the scene",
103
+ "a photo of a small {}.",
104
+ "a photo of a medium {}.",
105
+ "a photo of a large {}.",
106
+ "This is a photo of a small {}.",
107
+ "This is a photo of a medium {}.",
108
+ "This is a photo of a large {}.",
109
+ "There is a small {} in the scene.",
110
+ "There is a medium {} in the scene.",
111
+ "There is a large {} in the scene.",
112
+ ]
113
+
114
+ class PromptExtractor(nn.Module):
115
+ def __init__(self):
116
+ super().__init__()
117
+ self._buffer_init = False
118
+
119
+ def init_buffer(self, clip_model):
120
+ self._buffer_init = True
121
+
122
+ def forward(self, noun_list: List[str], clip_model: nn.Module):
123
+ raise NotImplementedError()
124
+
125
+
126
+ class PredefinedPromptExtractor(PromptExtractor):
127
+ def __init__(self, templates: List[str]):
128
+ super().__init__()
129
+ self.templates = templates
130
+
131
+ def forward(self, noun_list: List[str], clip_model: nn.Module):
132
+ text_features_bucket = []
133
+ for template in self.templates:
134
+ noun_tokens = [tokenize(template.format(noun)) for noun in noun_list]
135
+ text_inputs = torch.cat(noun_tokens).to(
136
+ clip_model.text_projection.data.device
137
+ )
138
+ text_features = clip_model.encode_text(text_inputs)
139
+ text_features /= text_features.norm(dim=-1, keepdim=True)
140
+ text_features_bucket.append(text_features)
141
+ del text_inputs
142
+ # ensemble by averaging
143
+ text_features = torch.stack(text_features_bucket).mean(dim=0)
144
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
145
+
146
+ return text_features
147
+
148
+
149
+ class ImageNetPromptExtractor(PredefinedPromptExtractor):
150
+ def __init__(self):
151
+ super().__init__(IMAGENET_PROMPT)
152
+
153
+
154
+ class VILDPromptExtractor(PredefinedPromptExtractor):
155
+ def __init__(self):
156
+ super().__init__(VILD_PROMPT)
open_vocab_seg/modeling/clip_adapter/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ from typing import Tuple
5
+ import numpy as np
6
+ import torch
7
+ from .clip import load as clip_load
8
+ from detectron2.utils.comm import get_local_rank, synchronize
9
+
10
+
11
+ def expand_box(
12
+ x1: float,
13
+ y1: float,
14
+ x2: float,
15
+ y2: float,
16
+ expand_ratio: float = 1.0,
17
+ max_h: int = None,
18
+ max_w: int = None,
19
+ ):
20
+ cx = 0.5 * (x1 + x2)
21
+ cy = 0.5 * (y1 + y2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ w = w * expand_ratio
25
+ h = h * expand_ratio
26
+ box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h]
27
+ if max_h is not None:
28
+ box[1] = max(0, box[1])
29
+ box[3] = min(max_h - 1, box[3])
30
+ if max_w is not None:
31
+ box[0] = max(0, box[0])
32
+ box[2] = min(max_w - 1, box[2])
33
+ return [int(b) for b in box]
34
+
35
+
36
+ def mask2box(mask: torch.Tensor):
37
+ # use naive way
38
+ row = torch.nonzero(mask.sum(dim=0))[:, 0]
39
+ if len(row) == 0:
40
+ return None
41
+ x1 = row.min()
42
+ x2 = row.max()
43
+ col = np.nonzero(mask.sum(dim=1))[:, 0]
44
+ y1 = col.min()
45
+ y2 = col.max()
46
+ return x1, y1, x2 + 1, y2 + 1
47
+
48
+
49
+ def crop_with_mask(
50
+ image: torch.Tensor,
51
+ mask: torch.Tensor,
52
+ bbox: torch.Tensor,
53
+ fill: Tuple[float, float, float] = (0, 0, 0),
54
+ expand_ratio: float = 1.0,
55
+ ):
56
+ l, t, r, b = expand_box(*bbox, expand_ratio)
57
+ _, h, w = image.shape
58
+ l = max(l, 0)
59
+ t = max(t, 0)
60
+ r = min(r, w)
61
+ b = min(b, h)
62
+ new_image = torch.cat(
63
+ [image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
64
+ )
65
+ mask_bool = mask.bool()
66
+ return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask_bool[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]
67
+
68
+
69
+ def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
70
+ rank = get_local_rank()
71
+ if rank == 0:
72
+ # download on rank 0 only
73
+ model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
74
+ synchronize()
75
+ if rank != 0:
76
+ model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
77
+ synchronize()
78
+ if frozen:
79
+ for param in model.parameters():
80
+ param.requires_grad = False
81
+ return model
open_vocab_seg/modeling/criterion.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
4
+
5
+ """
6
+ MaskFormer criterion.
7
+ """
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from detectron2.utils.comm import get_world_size
13
+
14
+ from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
15
+
16
+
17
+ def dice_loss(inputs, targets, num_masks):
18
+ """
19
+ Compute the DICE loss, similar to generalized IOU for masks
20
+ Args:
21
+ inputs: A float tensor of arbitrary shape.
22
+ The predictions for each example.
23
+ targets: A float tensor with the same shape as inputs. Stores the binary
24
+ classification label for each element in inputs
25
+ (0 for the negative class and 1 for the positive class).
26
+ """
27
+ inputs = inputs.sigmoid()
28
+ inputs = inputs.flatten(1)
29
+ numerator = 2 * (inputs * targets).sum(-1)
30
+ denominator = inputs.sum(-1) + targets.sum(-1)
31
+ loss = 1 - (numerator + 1) / (denominator + 1)
32
+ return loss.sum() / num_masks
33
+
34
+
35
+ def sigmoid_focal_loss(
36
+ inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2
37
+ ):
38
+ """
39
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
40
+ Args:
41
+ inputs: A float tensor of arbitrary shape.
42
+ The predictions for each example.
43
+ targets: A float tensor with the same shape as inputs. Stores the binary
44
+ classification label for each element in inputs
45
+ (0 for the negative class and 1 for the positive class).
46
+ alpha: (optional) Weighting factor in range (0,1) to balance
47
+ positive vs negative examples. Default = -1 (no weighting).
48
+ gamma: Exponent of the modulating factor (1 - p_t) to
49
+ balance easy vs hard examples.
50
+ Returns:
51
+ Loss tensor
52
+ """
53
+ prob = inputs.sigmoid()
54
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
55
+ p_t = prob * targets + (1 - prob) * (1 - targets)
56
+ loss = ce_loss * ((1 - p_t) ** gamma)
57
+
58
+ if alpha >= 0:
59
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
60
+ loss = alpha_t * loss
61
+
62
+ return loss.mean(1).sum() / num_masks
63
+
64
+
65
+ class SetCriterion(nn.Module):
66
+ """This class computes the loss for DETR.
67
+ The process happens in two steps:
68
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
69
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
70
+ """
71
+
72
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
73
+ """Create the criterion.
74
+ Parameters:
75
+ num_classes: number of object categories, omitting the special no-object category
76
+ matcher: module able to compute a matching between targets and proposals
77
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
78
+ eos_coef: relative classification weight applied to the no-object category
79
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
80
+ """
81
+ super().__init__()
82
+ self.num_classes = num_classes
83
+ self.matcher = matcher
84
+ self.weight_dict = weight_dict
85
+ self.eos_coef = eos_coef
86
+ self.losses = losses
87
+ if eos_coef > 0:
88
+
89
+ empty_weight = torch.ones(self.num_classes + 1)
90
+
91
+ empty_weight[-1] = self.eos_coef
92
+ self.register_buffer("empty_weight", empty_weight)
93
+ self.use_ignore_idx = False
94
+ else:
95
+ self.use_ignore_idx = True
96
+ self.cur_target = []
97
+
98
+ def loss_labels(self, outputs, targets, indices, num_masks):
99
+ """Classification loss (NLL)
100
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
101
+ """
102
+ assert "pred_logits" in outputs
103
+ src_logits = outputs["pred_logits"]
104
+
105
+ idx = self._get_src_permutation_idx(indices)
106
+ target_classes_o = torch.cat(
107
+ [t["labels"][J] for t, (_, J) in zip(targets, indices)]
108
+ )
109
+ target_classes = torch.full(
110
+ src_logits.shape[:2],
111
+ self.num_classes,
112
+ dtype=torch.int64,
113
+ device=src_logits.device,
114
+ )
115
+ target_classes[idx] = target_classes_o
116
+ if self.use_ignore_idx:
117
+ loss_ce = F.cross_entropy(
118
+ src_logits.transpose(1, 2),
119
+ target_classes,
120
+ ignore_index=self.num_classes,
121
+ )
122
+ else:
123
+ if "empty_weight" in outputs:
124
+ empty_weight = torch.cat(
125
+ [outputs["empty_weight"], self.empty_weight[-1:]]
126
+ ).detach()
127
+ else:
128
+ empty_weight = self.empty_weight
129
+ loss_ce = F.cross_entropy(
130
+ src_logits.transpose(1, 2), target_classes, empty_weight
131
+ )
132
+ losses = {"loss_ce": loss_ce}
133
+ return losses
134
+
135
+ def loss_masks(self, outputs, targets, indices, num_masks):
136
+ """Compute the losses related to the masks: the focal loss and the dice loss.
137
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
138
+ """
139
+ assert "pred_masks" in outputs
140
+
141
+ src_idx = self._get_src_permutation_idx(indices)
142
+ tgt_idx = self._get_tgt_permutation_idx(indices)
143
+ src_masks = outputs["pred_masks"]
144
+ src_masks = src_masks[src_idx]
145
+ masks = [t["masks"] for t in targets]
146
+ # TODO use valid to mask invalid areas due to padding in loss
147
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
148
+ target_masks = target_masks.to(src_masks)
149
+ target_masks = target_masks[tgt_idx]
150
+
151
+ # upsample predictions to the target size
152
+ src_masks = F.interpolate(
153
+ src_masks[:, None],
154
+ size=target_masks.shape[-2:],
155
+ mode="bilinear",
156
+ align_corners=False,
157
+ )
158
+ src_masks = src_masks[:, 0].flatten(1)
159
+
160
+ target_masks = target_masks.flatten(1)
161
+ target_masks = target_masks.view(src_masks.shape)
162
+ losses = {
163
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
164
+ "loss_dice": dice_loss(src_masks, target_masks, num_masks),
165
+ }
166
+ return losses
167
+
168
+ def _get_src_permutation_idx(self, indices):
169
+ # permute predictions following indices
170
+ batch_idx = torch.cat(
171
+ [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
172
+ )
173
+ src_idx = torch.cat([src for (src, _) in indices])
174
+ return batch_idx, src_idx
175
+
176
+ def _get_tgt_permutation_idx(self, indices):
177
+ # permute targets following indices
178
+ batch_idx = torch.cat(
179
+ [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
180
+ )
181
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
182
+ return batch_idx, tgt_idx
183
+
184
+ def get_loss(self, loss, outputs, targets, indices, num_masks):
185
+ loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
186
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
187
+ return loss_map[loss](outputs, targets, indices, num_masks)
188
+
189
+ def forward(self, outputs, targets):
190
+ """This performs the loss computation.
191
+ Parameters:
192
+ outputs: dict of tensors, see the output specification of the model for the format
193
+ targets: list of dicts, such that len(targets) == batch_size.
194
+ The expected keys in each dict depends on the losses applied, see each loss' doc
195
+ """
196
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
197
+
198
+ # Retrieve the matching between the outputs of the last layer and the targets
199
+ indices = self.matcher(outputs_without_aux, targets)
200
+
201
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
202
+ num_masks = sum(len(t["labels"]) for t in targets)
203
+ num_masks = torch.as_tensor(
204
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
205
+ )
206
+ if is_dist_avail_and_initialized():
207
+ torch.distributed.all_reduce(num_masks)
208
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
209
+
210
+ # Compute all the requested losses
211
+ losses = {}
212
+ for loss in self.losses:
213
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
214
+
215
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
216
+ if "aux_outputs" in outputs:
217
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
218
+ indices = self.matcher(aux_outputs, targets)
219
+ for loss in self.losses:
220
+ l_dict = self.get_loss(
221
+ loss, aux_outputs, targets, indices, num_masks
222
+ )
223
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
224
+ losses.update(l_dict)
225
+
226
+ return losses
227
+
228
+ def clean_buffer(self):
229
+ self.cur_target = []
open_vocab_seg/modeling/heads/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
open_vocab_seg/modeling/heads/mask_former_head.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import logging
5
+ from copy import deepcopy
6
+ from typing import Callable, Dict, List, Optional, Tuple, Union
7
+
8
+ import fvcore.nn.weight_init as weight_init
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from detectron2.config import configurable
13
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
14
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
15
+
16
+ from ..transformer.transformer_predictor import TransformerPredictor
17
+ from .pixel_decoder import build_pixel_decoder
18
+
19
+
20
+ @SEM_SEG_HEADS_REGISTRY.register()
21
+ class MaskFormerHead(nn.Module):
22
+
23
+ _version = 2
24
+
25
+ def _load_from_state_dict(
26
+ self,
27
+ state_dict,
28
+ prefix,
29
+ local_metadata,
30
+ strict,
31
+ missing_keys,
32
+ unexpected_keys,
33
+ error_msgs,
34
+ ):
35
+ version = local_metadata.get("version", None)
36
+ if version is None or version < 2:
37
+ # Do not warn if train from scratch
38
+ scratch = True
39
+ logger = logging.getLogger(__name__)
40
+ for k in list(state_dict.keys()):
41
+ newk = k
42
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
43
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
44
+ # logger.debug(f"{k} ==> {newk}")
45
+ if newk != k:
46
+ state_dict[newk] = state_dict[k]
47
+ del state_dict[k]
48
+ scratch = False
49
+
50
+ if not scratch:
51
+ logger.warning(
52
+ f"Weight format of {self.__class__.__name__} have changed! "
53
+ "Please upgrade your models. Applying automatic conversion now ..."
54
+ )
55
+
56
+ @configurable
57
+ def __init__(
58
+ self,
59
+ input_shape: Dict[str, ShapeSpec],
60
+ *,
61
+ num_classes: int,
62
+ pixel_decoder: nn.Module,
63
+ loss_weight: float = 1.0,
64
+ ignore_value: int = -1,
65
+ # extra parameters
66
+ transformer_predictor: nn.Module,
67
+ transformer_in_feature: str,
68
+ ):
69
+ """
70
+ NOTE: this interface is experimental.
71
+ Args:
72
+ input_shape: shapes (channels and stride) of the input features
73
+ num_classes: number of classes to predict
74
+ pixel_decoder: the pixel decoder module
75
+ loss_weight: loss weight
76
+ ignore_value: category id to be ignored during training.
77
+ transformer_predictor: the transformer decoder that makes prediction
78
+ transformer_in_feature: input feature name to the transformer_predictor
79
+ """
80
+ super().__init__()
81
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
82
+ self.in_features = [k for k, v in input_shape]
83
+ feature_strides = [v.stride for k, v in input_shape]
84
+ feature_channels = [v.channels for k, v in input_shape]
85
+
86
+ self.ignore_value = ignore_value
87
+ self.common_stride = 4
88
+ self.loss_weight = loss_weight
89
+
90
+ self.pixel_decoder = pixel_decoder
91
+ self.predictor = transformer_predictor
92
+ self.transformer_in_feature = transformer_in_feature
93
+
94
+ self.num_classes = num_classes
95
+
96
+ @classmethod
97
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
98
+ return {
99
+ "input_shape": {
100
+ k: v
101
+ for k, v in input_shape.items()
102
+ if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
103
+ },
104
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
105
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
106
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
107
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
108
+ "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
109
+ "transformer_predictor": TransformerPredictor(
110
+ cfg,
111
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
112
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
113
+ else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
114
+ mask_classification=True,
115
+ ),
116
+ }
117
+
118
+ def forward(self, features):
119
+ return self.layers(features)
120
+
121
+ def layers(self, features):
122
+ (
123
+ mask_features,
124
+ transformer_encoder_features,
125
+ ) = self.pixel_decoder.forward_features(features)
126
+ if self.transformer_in_feature == "transformer_encoder":
127
+ assert (
128
+ transformer_encoder_features is not None
129
+ ), "Please use the TransformerEncoderPixelDecoder."
130
+ predictions = self.predictor(transformer_encoder_features, mask_features)
131
+ else:
132
+ predictions = self.predictor(
133
+ features[self.transformer_in_feature], mask_features
134
+ )
135
+ return predictions
open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+ # Modified by Feng Liang from
4
+ # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/heads/zero_shot_mask_former_head.py
5
+
6
+ import logging
7
+ from copy import deepcopy
8
+ from typing import Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import fvcore.nn.weight_init as weight_init
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+
14
+ from detectron2.config import configurable
15
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
16
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
17
+
18
+ from ..transformer.open_vocab_transformer_predictor import OpenVocabTransformerPredictor
19
+ from .pixel_decoder import build_pixel_decoder
20
+
21
+
22
+ @SEM_SEG_HEADS_REGISTRY.register()
23
+ class OpenVocabMaskFormerHead(nn.Module):
24
+
25
+ _version = 2
26
+
27
+ def _load_from_state_dict(
28
+ self,
29
+ state_dict,
30
+ prefix,
31
+ local_metadata,
32
+ strict,
33
+ missing_keys,
34
+ unexpected_keys,
35
+ error_msgs,
36
+ ):
37
+ version = local_metadata.get("version", None)
38
+ if version is None or version < 2:
39
+ # Do not warn if train from scratch
40
+ scratch = True
41
+ logger = logging.getLogger(__name__)
42
+ for k in list(state_dict.keys()):
43
+ newk = k
44
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
45
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
46
+ # logger.debug(f"{k} ==> {newk}")
47
+ if newk != k:
48
+ state_dict[newk] = state_dict[k]
49
+ del state_dict[k]
50
+ scratch = False
51
+
52
+ if not scratch:
53
+ logger.warning(
54
+ f"Weight format of {self.__class__.__name__} have changed! "
55
+ "Please upgrade your models. Applying automatic conversion now ..."
56
+ )
57
+
58
+ @configurable
59
+ def __init__(
60
+ self,
61
+ input_shape: Dict[str, ShapeSpec],
62
+ *,
63
+ num_classes: int,
64
+ pixel_decoder: nn.Module,
65
+ loss_weight: float = 1.0,
66
+ ignore_value: int = -1,
67
+ # extra parameters
68
+ transformer_predictor: nn.Module,
69
+ transformer_in_feature: str,
70
+ ):
71
+ """
72
+ NOTE: this interface is experimental.
73
+ Args:
74
+ input_shape: shapes (channels and stride) of the input features
75
+ num_classes: number of classes to predict
76
+ pixel_decoder: the pixel decoder module
77
+ loss_weight: loss weight
78
+ ignore_value: category id to be ignored during training.
79
+ transformer_predictor: the transformer decoder that makes prediction
80
+ transformer_in_feature: input feature name to the transformer_predictor
81
+ """
82
+ super().__init__()
83
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
84
+ self.in_features = [k for k, v in input_shape]
85
+ feature_strides = [v.stride for k, v in input_shape]
86
+ feature_channels = [v.channels for k, v in input_shape]
87
+
88
+ self.ignore_value = ignore_value
89
+ self.common_stride = 4
90
+ self.loss_weight = loss_weight
91
+
92
+ self.pixel_decoder = pixel_decoder
93
+ self.predictor = transformer_predictor
94
+ self.transformer_in_feature = transformer_in_feature
95
+
96
+ self.num_classes = num_classes
97
+
98
+ @classmethod
99
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
100
+ return {
101
+ "input_shape": {
102
+ k: v
103
+ for k, v in input_shape.items()
104
+ if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
105
+ },
106
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
107
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
108
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
109
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
110
+ "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
111
+ "transformer_predictor": OpenVocabTransformerPredictor(
112
+ cfg,
113
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
114
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
115
+ else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
116
+ mask_classification=True,
117
+ ),
118
+ }
119
+
120
+ def forward(self, features):
121
+ return self.layers(features)
122
+
123
+ def layers(self, features):
124
+ (
125
+ mask_features,
126
+ transformer_encoder_features,
127
+ ) = self.pixel_decoder.forward_features(features)
128
+ if self.transformer_in_feature == "transformer_encoder":
129
+ assert (
130
+ transformer_encoder_features is not None
131
+ ), "Please use the TransformerEncoderPixelDecoder."
132
+ predictions = self.predictor(transformer_encoder_features, mask_features)
133
+ else:
134
+ predictions = self.predictor(
135
+ features[self.transformer_in_feature], mask_features
136
+ )
137
+ return predictions
138
+
139
+ def freeze_pretrained(self):
140
+ for name, module in self.named_children():
141
+ if name not in ["predictor"]:
142
+ for param in module.parameters():
143
+ param.requires_grad = False
144
+ else:
145
+ module.freeze_pretrained()
open_vocab_seg/modeling/heads/pixel_decoder.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import logging
5
+ from typing import Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import fvcore.nn.weight_init as weight_init
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from detectron2.config import configurable
12
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
13
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
14
+
15
+ from ..transformer.position_encoding import PositionEmbeddingSine
16
+ from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer
17
+
18
+
19
+ def build_pixel_decoder(cfg, input_shape):
20
+ """
21
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
22
+ """
23
+ name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
24
+ model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
25
+ forward_features = getattr(model, "forward_features", None)
26
+ if not callable(forward_features):
27
+ raise ValueError(
28
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
29
+ f"Please implement forward_features for {name} to only return mask features."
30
+ )
31
+ return model
32
+
33
+
34
+ @SEM_SEG_HEADS_REGISTRY.register()
35
+ class BasePixelDecoder(nn.Module):
36
+ @configurable
37
+ def __init__(
38
+ self,
39
+ input_shape: Dict[str, ShapeSpec],
40
+ *,
41
+ conv_dim: int,
42
+ mask_dim: int,
43
+ norm: Optional[Union[str, Callable]] = None,
44
+ ):
45
+ """
46
+ NOTE: this interface is experimental.
47
+ Args:
48
+ input_shape: shapes (channels and stride) of the input features
49
+ conv_dims: number of output channels for the intermediate conv layers.
50
+ mask_dim: number of output channels for the final conv layer.
51
+ norm (str or callable): normalization for all conv layers
52
+ """
53
+ super().__init__()
54
+
55
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
56
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
57
+ feature_channels = [v.channels for k, v in input_shape]
58
+
59
+ lateral_convs = []
60
+ output_convs = []
61
+
62
+ use_bias = norm == ""
63
+ for idx, in_channels in enumerate(feature_channels):
64
+ if idx == len(self.in_features) - 1:
65
+ output_norm = get_norm(norm, conv_dim)
66
+ output_conv = Conv2d(
67
+ in_channels,
68
+ conv_dim,
69
+ kernel_size=3,
70
+ stride=1,
71
+ padding=1,
72
+ bias=use_bias,
73
+ norm=output_norm,
74
+ activation=F.relu,
75
+ )
76
+ weight_init.c2_xavier_fill(output_conv)
77
+ self.add_module("layer_{}".format(idx + 1), output_conv)
78
+
79
+ lateral_convs.append(None)
80
+ output_convs.append(output_conv)
81
+ else:
82
+ lateral_norm = get_norm(norm, conv_dim)
83
+ output_norm = get_norm(norm, conv_dim)
84
+
85
+ lateral_conv = Conv2d(
86
+ in_channels,
87
+ conv_dim,
88
+ kernel_size=1,
89
+ bias=use_bias,
90
+ norm=lateral_norm,
91
+ )
92
+ output_conv = Conv2d(
93
+ conv_dim,
94
+ conv_dim,
95
+ kernel_size=3,
96
+ stride=1,
97
+ padding=1,
98
+ bias=use_bias,
99
+ norm=output_norm,
100
+ activation=F.relu,
101
+ )
102
+ weight_init.c2_xavier_fill(lateral_conv)
103
+ weight_init.c2_xavier_fill(output_conv)
104
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
105
+ self.add_module("layer_{}".format(idx + 1), output_conv)
106
+
107
+ lateral_convs.append(lateral_conv)
108
+ output_convs.append(output_conv)
109
+ # Place convs into top-down order (from low to high resolution)
110
+ # to make the top-down computation in forward clearer.
111
+ self.lateral_convs = lateral_convs[::-1]
112
+ self.output_convs = output_convs[::-1]
113
+
114
+ self.mask_dim = mask_dim
115
+ self.mask_features = Conv2d(
116
+ conv_dim,
117
+ mask_dim,
118
+ kernel_size=3,
119
+ stride=1,
120
+ padding=1,
121
+ )
122
+ weight_init.c2_xavier_fill(self.mask_features)
123
+
124
+ @classmethod
125
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
126
+ ret = {}
127
+ ret["input_shape"] = {
128
+ k: v
129
+ for k, v in input_shape.items()
130
+ if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
131
+ }
132
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
133
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
134
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
135
+ return ret
136
+
137
+ def forward_features(self, features):
138
+ # Reverse feature maps into top-down order (from low to high resolution)
139
+ for idx, f in enumerate(self.in_features[::-1]):
140
+ x = features[f]
141
+ lateral_conv = self.lateral_convs[idx]
142
+ output_conv = self.output_convs[idx]
143
+ if lateral_conv is None:
144
+ y = output_conv(x)
145
+ else:
146
+ cur_fpn = lateral_conv(x)
147
+ # Following FPN implementation, we use nearest upsampling here
148
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
149
+ y = output_conv(y)
150
+ return self.mask_features(y), None
151
+
152
+ def forward(self, features, targets=None):
153
+ logger = logging.getLogger(__name__)
154
+ logger.warning(
155
+ "Calling forward() may cause unpredicted behavior of PixelDecoder module."
156
+ )
157
+ return self.forward_features(features)
158
+
159
+
160
+ class TransformerEncoderOnly(nn.Module):
161
+ def __init__(
162
+ self,
163
+ d_model=512,
164
+ nhead=8,
165
+ num_encoder_layers=6,
166
+ dim_feedforward=2048,
167
+ dropout=0.1,
168
+ activation="relu",
169
+ normalize_before=False,
170
+ ):
171
+ super().__init__()
172
+
173
+ encoder_layer = TransformerEncoderLayer(
174
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
175
+ )
176
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
177
+ self.encoder = TransformerEncoder(
178
+ encoder_layer, num_encoder_layers, encoder_norm
179
+ )
180
+
181
+ self._reset_parameters()
182
+
183
+ self.d_model = d_model
184
+ self.nhead = nhead
185
+
186
+ def _reset_parameters(self):
187
+ for p in self.parameters():
188
+ if p.dim() > 1:
189
+ nn.init.xavier_uniform_(p)
190
+
191
+ def forward(self, src, mask, pos_embed):
192
+ # flatten NxCxHxW to HWxNxC
193
+ bs, c, h, w = src.shape
194
+ src = src.flatten(2).permute(2, 0, 1)
195
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
196
+ if mask is not None:
197
+ mask = mask.flatten(1)
198
+
199
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
200
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
201
+
202
+
203
+ @SEM_SEG_HEADS_REGISTRY.register()
204
+ class TransformerEncoderPixelDecoder(BasePixelDecoder):
205
+ @configurable
206
+ def __init__(
207
+ self,
208
+ input_shape: Dict[str, ShapeSpec],
209
+ *,
210
+ transformer_dropout: float,
211
+ transformer_nheads: int,
212
+ transformer_dim_feedforward: int,
213
+ transformer_enc_layers: int,
214
+ transformer_pre_norm: bool,
215
+ conv_dim: int,
216
+ mask_dim: int,
217
+ norm: Optional[Union[str, Callable]] = None,
218
+ ):
219
+ """
220
+ NOTE: this interface is experimental.
221
+ Args:
222
+ input_shape: shapes (channels and stride) of the input features
223
+ transformer_dropout: dropout probability in transformer
224
+ transformer_nheads: number of heads in transformer
225
+ transformer_dim_feedforward: dimension of feedforward network
226
+ transformer_enc_layers: number of transformer encoder layers
227
+ transformer_pre_norm: whether to use pre-layernorm or not
228
+ conv_dims: number of output channels for the intermediate conv layers.
229
+ mask_dim: number of output channels for the final conv layer.
230
+ norm (str or callable): normalization for all conv layers
231
+ """
232
+ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
233
+
234
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
235
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
236
+ feature_strides = [v.stride for k, v in input_shape]
237
+ feature_channels = [v.channels for k, v in input_shape]
238
+
239
+ in_channels = feature_channels[len(self.in_features) - 1]
240
+ self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
241
+ weight_init.c2_xavier_fill(self.input_proj)
242
+ self.transformer = TransformerEncoderOnly(
243
+ d_model=conv_dim,
244
+ dropout=transformer_dropout,
245
+ nhead=transformer_nheads,
246
+ dim_feedforward=transformer_dim_feedforward,
247
+ num_encoder_layers=transformer_enc_layers,
248
+ normalize_before=transformer_pre_norm,
249
+ )
250
+ N_steps = conv_dim // 2
251
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
252
+
253
+ # update layer
254
+ use_bias = norm == ""
255
+ output_norm = get_norm(norm, conv_dim)
256
+ output_conv = Conv2d(
257
+ conv_dim,
258
+ conv_dim,
259
+ kernel_size=3,
260
+ stride=1,
261
+ padding=1,
262
+ bias=use_bias,
263
+ norm=output_norm,
264
+ activation=F.relu,
265
+ )
266
+ weight_init.c2_xavier_fill(output_conv)
267
+ delattr(self, "layer_{}".format(len(self.in_features)))
268
+ self.add_module("layer_{}".format(len(self.in_features)), output_conv)
269
+ self.output_convs[0] = output_conv
270
+
271
+ @classmethod
272
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
273
+ ret = super().from_config(cfg, input_shape)
274
+ ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
275
+ ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
276
+ ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
277
+ ret[
278
+ "transformer_enc_layers"
279
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
280
+ ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
281
+ return ret
282
+
283
+ def forward_features(self, features):
284
+ # Reverse feature maps into top-down order (from low to high resolution)
285
+ for idx, f in enumerate(self.in_features[::-1]):
286
+ x = features[f]
287
+ lateral_conv = self.lateral_convs[idx]
288
+ output_conv = self.output_convs[idx]
289
+ if lateral_conv is None:
290
+ transformer = self.input_proj(x)
291
+ pos = self.pe_layer(x)
292
+ transformer = self.transformer(transformer, None, pos)
293
+ y = output_conv(transformer)
294
+ # save intermediate feature as input to Transformer decoder
295
+ transformer_encoder_features = transformer
296
+ else:
297
+ cur_fpn = lateral_conv(x)
298
+ # Following FPN implementation, we use nearest upsampling here
299
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
300
+ y = output_conv(y)
301
+ return self.mask_features(y), transformer_encoder_features
302
+
303
+ def forward(self, features, targets=None):
304
+ logger = logging.getLogger(__name__)
305
+ logger.warning(
306
+ "Calling forward() may cause unpredicted behavior of PixelDecoder module."
307
+ )
308
+ return self.forward_features(features)
open_vocab_seg/modeling/matcher.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
3
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
4
+
5
+ """
6
+ Modules to compute the matching cost and solve the corresponding LSAP.
7
+ """
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from scipy.optimize import linear_sum_assignment
11
+ from torch import nn
12
+
13
+
14
+ def batch_dice_loss(inputs, targets):
15
+ """
16
+ Compute the DICE loss, similar to generalized IOU for masks
17
+ Args:
18
+ inputs: A float tensor of arbitrary shape.
19
+ The predictions for each example.
20
+ targets: A float tensor with the same shape as inputs. Stores the binary
21
+ classification label for each element in inputs
22
+ (0 for the negative class and 1 for the positive class).
23
+ """
24
+ inputs = inputs.sigmoid()
25
+ inputs = inputs.flatten(1)
26
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
27
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
28
+ loss = 1 - (numerator + 1) / (denominator + 1)
29
+ return loss
30
+
31
+
32
+ def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
33
+ """
34
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
35
+ Args:
36
+ inputs: A float tensor of arbitrary shape.
37
+ The predictions for each example.
38
+ targets: A float tensor with the same shape as inputs. Stores the binary
39
+ classification label for each element in inputs
40
+ (0 for the negative class and 1 for the positive class).
41
+ alpha: (optional) Weighting factor in range (0,1) to balance
42
+ positive vs negative examples. Default = -1 (no weighting).
43
+ gamma: Exponent of the modulating factor (1 - p_t) to
44
+ balance easy vs hard examples.
45
+ Returns:
46
+ Loss tensor
47
+ """
48
+ hw = inputs.shape[1]
49
+
50
+ prob = inputs.sigmoid()
51
+ focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits(
52
+ inputs, torch.ones_like(inputs), reduction="none"
53
+ )
54
+ focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits(
55
+ inputs, torch.zeros_like(inputs), reduction="none"
56
+ )
57
+ if alpha >= 0:
58
+ focal_pos = focal_pos * alpha
59
+ focal_neg = focal_neg * (1 - alpha)
60
+
61
+ loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum(
62
+ "nc,mc->nm", focal_neg, (1 - targets)
63
+ )
64
+
65
+ return loss / hw
66
+
67
+
68
+ class HungarianMatcher(nn.Module):
69
+ """This class computes an assignment between the targets and the predictions of the network
70
+
71
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
72
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
73
+ while the others are un-matched (and thus treated as non-objects).
74
+ """
75
+
76
+ def __init__(
77
+ self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1
78
+ ):
79
+ """Creates the matcher
80
+
81
+ Params:
82
+ cost_class: This is the relative weight of the classification error in the matching cost
83
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
84
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
85
+ """
86
+ super().__init__()
87
+ self.cost_class = cost_class
88
+ self.cost_mask = cost_mask
89
+ self.cost_dice = cost_dice
90
+ assert (
91
+ cost_class != 0 or cost_mask != 0 or cost_dice != 0
92
+ ), "all costs cant be 0"
93
+
94
+ @torch.no_grad()
95
+ def memory_efficient_forward(self, outputs, targets):
96
+ """More memory-friendly matching"""
97
+ bs, num_queries = outputs["pred_logits"].shape[:2]
98
+
99
+ # Work out the mask padding size
100
+ masks = [v["masks"] for v in targets]
101
+ h_max = max([m.shape[1] for m in masks])
102
+ w_max = max([m.shape[2] for m in masks])
103
+
104
+ indices = []
105
+
106
+ # Iterate through batch size
107
+ for b in range(bs):
108
+
109
+ out_prob = outputs["pred_logits"][b].softmax(
110
+ -1
111
+ ) # [num_queries, num_classes]
112
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
113
+
114
+ tgt_ids = targets[b]["labels"]
115
+ # gt masks are already padded when preparing target
116
+ tgt_mask = targets[b]["masks"].to(out_mask)
117
+
118
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
119
+ # but approximate it in 1 - proba[target class].
120
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
121
+ cost_class = -out_prob[:, tgt_ids]
122
+
123
+ # Downsample gt masks to save memory
124
+ tgt_mask = F.interpolate(
125
+ tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest"
126
+ )
127
+
128
+ # Flatten spatial dimension
129
+ out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W]
130
+ tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W]
131
+
132
+ # Compute the focal loss between masks
133
+ cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask)
134
+
135
+ # Compute the dice loss betwen masks
136
+ cost_dice = batch_dice_loss(out_mask, tgt_mask)
137
+
138
+ # Final cost matrix
139
+ C = (
140
+ self.cost_mask * cost_mask
141
+ + self.cost_class * cost_class
142
+ + self.cost_dice * cost_dice
143
+ )
144
+ C = C.reshape(num_queries, -1).cpu()
145
+
146
+ indices.append(linear_sum_assignment(C))
147
+ return [
148
+ (
149
+ torch.as_tensor(i, dtype=torch.int64),
150
+ torch.as_tensor(j, dtype=torch.int64),
151
+ )
152
+ for i, j in indices
153
+ ]
154
+
155
+ @torch.no_grad()
156
+ def forward(self, outputs, targets):
157
+ """Performs the matching
158
+
159
+ Params:
160
+ outputs: This is a dict that contains at least these entries:
161
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
162
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
163
+
164
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
165
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
166
+ objects in the target) containing the class labels
167
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
168
+
169
+ Returns:
170
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
171
+ - index_i is the indices of the selected predictions (in order)
172
+ - index_j is the indices of the corresponding selected targets (in order)
173
+ For each batch element, it holds:
174
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
175
+ """
176
+ return self.memory_efficient_forward(outputs, targets)
177
+
178
+ def __repr__(self):
179
+ head = "Matcher " + self.__class__.__name__
180
+ body = [
181
+ "cost_class: {}".format(self.cost_class),
182
+ "cost_mask: {}".format(self.cost_mask),
183
+ "cost_dice: {}".format(self.cost_dice),
184
+ ]
185
+ _repr_indent = 4
186
+ lines = [head] + [" " * _repr_indent + line for line in body]
187
+ return "\n".join(lines)
open_vocab_seg/modeling/transformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
4
+
5
+ from torch import nn
6
+ from detectron2.config import configurable
7
+ from .transformer_predictor import TransformerPredictor, MLP
8
+
9
+
10
+ class OpenVocabTransformerPredictor(TransformerPredictor):
11
+ @configurable
12
+ def __init__(
13
+ self,
14
+ in_channels,
15
+ mask_classification=True,
16
+ *,
17
+ embedding_dim: int,
18
+ embed_hidden_dim: int,
19
+ embed_layers: int,
20
+ hidden_dim: int,
21
+ num_queries: int,
22
+ nheads: int,
23
+ dropout: float,
24
+ dim_feedforward: int,
25
+ enc_layers: int,
26
+ dec_layers: int,
27
+ pre_norm: bool,
28
+ deep_supervision: bool,
29
+ mask_dim: int,
30
+ enforce_input_project: bool,
31
+ ):
32
+ super().__init__(
33
+ in_channels,
34
+ False,
35
+ num_classes=embedding_dim,
36
+ hidden_dim=hidden_dim,
37
+ num_queries=num_queries,
38
+ nheads=nheads,
39
+ dropout=dropout,
40
+ dim_feedforward=dim_feedforward,
41
+ enc_layers=enc_layers,
42
+ dec_layers=dec_layers,
43
+ pre_norm=pre_norm,
44
+ deep_supervision=deep_supervision,
45
+ mask_dim=mask_dim,
46
+ enforce_input_project=enforce_input_project,
47
+ )
48
+ self.mask_classification = mask_classification
49
+ # output FFNs
50
+ if self.mask_classification:
51
+ self.class_embed = MLP(
52
+ hidden_dim, embed_hidden_dim, embedding_dim, embed_layers
53
+ )
54
+
55
+ def freeze_pretrained(self):
56
+ for name, module in self.named_children():
57
+ if name not in ["class_embed"]:
58
+ for param in module.parameters():
59
+ param.requires_grad = False
60
+
61
+ @classmethod
62
+ def from_config(cls, cfg, in_channels, mask_classification):
63
+ ret = {}
64
+ ret["in_channels"] = in_channels
65
+ ret["mask_classification"] = mask_classification
66
+
67
+ ret["embedding_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM
68
+ ret["embed_hidden_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM
69
+ ret["embed_layers"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS
70
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
71
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
72
+ # Transformer parameters:
73
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
74
+ ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
75
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
76
+ ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
77
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
78
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
79
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
80
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
81
+
82
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
83
+
84
+ return ret
open_vocab_seg/modeling/transformer/position_encoding.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
3
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
4
+
5
+ """
6
+ Various positional encodings for the transformer.
7
+ """
8
+ import math
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class PositionEmbeddingSine(nn.Module):
15
+ """
16
+ This is a more standard version of the position embedding, very similar to the one
17
+ used by the Attention is all you need paper, generalized to work on images.
18
+ """
19
+
20
+ def __init__(
21
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
22
+ ):
23
+ super().__init__()
24
+ self.num_pos_feats = num_pos_feats
25
+ self.temperature = temperature
26
+ self.normalize = normalize
27
+ if scale is not None and normalize is False:
28
+ raise ValueError("normalize should be True if scale is passed")
29
+ if scale is None:
30
+ scale = 2 * math.pi
31
+ self.scale = scale
32
+
33
+ def forward(self, x, mask=None):
34
+ if mask is None:
35
+ mask = torch.zeros(
36
+ (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
37
+ )
38
+ not_mask = ~mask
39
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
40
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
41
+ if self.normalize:
42
+ eps = 1e-6
43
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
44
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
45
+
46
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
47
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
48
+
49
+ pos_x = x_embed[:, :, :, None] / dim_t
50
+ pos_y = y_embed[:, :, :, None] / dim_t
51
+ pos_x = torch.stack(
52
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
53
+ ).flatten(3)
54
+ pos_y = torch.stack(
55
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
56
+ ).flatten(3)
57
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
58
+ return pos
open_vocab_seg/modeling/transformer/transformer.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
3
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
4
+
5
+ """
6
+ Transformer class.
7
+
8
+ Copy-paste from torch.nn.Transformer with modifications:
9
+ * positional encodings are passed in MHattention
10
+ * extra LN at the end of encoder is removed
11
+ * decoder returns a stack of activations from all decoding layers
12
+ """
13
+ import copy
14
+ from typing import List, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import Tensor, nn
19
+
20
+
21
+ class Transformer(nn.Module):
22
+ def __init__(
23
+ self,
24
+ d_model=512,
25
+ nhead=8,
26
+ num_encoder_layers=6,
27
+ num_decoder_layers=6,
28
+ dim_feedforward=2048,
29
+ dropout=0.1,
30
+ activation="relu",
31
+ normalize_before=False,
32
+ return_intermediate_dec=False,
33
+ ):
34
+ super().__init__()
35
+
36
+ encoder_layer = TransformerEncoderLayer(
37
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
38
+ )
39
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
40
+ self.encoder = TransformerEncoder(
41
+ encoder_layer, num_encoder_layers, encoder_norm
42
+ )
43
+
44
+ decoder_layer = TransformerDecoderLayer(
45
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
46
+ )
47
+ decoder_norm = nn.LayerNorm(d_model)
48
+ self.decoder = TransformerDecoder(
49
+ decoder_layer,
50
+ num_decoder_layers,
51
+ decoder_norm,
52
+ return_intermediate=return_intermediate_dec,
53
+ )
54
+
55
+ self._reset_parameters()
56
+
57
+ self.d_model = d_model
58
+ self.nhead = nhead
59
+
60
+ def _reset_parameters(self):
61
+ for p in self.parameters():
62
+ if p.dim() > 1:
63
+ nn.init.xavier_uniform_(p)
64
+
65
+ def forward(self, src, mask, query_embed, pos_embed):
66
+ # flatten NxCxHxW to HWxNxC
67
+ bs, c, h, w = src.shape
68
+ src = src.flatten(2).permute(2, 0, 1)
69
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
70
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
71
+ if mask is not None:
72
+ mask = mask.flatten(1)
73
+
74
+ tgt = torch.zeros_like(query_embed)
75
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
76
+ hs = self.decoder(
77
+ tgt,
78
+ memory,
79
+ memory_key_padding_mask=mask,
80
+ pos=pos_embed,
81
+ query_pos=query_embed,
82
+ )
83
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
84
+
85
+
86
+ class TransformerEncoder(nn.Module):
87
+ def __init__(self, encoder_layer, num_layers, norm=None):
88
+ super().__init__()
89
+ self.layers = _get_clones(encoder_layer, num_layers)
90
+ self.num_layers = num_layers
91
+ self.norm = norm
92
+
93
+ def forward(
94
+ self,
95
+ src,
96
+ mask: Optional[Tensor] = None,
97
+ src_key_padding_mask: Optional[Tensor] = None,
98
+ pos: Optional[Tensor] = None,
99
+ ):
100
+ output = src
101
+
102
+ for layer in self.layers:
103
+ output = layer(
104
+ output,
105
+ src_mask=mask,
106
+ src_key_padding_mask=src_key_padding_mask,
107
+ pos=pos,
108
+ )
109
+
110
+ if self.norm is not None:
111
+ output = self.norm(output)
112
+
113
+ return output
114
+
115
+
116
+ class TransformerDecoder(nn.Module):
117
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
118
+ super().__init__()
119
+ self.layers = _get_clones(decoder_layer, num_layers)
120
+ self.num_layers = num_layers
121
+ self.norm = norm
122
+ self.return_intermediate = return_intermediate
123
+
124
+ def forward(
125
+ self,
126
+ tgt,
127
+ memory,
128
+ tgt_mask: Optional[Tensor] = None,
129
+ memory_mask: Optional[Tensor] = None,
130
+ tgt_key_padding_mask: Optional[Tensor] = None,
131
+ memory_key_padding_mask: Optional[Tensor] = None,
132
+ pos: Optional[Tensor] = None,
133
+ query_pos: Optional[Tensor] = None,
134
+ ):
135
+ output = tgt
136
+
137
+ intermediate = []
138
+
139
+ for layer in self.layers:
140
+ output = layer(
141
+ output,
142
+ memory,
143
+ tgt_mask=tgt_mask,
144
+ memory_mask=memory_mask,
145
+ tgt_key_padding_mask=tgt_key_padding_mask,
146
+ memory_key_padding_mask=memory_key_padding_mask,
147
+ pos=pos,
148
+ query_pos=query_pos,
149
+ )
150
+ if self.return_intermediate:
151
+ intermediate.append(self.norm(output))
152
+
153
+ if self.norm is not None:
154
+ output = self.norm(output)
155
+ if self.return_intermediate:
156
+ intermediate.pop()
157
+ intermediate.append(output)
158
+
159
+ if self.return_intermediate:
160
+ return torch.stack(intermediate)
161
+
162
+ return output.unsqueeze(0)
163
+
164
+
165
+ class TransformerEncoderLayer(nn.Module):
166
+ def __init__(
167
+ self,
168
+ d_model,
169
+ nhead,
170
+ dim_feedforward=2048,
171
+ dropout=0.1,
172
+ activation="relu",
173
+ normalize_before=False,
174
+ ):
175
+ super().__init__()
176
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
177
+ # Implementation of Feedforward model
178
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
179
+ self.dropout = nn.Dropout(dropout)
180
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
181
+
182
+ self.norm1 = nn.LayerNorm(d_model)
183
+ self.norm2 = nn.LayerNorm(d_model)
184
+ self.dropout1 = nn.Dropout(dropout)
185
+ self.dropout2 = nn.Dropout(dropout)
186
+
187
+ self.activation = _get_activation_fn(activation)
188
+ self.normalize_before = normalize_before
189
+
190
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
191
+ return tensor if pos is None else tensor + pos
192
+
193
+ def forward_post(
194
+ self,
195
+ src,
196
+ src_mask: Optional[Tensor] = None,
197
+ src_key_padding_mask: Optional[Tensor] = None,
198
+ pos: Optional[Tensor] = None,
199
+ ):
200
+ q = k = self.with_pos_embed(src, pos)
201
+ src2 = self.self_attn(
202
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
203
+ )[0]
204
+ src = src + self.dropout1(src2)
205
+ src = self.norm1(src)
206
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
207
+ src = src + self.dropout2(src2)
208
+ src = self.norm2(src)
209
+ return src
210
+
211
+ def forward_pre(
212
+ self,
213
+ src,
214
+ src_mask: Optional[Tensor] = None,
215
+ src_key_padding_mask: Optional[Tensor] = None,
216
+ pos: Optional[Tensor] = None,
217
+ ):
218
+ src2 = self.norm1(src)
219
+ q = k = self.with_pos_embed(src2, pos)
220
+ src2 = self.self_attn(
221
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
222
+ )[0]
223
+ src = src + self.dropout1(src2)
224
+ src2 = self.norm2(src)
225
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
226
+ src = src + self.dropout2(src2)
227
+ return src
228
+
229
+ def forward(
230
+ self,
231
+ src,
232
+ src_mask: Optional[Tensor] = None,
233
+ src_key_padding_mask: Optional[Tensor] = None,
234
+ pos: Optional[Tensor] = None,
235
+ ):
236
+ if self.normalize_before:
237
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
238
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
239
+
240
+
241
+ class TransformerDecoderLayer(nn.Module):
242
+ def __init__(
243
+ self,
244
+ d_model,
245
+ nhead,
246
+ dim_feedforward=2048,
247
+ dropout=0.1,
248
+ activation="relu",
249
+ normalize_before=False,
250
+ ):
251
+ super().__init__()
252
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
253
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
254
+ # Implementation of Feedforward model
255
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
256
+ self.dropout = nn.Dropout(dropout)
257
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
258
+
259
+ self.norm1 = nn.LayerNorm(d_model)
260
+ self.norm2 = nn.LayerNorm(d_model)
261
+ self.norm3 = nn.LayerNorm(d_model)
262
+ self.dropout1 = nn.Dropout(dropout)
263
+ self.dropout2 = nn.Dropout(dropout)
264
+ self.dropout3 = nn.Dropout(dropout)
265
+
266
+ self.activation = _get_activation_fn(activation)
267
+ self.normalize_before = normalize_before
268
+
269
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
270
+ return tensor if pos is None else tensor + pos
271
+
272
+ def forward_post(
273
+ self,
274
+ tgt,
275
+ memory,
276
+ tgt_mask: Optional[Tensor] = None,
277
+ memory_mask: Optional[Tensor] = None,
278
+ tgt_key_padding_mask: Optional[Tensor] = None,
279
+ memory_key_padding_mask: Optional[Tensor] = None,
280
+ pos: Optional[Tensor] = None,
281
+ query_pos: Optional[Tensor] = None,
282
+ ):
283
+ q = k = self.with_pos_embed(tgt, query_pos)
284
+ tgt2 = self.self_attn(
285
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
286
+ )[0]
287
+ tgt = tgt + self.dropout1(tgt2)
288
+ tgt = self.norm1(tgt)
289
+ tgt2 = self.multihead_attn(
290
+ query=self.with_pos_embed(tgt, query_pos),
291
+ key=self.with_pos_embed(memory, pos),
292
+ value=memory,
293
+ attn_mask=memory_mask,
294
+ key_padding_mask=memory_key_padding_mask,
295
+ )[0]
296
+ tgt = tgt + self.dropout2(tgt2)
297
+ tgt = self.norm2(tgt)
298
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
299
+ tgt = tgt + self.dropout3(tgt2)
300
+ tgt = self.norm3(tgt)
301
+ return tgt
302
+
303
+ def forward_pre(
304
+ self,
305
+ tgt,
306
+ memory,
307
+ tgt_mask: Optional[Tensor] = None,
308
+ memory_mask: Optional[Tensor] = None,
309
+ tgt_key_padding_mask: Optional[Tensor] = None,
310
+ memory_key_padding_mask: Optional[Tensor] = None,
311
+ pos: Optional[Tensor] = None,
312
+ query_pos: Optional[Tensor] = None,
313
+ ):
314
+ tgt2 = self.norm1(tgt)
315
+ q = k = self.with_pos_embed(tgt2, query_pos)
316
+ tgt2 = self.self_attn(
317
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
318
+ )[0]
319
+ tgt = tgt + self.dropout1(tgt2)
320
+ tgt2 = self.norm2(tgt)
321
+ tgt2 = self.multihead_attn(
322
+ query=self.with_pos_embed(tgt2, query_pos),
323
+ key=self.with_pos_embed(memory, pos),
324
+ value=memory,
325
+ attn_mask=memory_mask,
326
+ key_padding_mask=memory_key_padding_mask,
327
+ )[0]
328
+ tgt = tgt + self.dropout2(tgt2)
329
+ tgt2 = self.norm3(tgt)
330
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
331
+ tgt = tgt + self.dropout3(tgt2)
332
+ return tgt
333
+
334
+ def forward(
335
+ self,
336
+ tgt,
337
+ memory,
338
+ tgt_mask: Optional[Tensor] = None,
339
+ memory_mask: Optional[Tensor] = None,
340
+ tgt_key_padding_mask: Optional[Tensor] = None,
341
+ memory_key_padding_mask: Optional[Tensor] = None,
342
+ pos: Optional[Tensor] = None,
343
+ query_pos: Optional[Tensor] = None,
344
+ ):
345
+ if self.normalize_before:
346
+ return self.forward_pre(
347
+ tgt,
348
+ memory,
349
+ tgt_mask,
350
+ memory_mask,
351
+ tgt_key_padding_mask,
352
+ memory_key_padding_mask,
353
+ pos,
354
+ query_pos,
355
+ )
356
+ return self.forward_post(
357
+ tgt,
358
+ memory,
359
+ tgt_mask,
360
+ memory_mask,
361
+ tgt_key_padding_mask,
362
+ memory_key_padding_mask,
363
+ pos,
364
+ query_pos,
365
+ )
366
+
367
+
368
+ def _get_clones(module, N):
369
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
370
+
371
+
372
+ def _get_activation_fn(activation):
373
+ """Return an activation function given a string"""
374
+ if activation == "relu":
375
+ return F.relu
376
+ if activation == "gelu":
377
+ return F.gelu
378
+ if activation == "glu":
379
+ return F.glu
380
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
open_vocab_seg/modeling/transformer/transformer_predictor.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d
12
+
13
+ from .position_encoding import PositionEmbeddingSine
14
+ from .transformer import Transformer
15
+
16
+
17
+ class TransformerPredictor(nn.Module):
18
+ @configurable
19
+ def __init__(
20
+ self,
21
+ in_channels,
22
+ mask_classification=True,
23
+ *,
24
+ num_classes: int,
25
+ hidden_dim: int,
26
+ num_queries: int,
27
+ nheads: int,
28
+ dropout: float,
29
+ dim_feedforward: int,
30
+ enc_layers: int,
31
+ dec_layers: int,
32
+ pre_norm: bool,
33
+ deep_supervision: bool,
34
+ mask_dim: int,
35
+ enforce_input_project: bool,
36
+ ):
37
+ """
38
+ NOTE: this interface is experimental.
39
+ Args:
40
+ in_channels: channels of the input features
41
+ mask_classification: whether to add mask classifier or not
42
+ num_classes: number of classes
43
+ hidden_dim: Transformer feature dimension
44
+ num_queries: number of queries
45
+ nheads: number of heads
46
+ dropout: dropout in Transformer
47
+ dim_feedforward: feature dimension in feedforward network
48
+ enc_layers: number of Transformer encoder layers
49
+ dec_layers: number of Transformer decoder layers
50
+ pre_norm: whether to use pre-LayerNorm or not
51
+ deep_supervision: whether to add supervision to every decoder layers
52
+ mask_dim: mask feature dimension
53
+ enforce_input_project: add input project 1x1 conv even if input
54
+ channels and hidden dim is identical
55
+ """
56
+ super().__init__()
57
+
58
+ self.mask_classification = mask_classification
59
+
60
+ # positional encoding
61
+ N_steps = hidden_dim // 2
62
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
63
+
64
+ transformer = Transformer(
65
+ d_model=hidden_dim,
66
+ dropout=dropout,
67
+ nhead=nheads,
68
+ dim_feedforward=dim_feedforward,
69
+ num_encoder_layers=enc_layers,
70
+ num_decoder_layers=dec_layers,
71
+ normalize_before=pre_norm,
72
+ return_intermediate_dec=deep_supervision,
73
+ )
74
+
75
+ self.num_queries = num_queries
76
+ self.transformer = transformer
77
+ hidden_dim = transformer.d_model
78
+
79
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
80
+
81
+ if in_channels != hidden_dim or enforce_input_project:
82
+ self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
83
+ weight_init.c2_xavier_fill(self.input_proj)
84
+ else:
85
+ self.input_proj = nn.Sequential()
86
+ self.aux_loss = deep_supervision
87
+
88
+ # output FFNs
89
+ if self.mask_classification:
90
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
91
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
92
+
93
+ @classmethod
94
+ def from_config(cls, cfg, in_channels, mask_classification):
95
+ ret = {}
96
+ ret["in_channels"] = in_channels
97
+ ret["mask_classification"] = mask_classification
98
+
99
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
100
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
101
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
102
+ # Transformer parameters:
103
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
104
+ ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
105
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
106
+ ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
107
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
108
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
109
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
110
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
111
+
112
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
113
+
114
+ return ret
115
+
116
+ def forward(self, x, mask_features):
117
+ pos = self.pe_layer(x)
118
+
119
+ src = x
120
+ mask = None
121
+ hs, memory = self.transformer(
122
+ self.input_proj(src), mask, self.query_embed.weight, pos
123
+ )
124
+
125
+ if self.mask_classification:
126
+ outputs_class = self.class_embed(hs)
127
+ out = {"pred_logits": outputs_class[-1]}
128
+ else:
129
+ out = {}
130
+
131
+ if self.aux_loss:
132
+ # [l, bs, queries, embed]
133
+ mask_embed = self.mask_embed(hs)
134
+ outputs_seg_masks = torch.einsum(
135
+ "lbqc,bchw->lbqhw", mask_embed, mask_features
136
+ )
137
+ out["pred_masks"] = outputs_seg_masks[-1]
138
+ out["aux_outputs"] = self._set_aux_loss(
139
+ outputs_class if self.mask_classification else None, outputs_seg_masks
140
+ )
141
+ else:
142
+ # FIXME h_boxes takes the last one computed, keep this in mind
143
+ # [bs, queries, embed]
144
+ mask_embed = self.mask_embed(hs[-1])
145
+ outputs_seg_masks = torch.einsum(
146
+ "bqc,bchw->bqhw", mask_embed, mask_features
147
+ )
148
+ out["pred_masks"] = outputs_seg_masks
149
+ return out
150
+
151
+ @torch.jit.unused
152
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks):
153
+ # this is a workaround to make torchscript happy, as torchscript
154
+ # doesn't support dictionary with non-homogeneous values, such
155
+ # as a dict having both a Tensor and a list.
156
+ if self.mask_classification:
157
+ return [
158
+ {"pred_logits": a, "pred_masks": b}
159
+ for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
160
+ ]
161
+ else:
162
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
163
+
164
+
165
+ class MLP(nn.Module):
166
+ """Very simple multi-layer perceptron (also called FFN)"""
167
+
168
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
169
+ super().__init__()
170
+ self.num_layers = num_layers
171
+ h = [hidden_dim] * (num_layers - 1)
172
+ self.layers = nn.ModuleList(
173
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
174
+ )
175
+
176
+ def forward(self, x):
177
+ for i, layer in enumerate(self.layers):
178
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
179
+ return x
open_vocab_seg/ovseg_model.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+ # Modified by Feng Liang from
4
+ # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/zero_shot_mask_former_model.py
5
+
6
+ import logging
7
+ from typing import Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+
14
+ from detectron2.config import configurable
15
+ from detectron2.data import MetadataCatalog
16
+ from detectron2.modeling import META_ARCH_REGISTRY
17
+ from detectron2.modeling.backbone import Backbone
18
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
19
+ from detectron2.structures import ImageList
20
+ from detectron2.utils.logger import log_first_n
21
+ from .modeling.clip_adapter import (
22
+ ClipAdapter,
23
+ MaskFormerClipAdapter,
24
+ build_text_prompt,
25
+ )
26
+ from .mask_former_model import MaskFormer
27
+ from .utils.misc import get_gt_binary_masks
28
+
29
+ @META_ARCH_REGISTRY.register()
30
+ class OVSeg(MaskFormer):
31
+ """
32
+ Main class for zero shot mask classification semantic segmentation architectures.
33
+ """
34
+
35
+ @configurable
36
+ def __init__(
37
+ self,
38
+ *,
39
+ backbone: Backbone,
40
+ sem_seg_head: nn.Module,
41
+ clip_adapter: nn.Module,
42
+ criterion: nn.Module,
43
+ num_queries: int,
44
+ panoptic_on: bool,
45
+ object_mask_threshold: float,
46
+ overlap_threshold: float,
47
+ metadata,
48
+ size_divisibility: int,
49
+ sem_seg_postprocess_before_inference: bool,
50
+ clip_ensemble: bool,
51
+ clip_ensemble_weight: float,
52
+ pixel_mean: Tuple[float],
53
+ pixel_std: Tuple[float],
54
+ ):
55
+ """
56
+ Args:
57
+ backbone: a backbone module, must follow detectron2's backbone interface
58
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
59
+ criterion: a module that defines the loss
60
+ clip_adapter: adapter for clip-based mask classification
61
+ num_queries: int, number of queries
62
+ panoptic_on: bool, whether to output panoptic segmentation prediction
63
+ object_mask_threshold: float, threshold to filter query based on classification score
64
+ for panoptic segmentation inference
65
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
66
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
67
+ segmentation inference
68
+ size_divisibility: Some backbones require the input height and width to be divisible by a
69
+ specific integer. We can use this to override such requirement.
70
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
71
+ to original input size before semantic segmentation inference or after.
72
+ For high-resolution dataset like Mapillary, resizing predictions before
73
+ inference will cause OOM error.
74
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
75
+ the per-channel mean and std to be used to normalize the input image
76
+ """
77
+ super().__init__(
78
+ backbone=backbone,
79
+ sem_seg_head=sem_seg_head,
80
+ criterion=criterion,
81
+ num_queries=num_queries,
82
+ panoptic_on=panoptic_on,
83
+ object_mask_threshold=object_mask_threshold,
84
+ overlap_threshold=overlap_threshold,
85
+ metadata=metadata,
86
+ size_divisibility=size_divisibility,
87
+ sem_seg_postprocess_before_inference=sem_seg_postprocess_before_inference,
88
+ pixel_mean=pixel_mean,
89
+ pixel_std=pixel_std,
90
+ )
91
+ self.clip_adapter: ClipAdapter = clip_adapter
92
+
93
+ self.clip_ensemble: bool = clip_ensemble
94
+ self.clip_ensemble_weight: float = clip_ensemble_weight
95
+
96
+ @classmethod
97
+ def from_config(cls, cfg):
98
+ init_kwargs = MaskFormer.from_config(cfg)
99
+ text_templates = build_text_prompt(cfg.MODEL.CLIP_ADAPTER)
100
+
101
+ clip_adapter = MaskFormerClipAdapter(
102
+ cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME,
103
+ text_templates,
104
+ mask_fill=cfg.MODEL.CLIP_ADAPTER.MASK_FILL,
105
+ mask_expand_ratio=cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO,
106
+ mask_thr=cfg.MODEL.CLIP_ADAPTER.MASK_THR,
107
+ mask_matting=cfg.MODEL.CLIP_ADAPTER.MASK_MATTING,
108
+ region_resized=cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED,
109
+ mask_prompt_depth=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH,
110
+ mask_prompt_fwd=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD,
111
+ )
112
+ init_kwargs["clip_adapter"] = clip_adapter
113
+ init_kwargs["clip_ensemble"] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE
114
+ init_kwargs[
115
+ "clip_ensemble_weight"
116
+ ] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT
117
+
118
+ return init_kwargs
119
+
120
+ def forward(self, batched_inputs):
121
+ """
122
+ Args:
123
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
124
+ Each item in the list contains the inputs for one image.
125
+ For now, each item in the list is a dict that contains:
126
+ * "image": Tensor, image in (C, H, W) format.
127
+ * "instances": per-region ground truth
128
+ * Other information that's included in the original dicts, such as:
129
+ "height", "width" (int): the output resolution of the model (may be different
130
+ from input resolution), used in inference.
131
+ Returns:
132
+ list[dict]:
133
+ each dict has the results for one image. The dict contains the following keys:
134
+
135
+ * "sem_seg":
136
+ A Tensor that represents the
137
+ per-pixel segmentation prediced by the head.
138
+ The prediction has shape KxHxW that represents the logits of
139
+ each class for each pixel.
140
+ * "panoptic_seg":
141
+ A tuple that represent panoptic output
142
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
143
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
144
+ Each dict contains keys "id", "category_id", "isthing".
145
+ """
146
+ dataset_name = [x["meta"]["dataset_name"] for x in batched_inputs]
147
+ assert len(set(dataset_name)) == 1
148
+ dataset_name = dataset_name[0]
149
+
150
+ images = [x["image"].to(self.device) for x in batched_inputs]
151
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
152
+ images = ImageList.from_tensors(images, self.size_divisibility)
153
+
154
+ features = self.backbone(images.tensor)
155
+ outputs = self.sem_seg_head(features)
156
+ class_names = self.get_class_name_list(dataset_name)
157
+ text_features = self.clip_adapter.get_text_features(class_names)
158
+ outputs["pred_logits"] = self.clip_adapter.get_sim_logits(
159
+ text_features, self.clip_adapter.normalize_feature(outputs["pred_logits"])
160
+ )
161
+ if self.training:
162
+ if "aux_outputs" in outputs.keys():
163
+ for i in range(len(outputs["aux_outputs"])):
164
+ outputs["aux_outputs"][i][
165
+ "pred_logits"
166
+ ] = self.clip_adapter.get_sim_logits(
167
+ text_features,
168
+ self.clip_adapter.normalize_feature(
169
+ outputs["aux_outputs"][i]["pred_logits"]
170
+ ),
171
+ )
172
+ # mask classification target
173
+ if "instances" in batched_inputs[0]:
174
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
175
+ targets = self.prepare_targets(gt_instances, images)
176
+ else:
177
+ targets = None
178
+
179
+ # bipartite matching-based loss
180
+ losses = self.criterion(outputs, targets)
181
+
182
+ for k in list(losses.keys()):
183
+ if k in self.criterion.weight_dict:
184
+ losses[k] *= self.criterion.weight_dict[k]
185
+ else:
186
+ # remove this loss if not specified in `weight_dict`
187
+ losses.pop(k)
188
+
189
+ return losses
190
+ else:
191
+ mask_cls_results = outputs["pred_logits"]
192
+ mask_pred_results = outputs["pred_masks"]
193
+ # upsample masks
194
+ mask_pred_results = F.interpolate(
195
+ mask_pred_results,
196
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
197
+ mode="bilinear",
198
+ align_corners=False,
199
+ )
200
+
201
+ processed_results = []
202
+ for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
203
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
204
+ ):
205
+ height = image_size[0]
206
+ width = image_size[1]
207
+ mask_pred_result = sem_seg_postprocess(
208
+ mask_pred_result, image_size, height, width
209
+ )
210
+ image = input_per_image["image"].to(self.device)
211
+
212
+ r, regions = self.semantic_inference(
213
+ mask_cls_result, mask_pred_result, image, class_names
214
+ )
215
+
216
+ height = input_per_image.get("height", image_size[0])
217
+ width = input_per_image.get("width", image_size[1])
218
+ r = sem_seg_postprocess(r, image_size, height, width)
219
+ processed_results.append({"sem_seg": r})
220
+
221
+ # panoptic segmentation inference
222
+ if self.panoptic_on:
223
+ panoptic_r = self.panoptic_inference(
224
+ mask_cls_result, mask_pred_result
225
+ )
226
+ processed_results[-1]["panoptic_seg"] = panoptic_r
227
+
228
+ return processed_results
229
+
230
+
231
+ def semantic_inference(self, mask_cls, mask_pred, image, class_names):
232
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
233
+ mask_pred = mask_pred.sigmoid()
234
+
235
+ regions = None
236
+ if self.clip_ensemble:
237
+ clip_cls, regions, valid_flag = self.clip_adapter(
238
+ image, class_names, mask_pred, normalize=True
239
+ )
240
+ if clip_cls is None:
241
+ clip_cls = torch.empty(0, mask_cls.shape[-1] + 1, device=self.device)
242
+ # softmax before index or after?
243
+ clip_cls = F.softmax(clip_cls[:, :-1], dim=-1)
244
+ if self.clip_ensemble_weight > 0:
245
+ map_back_clip_cls = mask_cls.new_ones(mask_cls.shape)
246
+ map_back_clip_cls[valid_flag] = clip_cls
247
+ mask_cls = torch.pow(mask_cls, 1 - self.clip_ensemble_weight) * \
248
+ torch.pow(map_back_clip_cls, self.clip_ensemble_weight)
249
+
250
+
251
+ else:
252
+ # only clip model predictions are used
253
+ mask_cls = clip_cls
254
+ mask_pred = mask_pred[valid_flag]
255
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
256
+ return semseg, regions
257
+
258
+ def get_class_name_list(self, dataset_name):
259
+ class_names = [
260
+ c.strip() for c in MetadataCatalog.get(dataset_name).stuff_classes
261
+ ]
262
+ return class_names
263
+
264
+
265
+ @META_ARCH_REGISTRY.register()
266
+ class OVSegDEMO(MaskFormer):
267
+ """
268
+ Main class for zero shot mask classification semantic segmentation architectures.
269
+ """
270
+
271
+ @configurable
272
+ def __init__(
273
+ self,
274
+ *,
275
+ backbone: Backbone,
276
+ sem_seg_head: nn.Module,
277
+ clip_adapter: nn.Module,
278
+ criterion: nn.Module,
279
+ num_queries: int,
280
+ panoptic_on: bool,
281
+ object_mask_threshold: float,
282
+ overlap_threshold: float,
283
+ metadata,
284
+ size_divisibility: int,
285
+ sem_seg_postprocess_before_inference: bool,
286
+ clip_ensemble: bool,
287
+ clip_ensemble_weight: float,
288
+ pixel_mean: Tuple[float],
289
+ pixel_std: Tuple[float],
290
+ ):
291
+ """
292
+ Args:
293
+ backbone: a backbone module, must follow detectron2's backbone interface
294
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
295
+ criterion: a module that defines the loss
296
+ clip_adapter: adapter for clip-based mask classification
297
+ num_queries: int, number of queries
298
+ panoptic_on: bool, whether to output panoptic segmentation prediction
299
+ object_mask_threshold: float, threshold to filter query based on classification score
300
+ for panoptic segmentation inference
301
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
302
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
303
+ segmentation inference
304
+ size_divisibility: Some backbones require the input height and width to be divisible by a
305
+ specific integer. We can use this to override such requirement.
306
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
307
+ to original input size before semantic segmentation inference or after.
308
+ For high-resolution dataset like Mapillary, resizing predictions before
309
+ inference will cause OOM error.
310
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
311
+ the per-channel mean and std to be used to normalize the input image
312
+ """
313
+ super().__init__(
314
+ backbone=backbone,
315
+ sem_seg_head=sem_seg_head,
316
+ criterion=criterion,
317
+ num_queries=num_queries,
318
+ panoptic_on=panoptic_on,
319
+ object_mask_threshold=object_mask_threshold,
320
+ overlap_threshold=overlap_threshold,
321
+ metadata=metadata,
322
+ size_divisibility=size_divisibility,
323
+ sem_seg_postprocess_before_inference=sem_seg_postprocess_before_inference,
324
+ pixel_mean=pixel_mean,
325
+ pixel_std=pixel_std,
326
+ )
327
+ self.clip_adapter: ClipAdapter = clip_adapter
328
+
329
+ self.clip_ensemble: bool = clip_ensemble
330
+ self.clip_ensemble_weight: float = clip_ensemble_weight
331
+
332
+ @classmethod
333
+ def from_config(cls, cfg):
334
+ init_kwargs = MaskFormer.from_config(cfg)
335
+ text_templates = build_text_prompt(cfg.MODEL.CLIP_ADAPTER)
336
+
337
+ clip_adapter = MaskFormerClipAdapter(
338
+ cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME,
339
+ text_templates,
340
+ mask_fill=cfg.MODEL.CLIP_ADAPTER.MASK_FILL,
341
+ mask_expand_ratio=cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO,
342
+ mask_thr=cfg.MODEL.CLIP_ADAPTER.MASK_THR,
343
+ mask_matting=cfg.MODEL.CLIP_ADAPTER.MASK_MATTING,
344
+ region_resized=cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED,
345
+ mask_prompt_depth=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH,
346
+ mask_prompt_fwd=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD,
347
+ )
348
+ init_kwargs["clip_adapter"] = clip_adapter
349
+ init_kwargs["clip_ensemble"] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE
350
+ init_kwargs[
351
+ "clip_ensemble_weight"
352
+ ] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT
353
+
354
+ return init_kwargs
355
+
356
+ def forward(self, batched_inputs):
357
+ """
358
+ Args:
359
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
360
+ Each item in the list contains the inputs for one image.
361
+ For now, each item in the list is a dict that contains:
362
+ * "image": Tensor, image in (C, H, W) format.
363
+ * "instances": per-region ground truth
364
+ * Other information that's included in the original dicts, such as:
365
+ "height", "width" (int): the output resolution of the model (may be different
366
+ from input resolution), used in inference.
367
+ Returns:
368
+ list[dict]:
369
+ each dict has the results for one image. The dict contains the following keys:
370
+
371
+ * "sem_seg":
372
+ A Tensor that represents the
373
+ per-pixel segmentation prediced by the head.
374
+ The prediction has shape KxHxW that represents the logits of
375
+ each class for each pixel.
376
+ * "panoptic_seg":
377
+ A tuple that represent panoptic output
378
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
379
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
380
+ Each dict contains keys "id", "category_id", "isthing".
381
+ """
382
+ images = [x["image"].to(self.device) for x in batched_inputs]
383
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
384
+ images = ImageList.from_tensors(images, self.size_divisibility)
385
+
386
+ features = self.backbone(images.tensor)
387
+ outputs = self.sem_seg_head(features)
388
+ class_names = batched_inputs[0]["class_names"]
389
+ if len(class_names) == 1:
390
+ # Because classification is performed in a 'contrastive' manner, adding others to represent other concepts
391
+ class_names.append('others')
392
+ text_features = self.clip_adapter.get_text_features(class_names)
393
+ outputs["pred_logits"] = self.clip_adapter.get_sim_logits(
394
+ text_features, self.clip_adapter.normalize_feature(outputs["pred_logits"])
395
+ )
396
+ mask_cls_results = outputs["pred_logits"]
397
+ mask_pred_results = outputs["pred_masks"]
398
+ # upsample masks
399
+ mask_pred_results = F.interpolate(
400
+ mask_pred_results,
401
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
402
+ mode="bilinear",
403
+ align_corners=False,
404
+ )
405
+
406
+ processed_results = []
407
+ for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
408
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
409
+ ):
410
+ height = image_size[0]
411
+ width = image_size[1]
412
+ mask_pred_result = sem_seg_postprocess(
413
+ mask_pred_result, image_size, height, width
414
+ )
415
+ image = input_per_image["image"].to(self.device)
416
+
417
+ r, regions = self.demo_inference(mask_cls_result, mask_pred_result, image, class_names)
418
+
419
+ height = input_per_image.get("height", image_size[0])
420
+ width = input_per_image.get("width", image_size[1])
421
+ r = sem_seg_postprocess(r, image_size, height, width)
422
+ processed_results.append({"sem_seg": r})
423
+
424
+ return processed_results
425
+
426
+
427
+
428
+
429
+ def demo_inference(self, mask_cls, mask_pred, image, class_names):
430
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
431
+ mask_pred = mask_pred.sigmoid()
432
+
433
+ regions = None
434
+ if self.clip_ensemble:
435
+ clip_cls, regions, valid_flag = self.clip_adapter(
436
+ image, class_names, mask_pred, normalize=True
437
+ )
438
+ if clip_cls is None:
439
+ clip_cls = torch.empty(0, mask_cls.shape[-1] + 1, device=self.device)
440
+ # softmax before index or after?
441
+ clip_cls = F.softmax(clip_cls[:, :-1], dim=-1)
442
+ if self.clip_ensemble_weight > 0:
443
+ map_back_clip_cls = mask_cls.new_ones(mask_cls.shape)
444
+ map_back_clip_cls[valid_flag] = clip_cls
445
+ mask_cls = torch.pow(mask_cls, 1 - self.clip_ensemble_weight) * \
446
+ torch.pow(map_back_clip_cls, self.clip_ensemble_weight)
447
+
448
+ else:
449
+ # only clip model predictions are used
450
+ mask_cls = clip_cls
451
+ mask_pred = mask_pred[valid_flag]
452
+ bin_mask = mask_pred > self.clip_adapter.mask_thr
453
+ select_cls = torch.zeros(sum(valid_flag), mask_cls.shape[-1], device=self.device)
454
+ select_mask = torch.argmax(mask_cls, dim=0)
455
+ if len(class_names) == 2 and class_names[-1] == 'others':
456
+ select_mask = select_mask[:-1]
457
+ for idx in select_mask:
458
+ select_cls[idx] = mask_cls[idx]
459
+ semseg = torch.einsum("qc,qhw->chw", select_cls, bin_mask.float())
460
+ return semseg, regions
open_vocab_seg/test_time_augmentation.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import copy
5
+ from itertools import count
6
+ import math
7
+ import numpy as np
8
+ import torch
9
+ from fvcore.transforms import HFlipTransform
10
+ from torch import nn
11
+ from torch.nn.parallel import DistributedDataParallel
12
+
13
+ from detectron2.data.detection_utils import read_image
14
+ from detectron2.modeling import DatasetMapperTTA
15
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
16
+ import logging
17
+ from detectron2.utils.logger import log_every_n, log_first_n
18
+
19
+ __all__ = [
20
+ "SemanticSegmentorWithTTA",
21
+ ]
22
+
23
+
24
+ class SemanticSegmentorWithTTA(nn.Module):
25
+ """
26
+ A SemanticSegmentor with test-time augmentation enabled.
27
+ Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
28
+ """
29
+
30
+ def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
31
+ """
32
+ Args:
33
+ cfg (CfgNode):
34
+ model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
35
+ tta_mapper (callable): takes a dataset dict and returns a list of
36
+ augmented versions of the dataset dict. Defaults to
37
+ `DatasetMapperTTA(cfg)`.
38
+ batch_size (int): batch the augmented images into this batch size for inference.
39
+ """
40
+ super().__init__()
41
+ if isinstance(model, DistributedDataParallel):
42
+ model = model.module
43
+ self.cfg = cfg.clone()
44
+
45
+ self.model = model
46
+
47
+ if tta_mapper is None:
48
+ tta_mapper = DatasetMapperTTA(cfg)
49
+ self.tta_mapper = tta_mapper
50
+ self.batch_size = batch_size
51
+
52
+ def _inference_with_model(self, inputs):
53
+ if self.cfg.TEST.SLIDING_WINDOW:
54
+ log_first_n(logging.INFO, "Using sliding window to test")
55
+
56
+ outputs = []
57
+
58
+ for input in inputs:
59
+ image_size = input["image"].shape[1:] # h,w
60
+ if self.cfg.TEST.SLIDING_TILE_SIZE > 0:
61
+ tile_size = (
62
+ self.cfg.TEST.SLIDING_TILE_SIZE,
63
+ self.cfg.TEST.SLIDING_TILE_SIZE,
64
+ )
65
+ else:
66
+ selected_mapping = {256: 224, 512: 256, 768: 512, 896: 512}
67
+ tile_size = min(image_size)
68
+ tile_size = selected_mapping[tile_size]
69
+ tile_size = (tile_size, tile_size)
70
+ extra_info = {
71
+ k: v
72
+ for k, v in input.items()
73
+ if k not in ["image", "height", "width"]
74
+ }
75
+ log_every_n(
76
+ logging.INFO, "split {} to {}".format(image_size, tile_size)
77
+ )
78
+ overlap = self.cfg.TEST.SLIDING_OVERLAP
79
+ stride = math.ceil(tile_size[0] * (1 - overlap))
80
+ tile_rows = int(
81
+ math.ceil((image_size[0] - tile_size[0]) / stride) + 1
82
+ ) # strided convolution formula
83
+ tile_cols = int(math.ceil((image_size[1] - tile_size[1]) / stride) + 1)
84
+ full_probs = None
85
+ count_predictions = None
86
+ tile_counter = 0
87
+
88
+ for row in range(tile_rows):
89
+ for col in range(tile_cols):
90
+ x1 = int(col * stride)
91
+ y1 = int(row * stride)
92
+ x2 = min(x1 + tile_size[1], image_size[1])
93
+ y2 = min(y1 + tile_size[0], image_size[0])
94
+ x1 = max(
95
+ int(x2 - tile_size[1]), 0
96
+ ) # for portrait images the x1 underflows sometimes
97
+ y1 = max(
98
+ int(y2 - tile_size[0]), 0
99
+ ) # for very few rows y1 underflows
100
+
101
+ img = input["image"][:, y1:y2, x1:x2]
102
+ padded_img = nn.functional.pad(
103
+ img,
104
+ (
105
+ 0,
106
+ tile_size[1] - img.shape[-1],
107
+ 0,
108
+ tile_size[0] - img.shape[-2],
109
+ ),
110
+ )
111
+ tile_counter += 1
112
+ padded_input = {"image": padded_img}
113
+ padded_input.update(extra_info)
114
+ padded_prediction = self.model([padded_input])[0]["sem_seg"]
115
+ prediction = padded_prediction[
116
+ :, 0 : img.shape[1], 0 : img.shape[2]
117
+ ]
118
+ if full_probs is None:
119
+ full_probs = prediction.new_zeros(
120
+ prediction.shape[0], image_size[0], image_size[1]
121
+ )
122
+ if count_predictions is None:
123
+ count_predictions = prediction.new_zeros(
124
+ prediction.shape[0], image_size[0], image_size[1]
125
+ )
126
+ count_predictions[:, y1:y2, x1:x2] += 1
127
+ full_probs[
128
+ :, y1:y2, x1:x2
129
+ ] += prediction # accumulate the predictions also in the overlapping regions
130
+
131
+ full_probs /= count_predictions
132
+ full_probs = sem_seg_postprocess(
133
+ full_probs,
134
+ image_size,
135
+ input.get("height", image_size[0]),
136
+ input.get("width", image_size[1]),
137
+ )
138
+ outputs.append({"sem_seg": full_probs})
139
+
140
+ return outputs
141
+ else:
142
+ log_first_n(logging.INFO, "Using whole image to test")
143
+ return self.model(inputs)
144
+
145
+ def _batch_inference(self, batched_inputs):
146
+ """
147
+ Execute inference on a list of inputs,
148
+ using batch size = self.batch_size, instead of the length of the list.
149
+ Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward`
150
+ """
151
+ outputs = []
152
+ inputs = []
153
+ for idx, input in zip(count(), batched_inputs):
154
+ inputs.append(input)
155
+ if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
156
+ with torch.no_grad():
157
+ outputs.extend(self._inference_with_model(inputs))
158
+ inputs = []
159
+ return outputs
160
+
161
+ def __call__(self, batched_inputs):
162
+ """
163
+ Same input/output format as :meth:`SemanticSegmentor.forward`
164
+ """
165
+
166
+ def _maybe_read_image(dataset_dict):
167
+ ret = copy.copy(dataset_dict)
168
+ if "image" not in ret:
169
+ image = read_image(ret.pop("file_name"), self.model.input_format)
170
+ image = torch.from_numpy(
171
+ np.ascontiguousarray(image.transpose(2, 0, 1))
172
+ ) # CHW
173
+ ret["image"] = image
174
+ if "height" not in ret and "width" not in ret:
175
+ ret["height"] = image.shape[1]
176
+ ret["width"] = image.shape[2]
177
+ return ret
178
+
179
+ return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs]
180
+
181
+ def _inference_one_image(self, input):
182
+ """
183
+ Args:
184
+ input (dict): one dataset dict with "image" field being a CHW tensor
185
+ Returns:
186
+ dict: one output dict
187
+ """
188
+ augmented_inputs, tfms = self._get_augmented_inputs(input)
189
+ # 1: forward with all augmented images
190
+ outputs = self._batch_inference(augmented_inputs)
191
+ # Delete now useless variables to avoid being out of memory
192
+ del augmented_inputs
193
+ # 2: merge the results
194
+ # handle flip specially
195
+ # outputs = [output.detach() for output in outputs]
196
+ return self._merge_auged_output(outputs, tfms)
197
+
198
+ def _merge_auged_output(self, outputs, tfms):
199
+ new_outputs = []
200
+ for output, tfm in zip(outputs, tfms):
201
+ if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
202
+ new_outputs.append(output["sem_seg"].flip(dims=[2]))
203
+ else:
204
+ new_outputs.append(output["sem_seg"])
205
+ del outputs
206
+ # to avoid OOM with torch.stack
207
+ final_predictions = new_outputs[0]
208
+ for i in range(1, len(new_outputs)):
209
+ final_predictions += new_outputs[i]
210
+ final_predictions = final_predictions / len(new_outputs)
211
+ del new_outputs
212
+ return {"sem_seg": final_predictions}
213
+
214
+ def _get_augmented_inputs(self, input):
215
+ augmented_inputs = self.tta_mapper(input)
216
+ tfms = [x.pop("transforms") for x in augmented_inputs]
217
+ return augmented_inputs, tfms