Spaces:
Runtime error
Runtime error
Commit
•
7e8c559
0
Parent(s):
Duplicate from facebook/ov-seg
Browse filesCo-authored-by: Jeff Liang <JeffLiang@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -0
- .idea/vcs.xml +6 -0
- README.md +14 -0
- app.py +96 -0
- open_vocab_seg/.DS_Store +0 -0
- open_vocab_seg/__init__.py +9 -0
- open_vocab_seg/config.py +133 -0
- open_vocab_seg/data/.DS_Store +0 -0
- open_vocab_seg/data/__init__.py +9 -0
- open_vocab_seg/data/augmentations.py +202 -0
- open_vocab_seg/data/build.py +344 -0
- open_vocab_seg/data/dataset_mappers/__init__.py +4 -0
- open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py +208 -0
- open_vocab_seg/data/datasets/__init__.py +5 -0
- open_vocab_seg/data/datasets/csv_data.py +459 -0
- open_vocab_seg/data/datasets/register_ade20k_full.py +995 -0
- open_vocab_seg/data/datasets/register_cc3m.py +457 -0
- open_vocab_seg/data/datasets/register_coco_stuff.py +250 -0
- open_vocab_seg/data/datasets/register_pascal_context.py +588 -0
- open_vocab_seg/data/datasets/register_voc_seg.py +62 -0
- open_vocab_seg/evaluation/__init__.py +4 -0
- open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py +159 -0
- open_vocab_seg/mask_former_model.py +254 -0
- open_vocab_seg/modeling/.DS_Store +0 -0
- open_vocab_seg/modeling/__init__.py +8 -0
- open_vocab_seg/modeling/backbone/__init__.py +2 -0
- open_vocab_seg/modeling/backbone/clip_resnet.py +206 -0
- open_vocab_seg/modeling/backbone/swin.py +832 -0
- open_vocab_seg/modeling/clip_adapter/__init__.py +25 -0
- open_vocab_seg/modeling/clip_adapter/adapter.py +206 -0
- open_vocab_seg/modeling/clip_adapter/clip/__init__.py +1 -0
- open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- open_vocab_seg/modeling/clip_adapter/clip/clip.py +285 -0
- open_vocab_seg/modeling/clip_adapter/clip/model.py +613 -0
- open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py +150 -0
- open_vocab_seg/modeling/clip_adapter/text_template.py +156 -0
- open_vocab_seg/modeling/clip_adapter/utils.py +81 -0
- open_vocab_seg/modeling/criterion.py +229 -0
- open_vocab_seg/modeling/heads/__init__.py +2 -0
- open_vocab_seg/modeling/heads/mask_former_head.py +135 -0
- open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py +145 -0
- open_vocab_seg/modeling/heads/pixel_decoder.py +308 -0
- open_vocab_seg/modeling/matcher.py +187 -0
- open_vocab_seg/modeling/transformer/__init__.py +2 -0
- open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py +84 -0
- open_vocab_seg/modeling/transformer/position_encoding.py +58 -0
- open_vocab_seg/modeling/transformer/transformer.py +380 -0
- open_vocab_seg/modeling/transformer/transformer_predictor.py +179 -0
- open_vocab_seg/ovseg_model.py +460 -0
- open_vocab_seg/test_time_augmentation.py +217 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
25 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Ov Seg
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.8.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: cc-by-nc-4.0
|
11 |
+
duplicated_from: facebook/ov-seg
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import multiprocessing as mp
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
try:
|
11 |
+
import detectron2
|
12 |
+
except:
|
13 |
+
import os
|
14 |
+
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
15 |
+
|
16 |
+
from detectron2.config import get_cfg
|
17 |
+
|
18 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
19 |
+
from detectron2.data.detection_utils import read_image
|
20 |
+
from open_vocab_seg import add_ovseg_config
|
21 |
+
from open_vocab_seg.utils import VisualizationDemo, SAMVisualizationDemo
|
22 |
+
|
23 |
+
import gradio as gr
|
24 |
+
|
25 |
+
import gdown
|
26 |
+
|
27 |
+
# ckpt_url = 'https://drive.google.com/uc?id=1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy'
|
28 |
+
# output = './ovseg_swinbase_vitL14_ft_mpt.pth'
|
29 |
+
# gdown.download(ckpt_url, output, quiet=False)
|
30 |
+
|
31 |
+
def setup_cfg(config_file):
|
32 |
+
# load config from file and command-line arguments
|
33 |
+
cfg = get_cfg()
|
34 |
+
add_deeplab_config(cfg)
|
35 |
+
add_ovseg_config(cfg)
|
36 |
+
cfg.merge_from_file(config_file)
|
37 |
+
cfg.freeze()
|
38 |
+
return cfg
|
39 |
+
|
40 |
+
|
41 |
+
def inference(class_names, proposal_gen, granularity, input_img):
|
42 |
+
mp.set_start_method("spawn", force=True)
|
43 |
+
config_file = './ovseg_swinB_vitL_demo.yaml'
|
44 |
+
cfg = setup_cfg(config_file)
|
45 |
+
if proposal_gen == 'MaskFormer':
|
46 |
+
demo = VisualizationDemo(cfg)
|
47 |
+
elif proposal_gen == 'Segment_Anything':
|
48 |
+
demo = SAMVisualizationDemo(cfg, granularity, './sam_vit_l_0b3195.pth', './ovseg_clip_l_9a1909.pth')
|
49 |
+
class_names = class_names.split(',')
|
50 |
+
img = read_image(input_img, format="BGR")
|
51 |
+
_, visualized_output = demo.run_on_image(img, class_names)
|
52 |
+
|
53 |
+
return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB')
|
54 |
+
|
55 |
+
|
56 |
+
examples = [['Saturn V, toys, desk, wall, sunflowers, white roses, chrysanthemums, carnations, green dianthus', 'Segment_Anything', 0.8, './resources/demo_samples/sample_01.jpeg'],
|
57 |
+
['red bench, yellow bench, blue bench, brown bench, green bench, blue chair, yellow chair, green chair, brown chair, yellow square painting, barrel, buddha statue', 'Segment_Anything', 0.8, './resources/demo_samples/sample_04.png'],
|
58 |
+
['pillow, pipe, sweater, shirt, jeans jacket, shoes, cabinet, handbag, photo frame', 'Segment_Anything', 0.8, './resources/demo_samples/sample_05.png'],
|
59 |
+
['Saturn V, toys, blossom', 'MaskFormer', 1.0, './resources/demo_samples/sample_01.jpeg'],
|
60 |
+
['Oculus, Ukulele', 'MaskFormer', 1.0, './resources/demo_samples/sample_03.jpeg'],
|
61 |
+
['Golden gate, yacht', 'MaskFormer', 1.0, './resources/demo_samples/sample_02.jpeg'],]
|
62 |
+
output_labels = ['segmentation map']
|
63 |
+
|
64 |
+
title = 'OVSeg (+ Segment_Anything)'
|
65 |
+
|
66 |
+
description = """
|
67 |
+
[NEW!] We incorperate OVSeg CLIP w/ Segment_Anything, enabling SAM's text prompts.
|
68 |
+
Gradio Demo for Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP. \n
|
69 |
+
OVSeg could perform open vocabulary segmentation, you may input more classes (seperate by comma). You may click on of the examples or upload your own image. \n
|
70 |
+
It might take some time to process. Cheers!
|
71 |
+
<p>(Colab only supports MaskFormer proposal generator) Don't want to wait in queue? <a href="https://colab.research.google.com/drive/1O4Ain5uFZNcQYUmDTG92DpEGCatga8K5?usp=sharing"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://camo.githubusercontent.com/84f0493939e0c4de4e6dbe113251b4bfb5353e57134ffd9fcab6b8714514d4d1/68747470733a2f2f636f6c61622e72657365617263682e676f6f676c652e636f6d2f6173736574732f636f6c61622d62616467652e737667"></a></p>
|
72 |
+
"""
|
73 |
+
|
74 |
+
article = """
|
75 |
+
<p style='text-align: center'>
|
76 |
+
<a href='https://arxiv.org/abs/2210.04150' target='_blank'>
|
77 |
+
Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP
|
78 |
+
</a>
|
79 |
+
|
|
80 |
+
<a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p>
|
81 |
+
"""
|
82 |
+
|
83 |
+
gr.Interface(
|
84 |
+
inference,
|
85 |
+
inputs=[
|
86 |
+
gr.Textbox(
|
87 |
+
lines=1, placeholder=None, default='', label='class names'),
|
88 |
+
gr.Radio(["Segment_Anything", "MaskFormer"], label="Proposal generator", default="Segment_Anything"),
|
89 |
+
gr.Slider(0, 1.0, 0.8, label="For Segment_Anything only, granularity of masks from 0 (most coarse) to 1 (most precise)"),
|
90 |
+
gr.Image(type='filepath'),
|
91 |
+
],
|
92 |
+
outputs=gr.outputs.Image(label='segmentation map'),
|
93 |
+
title=title,
|
94 |
+
description=description,
|
95 |
+
article=article,
|
96 |
+
examples=examples).launch(enable_queue=True)
|
open_vocab_seg/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
open_vocab_seg/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from . import data
|
5 |
+
from . import modeling
|
6 |
+
from .config import add_ovseg_config
|
7 |
+
|
8 |
+
from .test_time_augmentation import SemanticSegmentorWithTTA
|
9 |
+
from .ovseg_model import OVSeg, OVSegDEMO
|
open_vocab_seg/config.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from detectron2.config import CfgNode as CN
|
5 |
+
|
6 |
+
|
7 |
+
def add_mask_former_default_config(cfg):
|
8 |
+
# data config
|
9 |
+
# select the dataset mapper
|
10 |
+
cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
|
11 |
+
# Color augmentation
|
12 |
+
cfg.INPUT.COLOR_AUG_SSD = False
|
13 |
+
# We retry random cropping until no single category in semantic segmentation GT occupies more
|
14 |
+
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
|
15 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
|
16 |
+
# Pad image and segmentation GT in dataset mapper.
|
17 |
+
cfg.INPUT.SIZE_DIVISIBILITY = -1
|
18 |
+
|
19 |
+
# solver config
|
20 |
+
# test batch size
|
21 |
+
cfg.SOLVER.TEST_IMS_PER_BATCH = 1
|
22 |
+
# weight decay on embedding
|
23 |
+
cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
|
24 |
+
# optimizer
|
25 |
+
cfg.SOLVER.OPTIMIZER = "ADAMW"
|
26 |
+
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
|
27 |
+
|
28 |
+
# mask_former model config
|
29 |
+
cfg.MODEL.MASK_FORMER = CN()
|
30 |
+
|
31 |
+
# loss
|
32 |
+
cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
|
33 |
+
cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
|
34 |
+
cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
|
35 |
+
cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
|
36 |
+
|
37 |
+
# transformer config
|
38 |
+
cfg.MODEL.MASK_FORMER.NHEADS = 8
|
39 |
+
cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
|
40 |
+
cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
|
41 |
+
cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
|
42 |
+
cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
|
43 |
+
cfg.MODEL.MASK_FORMER.PRE_NORM = False
|
44 |
+
|
45 |
+
cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
|
46 |
+
cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
|
47 |
+
|
48 |
+
cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
|
49 |
+
cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
|
50 |
+
|
51 |
+
# mask_former inference config
|
52 |
+
cfg.MODEL.MASK_FORMER.TEST = CN()
|
53 |
+
cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
|
54 |
+
cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
|
55 |
+
cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
|
56 |
+
cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
|
57 |
+
|
58 |
+
# Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
|
59 |
+
# you can use this config to override
|
60 |
+
cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
|
61 |
+
|
62 |
+
# pixel decoder config
|
63 |
+
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
|
64 |
+
# adding transformer in pixel decoder
|
65 |
+
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
|
66 |
+
# pixel decoder
|
67 |
+
cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
|
68 |
+
|
69 |
+
# swin transformer backbone
|
70 |
+
cfg.MODEL.SWIN = CN()
|
71 |
+
cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
|
72 |
+
cfg.MODEL.SWIN.PATCH_SIZE = 4
|
73 |
+
cfg.MODEL.SWIN.EMBED_DIM = 96
|
74 |
+
cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
|
75 |
+
cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
76 |
+
cfg.MODEL.SWIN.WINDOW_SIZE = 7
|
77 |
+
cfg.MODEL.SWIN.MLP_RATIO = 4.0
|
78 |
+
cfg.MODEL.SWIN.QKV_BIAS = True
|
79 |
+
cfg.MODEL.SWIN.QK_SCALE = None
|
80 |
+
cfg.MODEL.SWIN.NORM_INDICES = None
|
81 |
+
cfg.MODEL.SWIN.PROJECTION = False
|
82 |
+
cfg.MODEL.SWIN.PROJECT_DIM = 256
|
83 |
+
cfg.MODEL.SWIN.DROP_RATE = 0.0
|
84 |
+
cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
|
85 |
+
cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
|
86 |
+
cfg.MODEL.SWIN.APE = False
|
87 |
+
cfg.MODEL.SWIN.PATCH_NORM = True
|
88 |
+
cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
89 |
+
|
90 |
+
|
91 |
+
def add_our_config(cfg):
|
92 |
+
cfg.TEST.SLIDING_WINDOW = False
|
93 |
+
cfg.TEST.SLIDING_TILE_SIZE = 224
|
94 |
+
cfg.TEST.SLIDING_OVERLAP = 2 / 3.0
|
95 |
+
# whether to use dense crf
|
96 |
+
cfg.TEST.DENSE_CRF = False
|
97 |
+
cfg.DATASETS.SAMPLE_PER_CLASS = -1
|
98 |
+
cfg.DATASETS.SAMPLE_SEED = 0
|
99 |
+
# embedding head
|
100 |
+
cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM = 512
|
101 |
+
cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM = 1024
|
102 |
+
cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS = 2
|
103 |
+
# clip_adapter
|
104 |
+
cfg.MODEL.CLIP_ADAPTER = CN()
|
105 |
+
cfg.MODEL.CLIP_ADAPTER.TEXT_TEMPLATES = "vild"
|
106 |
+
# for predefined
|
107 |
+
cfg.MODEL.CLIP_ADAPTER.PREDEFINED_PROMPT_TEMPLATES = ["a photo of a {}."]
|
108 |
+
# for learnable prompt
|
109 |
+
cfg.MODEL.CLIP_ADAPTER.PROMPT_CHECKPOINT = ""
|
110 |
+
cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME = "ViT-B/16"
|
111 |
+
cfg.MODEL.CLIP_ADAPTER.MASK_FILL = "mean"
|
112 |
+
cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO = 1.0
|
113 |
+
cfg.MODEL.CLIP_ADAPTER.MASK_THR = 0.4
|
114 |
+
cfg.MODEL.CLIP_ADAPTER.MASK_MATTING = False
|
115 |
+
cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED = True
|
116 |
+
cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE = True
|
117 |
+
cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT = 0.7
|
118 |
+
# for mask prompt
|
119 |
+
cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH = 3
|
120 |
+
cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD = False
|
121 |
+
|
122 |
+
# wandb
|
123 |
+
cfg.WANDB = CN()
|
124 |
+
cfg.WANDB.PROJECT = "open_vocab_seg"
|
125 |
+
cfg.WANDB.NAME = None
|
126 |
+
|
127 |
+
|
128 |
+
def add_ovseg_config(cfg):
|
129 |
+
"""
|
130 |
+
Add config for open_vocab_seg.
|
131 |
+
"""
|
132 |
+
add_mask_former_default_config(cfg)
|
133 |
+
add_our_config(cfg)
|
open_vocab_seg/data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
open_vocab_seg/data/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from .dataset_mappers import *
|
5 |
+
from . import datasets
|
6 |
+
from .build import (
|
7 |
+
build_detection_train_loader,
|
8 |
+
build_detection_test_loader,
|
9 |
+
)
|
open_vocab_seg/data/augmentations.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import math
|
5 |
+
import numbers
|
6 |
+
import numpy as np
|
7 |
+
from detectron2.data.transforms.augmentation import Augmentation
|
8 |
+
from detectron2.data.transforms.transform import (
|
9 |
+
CropTransform,
|
10 |
+
ResizeTransform,
|
11 |
+
TransformList,
|
12 |
+
)
|
13 |
+
from PIL import Image
|
14 |
+
from fvcore.transforms.transform import PadTransform
|
15 |
+
|
16 |
+
|
17 |
+
def mask2box(mask: np.ndarray):
|
18 |
+
# use naive way
|
19 |
+
row = np.nonzero(mask.sum(axis=0))[0]
|
20 |
+
if len(row) == 0:
|
21 |
+
return None
|
22 |
+
x1 = row.min()
|
23 |
+
x2 = row.max()
|
24 |
+
col = np.nonzero(mask.sum(axis=1))[0]
|
25 |
+
y1 = col.min()
|
26 |
+
y2 = col.max()
|
27 |
+
return x1, y1, x2 + 1 - x1, y2 + 1 - y1
|
28 |
+
|
29 |
+
|
30 |
+
def expand_box(x, y, w, h, expand_ratio=1.0, max_h=None, max_w=None):
|
31 |
+
cx = x + 0.5 * w
|
32 |
+
cy = y + 0.5 * h
|
33 |
+
w = w * expand_ratio
|
34 |
+
h = h * expand_ratio
|
35 |
+
box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h]
|
36 |
+
if max_h is not None:
|
37 |
+
box[1] = max(0, box[1])
|
38 |
+
box[3] = min(max_h - 1, box[3])
|
39 |
+
if max_w is not None:
|
40 |
+
box[0] = max(0, box[0])
|
41 |
+
box[2] = min(max_w - 1, box[2])
|
42 |
+
box[2] = box[2] - box[0]
|
43 |
+
box[3] = box[3] - box[1]
|
44 |
+
|
45 |
+
return [int(b) for b in box]
|
46 |
+
|
47 |
+
|
48 |
+
class CropImageWithMask(Augmentation):
|
49 |
+
def __init__(self, expand_ratio=1.0, mode="choice"):
|
50 |
+
if isinstance(expand_ratio, numbers.Number):
|
51 |
+
expand_ratio = (expand_ratio, expand_ratio)
|
52 |
+
self.mode = mode
|
53 |
+
self.expand_ratio = expand_ratio
|
54 |
+
if self.mode == "range":
|
55 |
+
assert len(expand_ratio) == 2 and expand_ratio[0] < expand_ratio[1]
|
56 |
+
|
57 |
+
def get_transform(self, image, sem_seg, category_id):
|
58 |
+
input_size = image.shape[:2]
|
59 |
+
bin_mask = sem_seg == category_id
|
60 |
+
x, y, w, h = mask2box(bin_mask)
|
61 |
+
if self.mode == "choice":
|
62 |
+
expand_ratio = np.random.choice(self.expand_ratio)
|
63 |
+
else:
|
64 |
+
expand_ratio = np.random.uniform(self.expand_ratio[0], self.expand_ratio[1])
|
65 |
+
x, y, w, h = expand_box(x, y, w, h, expand_ratio, *input_size)
|
66 |
+
w = max(w, 1)
|
67 |
+
h = max(h, 1)
|
68 |
+
return CropTransform(x, y, w, h, input_size[1], input_size[0])
|
69 |
+
|
70 |
+
|
71 |
+
class CropImageWithBox(Augmentation):
|
72 |
+
def __init__(self, expand_ratio=1.0, mode="choice"):
|
73 |
+
if isinstance(expand_ratio, numbers.Number):
|
74 |
+
expand_ratio = (expand_ratio, expand_ratio)
|
75 |
+
self.mode = mode
|
76 |
+
self.expand_ratio = expand_ratio
|
77 |
+
if self.mode == "range":
|
78 |
+
assert len(expand_ratio) == 2 and expand_ratio[0] < expand_ratio[1]
|
79 |
+
|
80 |
+
def get_transform(self, image, boxes):
|
81 |
+
input_size = image.shape[:2]
|
82 |
+
x, y, x2, y2 = boxes[0]
|
83 |
+
w = x2 - x + 1
|
84 |
+
h = y2 - y + 1
|
85 |
+
if self.mode == "choice":
|
86 |
+
expand_ratio = np.random.choice(self.expand_ratio)
|
87 |
+
else:
|
88 |
+
expand_ratio = np.random.uniform(self.expand_ratio[0], self.expand_ratio[1])
|
89 |
+
x, y, w, h = expand_box(x, y, w, h, expand_ratio, *input_size)
|
90 |
+
w = max(w, 1)
|
91 |
+
h = max(h, 1)
|
92 |
+
return CropTransform(x, y, w, h, input_size[1], input_size[0])
|
93 |
+
|
94 |
+
|
95 |
+
class RandomResizedCrop(Augmentation):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
size,
|
99 |
+
scale=(0.08, 1.0),
|
100 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
101 |
+
interpolation=Image.BILINEAR,
|
102 |
+
):
|
103 |
+
if isinstance(size, int):
|
104 |
+
size = (size, size)
|
105 |
+
else:
|
106 |
+
assert isinstance(size, (tuple, list)) and len(size) == 2
|
107 |
+
|
108 |
+
self.size = size
|
109 |
+
|
110 |
+
self.scale = scale
|
111 |
+
self.ratio = ratio
|
112 |
+
self.interpolation = interpolation
|
113 |
+
|
114 |
+
def get_transform(self, image):
|
115 |
+
height, width = image.shape[:2]
|
116 |
+
area = height * width
|
117 |
+
|
118 |
+
log_ratio = np.log(np.array(self.ratio))
|
119 |
+
is_success = False
|
120 |
+
for _ in range(10):
|
121 |
+
target_area = area * np.random.uniform(self.scale[0], self.scale[1])
|
122 |
+
aspect_ratio = np.exp(np.random.uniform(log_ratio[0], log_ratio[1]))
|
123 |
+
|
124 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
125 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
126 |
+
|
127 |
+
if 0 < w <= width and 0 < h <= height:
|
128 |
+
i = np.random.randint(0, width - w + 1)
|
129 |
+
j = np.random.randint(0, height - h + 1)
|
130 |
+
|
131 |
+
is_success = True
|
132 |
+
break
|
133 |
+
|
134 |
+
if not is_success:
|
135 |
+
# Fallback to central crop
|
136 |
+
in_ratio = float(width) / float(height)
|
137 |
+
if in_ratio < min(self.ratio):
|
138 |
+
w = width
|
139 |
+
h = int(round(w / min(self.ratio)))
|
140 |
+
elif in_ratio > max(self.ratio):
|
141 |
+
h = height
|
142 |
+
w = int(round(h * max(self.ratio)))
|
143 |
+
else: # whole image
|
144 |
+
w = width
|
145 |
+
h = height
|
146 |
+
i = (width - w) // 2
|
147 |
+
j = (height - h) // 2
|
148 |
+
return TransformList(
|
149 |
+
[
|
150 |
+
CropTransform(i, j, w, h, width, height),
|
151 |
+
ResizeTransform(
|
152 |
+
h, w, self.size[1], self.size[0], interp=self.interpolation
|
153 |
+
),
|
154 |
+
]
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
class CenterCrop(Augmentation):
|
159 |
+
def __init__(self, size, seg_ignore_label):
|
160 |
+
if isinstance(size, numbers.Number):
|
161 |
+
size = (int(size), int(size))
|
162 |
+
elif isinstance(size, (tuple, list)) and len(size) == 1:
|
163 |
+
size = (size[0], size[0])
|
164 |
+
self.size = size
|
165 |
+
self.seg_ignore_label = seg_ignore_label
|
166 |
+
|
167 |
+
def get_transform(self, image):
|
168 |
+
|
169 |
+
image_height, image_width = image.shape[:2]
|
170 |
+
crop_height, crop_width = self.size
|
171 |
+
|
172 |
+
transforms = []
|
173 |
+
if crop_width > image_width or crop_height > image_height:
|
174 |
+
padding_ltrb = [
|
175 |
+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
176 |
+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
177 |
+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
178 |
+
(crop_height - image_height + 1) // 2
|
179 |
+
if crop_height > image_height
|
180 |
+
else 0,
|
181 |
+
]
|
182 |
+
transforms.append(
|
183 |
+
PadTransform(
|
184 |
+
*padding_ltrb,
|
185 |
+
orig_w=image_width,
|
186 |
+
orig_h=image_height,
|
187 |
+
seg_pad_value=self.seg_ignore_label
|
188 |
+
)
|
189 |
+
)
|
190 |
+
image_width, image_height = (
|
191 |
+
image_width + padding_ltrb[0] + padding_ltrb[2],
|
192 |
+
image_height + padding_ltrb[1] + padding_ltrb[3],
|
193 |
+
)
|
194 |
+
|
195 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
196 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
197 |
+
transforms.append(
|
198 |
+
CropTransform(
|
199 |
+
crop_left, crop_top, crop_width, crop_height, image_width, image_height
|
200 |
+
)
|
201 |
+
)
|
202 |
+
return TransformList(transforms)
|
open_vocab_seg/data/build.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import itertools
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
from collections import Counter
|
8 |
+
import torch.utils.data
|
9 |
+
from tabulate import tabulate
|
10 |
+
from termcolor import colored
|
11 |
+
|
12 |
+
from detectron2.utils.logger import _log_api_usage, log_first_n
|
13 |
+
from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
|
14 |
+
import torch.utils.data
|
15 |
+
from detectron2.config import configurable
|
16 |
+
from detectron2.data.build import (
|
17 |
+
build_batch_data_loader,
|
18 |
+
trivial_batch_collator,
|
19 |
+
load_proposals_into_dataset,
|
20 |
+
filter_images_with_only_crowd_annotations,
|
21 |
+
filter_images_with_few_keypoints,
|
22 |
+
print_instances_class_histogram,
|
23 |
+
)
|
24 |
+
|
25 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
26 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
27 |
+
from detectron2.data.detection_utils import check_metadata_consistency
|
28 |
+
from detectron2.data.samplers import (
|
29 |
+
InferenceSampler,
|
30 |
+
RandomSubsetTrainingSampler,
|
31 |
+
RepeatFactorTrainingSampler,
|
32 |
+
TrainingSampler,
|
33 |
+
)
|
34 |
+
|
35 |
+
"""
|
36 |
+
This file contains the default logic to build a dataloader for training or testing.
|
37 |
+
"""
|
38 |
+
|
39 |
+
__all__ = [
|
40 |
+
"build_detection_train_loader",
|
41 |
+
"build_detection_test_loader",
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
def print_classification_instances_class_histogram(dataset_dicts, class_names):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
dataset_dicts (list[dict]): list of dataset dicts.
|
49 |
+
class_names (list[str]): list of class names (zero-indexed).
|
50 |
+
"""
|
51 |
+
num_classes = len(class_names)
|
52 |
+
hist_bins = np.arange(num_classes + 1)
|
53 |
+
histogram = np.zeros((num_classes,), dtype=np.int)
|
54 |
+
for entry in dataset_dicts:
|
55 |
+
classes = np.asarray([entry["category_id"]], dtype=np.int)
|
56 |
+
if len(classes):
|
57 |
+
assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
|
58 |
+
assert (
|
59 |
+
classes.max() < num_classes
|
60 |
+
), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
|
61 |
+
histogram += np.histogram(classes, bins=hist_bins)[0]
|
62 |
+
|
63 |
+
N_COLS = min(6, len(class_names) * 2)
|
64 |
+
|
65 |
+
def short_name(x):
|
66 |
+
# make long class names shorter. useful for lvis
|
67 |
+
if len(x) > 13:
|
68 |
+
return x[:11] + ".."
|
69 |
+
return x
|
70 |
+
|
71 |
+
data = list(
|
72 |
+
itertools.chain(
|
73 |
+
*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]
|
74 |
+
)
|
75 |
+
)
|
76 |
+
total_num_instances = sum(data[1::2])
|
77 |
+
data.extend([None] * (N_COLS - (len(data) % N_COLS)))
|
78 |
+
if num_classes > 1:
|
79 |
+
data.extend(["total", total_num_instances])
|
80 |
+
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
|
81 |
+
table = tabulate(
|
82 |
+
data,
|
83 |
+
headers=["category", "#instances"] * (N_COLS // 2),
|
84 |
+
tablefmt="pipe",
|
85 |
+
numalign="left",
|
86 |
+
stralign="center",
|
87 |
+
)
|
88 |
+
log_first_n(
|
89 |
+
logging.INFO,
|
90 |
+
"Distribution of instances among all {} categories:\n".format(num_classes)
|
91 |
+
+ colored(table, "cyan"),
|
92 |
+
key="message",
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def wrap_metas(dataset_dict, **kwargs):
|
97 |
+
def _assign_attr(data_dict: dict, **kwargs):
|
98 |
+
assert not any(
|
99 |
+
[key in data_dict for key in kwargs]
|
100 |
+
), "Assigned attributes should not exist in the original sample."
|
101 |
+
data_dict.update(kwargs)
|
102 |
+
return data_dict
|
103 |
+
|
104 |
+
return [_assign_attr(sample, meta=kwargs) for sample in dataset_dict]
|
105 |
+
|
106 |
+
|
107 |
+
def get_detection_dataset_dicts(
|
108 |
+
names, filter_empty=True, min_keypoints=0, proposal_files=None
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
names (str or list[str]): a dataset name or a list of dataset names
|
115 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
116 |
+
min_keypoints (int): filter out images with fewer keypoints than
|
117 |
+
`min_keypoints`. Set to 0 to do nothing.
|
118 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
119 |
+
that match each dataset in `names`.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
123 |
+
"""
|
124 |
+
if isinstance(names, str):
|
125 |
+
names = [names]
|
126 |
+
assert len(names), names
|
127 |
+
dataset_dicts = [
|
128 |
+
wrap_metas(DatasetCatalog.get(dataset_name), dataset_name=dataset_name)
|
129 |
+
for dataset_name in names
|
130 |
+
]
|
131 |
+
for dataset_name, dicts in zip(names, dataset_dicts):
|
132 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
133 |
+
|
134 |
+
if proposal_files is not None:
|
135 |
+
assert len(names) == len(proposal_files)
|
136 |
+
# load precomputed proposals from proposal files
|
137 |
+
dataset_dicts = [
|
138 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
139 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
140 |
+
]
|
141 |
+
|
142 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
143 |
+
|
144 |
+
has_instances = "annotations" in dataset_dicts[0]
|
145 |
+
if filter_empty and has_instances:
|
146 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
147 |
+
if min_keypoints > 0 and has_instances:
|
148 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
149 |
+
|
150 |
+
if has_instances:
|
151 |
+
try:
|
152 |
+
class_names = MetadataCatalog.get(names[0]).thing_classes
|
153 |
+
check_metadata_consistency("thing_classes", names)
|
154 |
+
print_instances_class_histogram(dataset_dicts, class_names)
|
155 |
+
except AttributeError: # class names are not available for this dataset
|
156 |
+
pass
|
157 |
+
|
158 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
|
159 |
+
return dataset_dicts
|
160 |
+
|
161 |
+
|
162 |
+
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
163 |
+
if dataset is None:
|
164 |
+
dataset = get_detection_dataset_dicts(
|
165 |
+
cfg.DATASETS.TRAIN,
|
166 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
167 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
168 |
+
if cfg.MODEL.KEYPOINT_ON
|
169 |
+
else 0,
|
170 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN
|
171 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
172 |
+
else None,
|
173 |
+
)
|
174 |
+
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
|
175 |
+
|
176 |
+
if mapper is None:
|
177 |
+
mapper = DatasetMapper(cfg, True)
|
178 |
+
|
179 |
+
if sampler is None:
|
180 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
181 |
+
logger = logging.getLogger(__name__)
|
182 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
183 |
+
if sampler_name == "TrainingSampler":
|
184 |
+
sampler = TrainingSampler(len(dataset))
|
185 |
+
elif sampler_name == "RepeatFactorTrainingSampler":
|
186 |
+
repeat_factors = (
|
187 |
+
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
188 |
+
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
|
189 |
+
)
|
190 |
+
)
|
191 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
192 |
+
elif sampler_name == "RandomSubsetTrainingSampler":
|
193 |
+
sampler = RandomSubsetTrainingSampler(
|
194 |
+
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
198 |
+
|
199 |
+
return {
|
200 |
+
"dataset": dataset,
|
201 |
+
"sampler": sampler,
|
202 |
+
"mapper": mapper,
|
203 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
204 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
205 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
206 |
+
}
|
207 |
+
|
208 |
+
|
209 |
+
# TODO can allow dataset as an iterable or IterableDataset to make this function more general
|
210 |
+
@configurable(from_config=_train_loader_from_config)
|
211 |
+
def build_detection_train_loader(
|
212 |
+
dataset,
|
213 |
+
*,
|
214 |
+
mapper,
|
215 |
+
sampler=None,
|
216 |
+
total_batch_size,
|
217 |
+
aspect_ratio_grouping=True,
|
218 |
+
num_workers=0,
|
219 |
+
):
|
220 |
+
"""
|
221 |
+
Build a dataloader for object detection with some default features.
|
222 |
+
This interface is experimental.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
226 |
+
or a map-style pytorch dataset. They can be obtained by using
|
227 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
228 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
229 |
+
returns the format to be consumed by the model.
|
230 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
231 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
|
232 |
+
indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
|
233 |
+
which coordinates an infinite random shuffle sequence across all workers.
|
234 |
+
total_batch_size (int): total batch size across all workers. Batching
|
235 |
+
simply puts data into a list.
|
236 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
237 |
+
aspect ratio for efficiency. When enabled, it requires each
|
238 |
+
element in dataset be a dict with keys "width" and "height".
|
239 |
+
num_workers (int): number of parallel data loading workers
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
torch.utils.data.DataLoader:
|
243 |
+
a dataloader. Each output from it is a ``list[mapped_element]`` of length
|
244 |
+
``total_batch_size / num_workers``, where ``mapped_element`` is produced
|
245 |
+
by the ``mapper``.
|
246 |
+
"""
|
247 |
+
if isinstance(dataset, list):
|
248 |
+
dataset = DatasetFromList(dataset, copy=False)
|
249 |
+
if mapper is not None:
|
250 |
+
dataset = MapDataset(dataset, mapper)
|
251 |
+
if sampler is None:
|
252 |
+
sampler = TrainingSampler(len(dataset))
|
253 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
254 |
+
return build_batch_data_loader(
|
255 |
+
dataset,
|
256 |
+
sampler,
|
257 |
+
total_batch_size,
|
258 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
259 |
+
num_workers=num_workers,
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
264 |
+
"""
|
265 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
266 |
+
standard practice is to evaluate each test set individually (not combining them).
|
267 |
+
"""
|
268 |
+
if isinstance(dataset_name, str):
|
269 |
+
dataset_name = [dataset_name]
|
270 |
+
|
271 |
+
dataset = get_detection_dataset_dicts(
|
272 |
+
dataset_name,
|
273 |
+
filter_empty=False,
|
274 |
+
proposal_files=[
|
275 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)]
|
276 |
+
for x in dataset_name
|
277 |
+
]
|
278 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
279 |
+
else None,
|
280 |
+
)
|
281 |
+
if mapper is None:
|
282 |
+
mapper = DatasetMapper(cfg, False)
|
283 |
+
return {
|
284 |
+
"dataset": dataset,
|
285 |
+
"mapper": mapper,
|
286 |
+
"num_workers": 0,
|
287 |
+
"samples_per_gpu": cfg.SOLVER.TEST_IMS_PER_BATCH,
|
288 |
+
}
|
289 |
+
|
290 |
+
|
291 |
+
@configurable(from_config=_test_loader_from_config)
|
292 |
+
def build_detection_test_loader(
|
293 |
+
dataset, *, mapper, sampler=None, num_workers=0, samples_per_gpu=1
|
294 |
+
):
|
295 |
+
"""
|
296 |
+
Similar to `build_detection_train_loader`, but uses a batch size of 1,
|
297 |
+
and :class:`InferenceSampler`. This sampler coordinates all workers to
|
298 |
+
produce the exact set of all samples.
|
299 |
+
This interface is experimental.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
303 |
+
or a map-style pytorch dataset. They can be obtained by using
|
304 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
305 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
306 |
+
and returns the format to be consumed by the model.
|
307 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
308 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
|
309 |
+
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
|
310 |
+
which splits the dataset across all workers.
|
311 |
+
num_workers (int): number of parallel data loading workers
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
315 |
+
dataset, with test-time transformation and batching.
|
316 |
+
|
317 |
+
Examples:
|
318 |
+
::
|
319 |
+
data_loader = build_detection_test_loader(
|
320 |
+
DatasetRegistry.get("my_test"),
|
321 |
+
mapper=DatasetMapper(...))
|
322 |
+
|
323 |
+
# or, instantiate with a CfgNode:
|
324 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
325 |
+
"""
|
326 |
+
if isinstance(dataset, list):
|
327 |
+
dataset = DatasetFromList(dataset, copy=False)
|
328 |
+
if mapper is not None:
|
329 |
+
dataset = MapDataset(dataset, mapper)
|
330 |
+
if sampler is None:
|
331 |
+
sampler = InferenceSampler(len(dataset))
|
332 |
+
# Always use 1 image per worker during inference since this is the
|
333 |
+
# standard when reporting inference time in papers.
|
334 |
+
batch_sampler = torch.utils.data.sampler.BatchSampler(
|
335 |
+
sampler, samples_per_gpu, drop_last=False
|
336 |
+
)
|
337 |
+
data_loader = torch.utils.data.DataLoader(
|
338 |
+
dataset,
|
339 |
+
num_workers=num_workers,
|
340 |
+
batch_sampler=batch_sampler,
|
341 |
+
collate_fn=trivial_batch_collator,
|
342 |
+
)
|
343 |
+
return data_loader
|
344 |
+
|
open_vocab_seg/data/dataset_mappers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
|
open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import copy
|
5 |
+
import logging
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from detectron2.config import configurable
|
12 |
+
from detectron2.data import MetadataCatalog
|
13 |
+
from detectron2.data import detection_utils as utils
|
14 |
+
from detectron2.data import transforms as T
|
15 |
+
from detectron2.projects.point_rend import ColorAugSSDTransform
|
16 |
+
from detectron2.structures import BitMasks, Instances
|
17 |
+
|
18 |
+
__all__ = ["MaskFormerSemanticDatasetMapper"]
|
19 |
+
|
20 |
+
|
21 |
+
class MaskFormerSemanticDatasetMapper:
|
22 |
+
"""
|
23 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
24 |
+
and map it into a format used by MaskFormer for semantic segmentation.
|
25 |
+
|
26 |
+
The callable currently does the following:
|
27 |
+
|
28 |
+
1. Read the image from "file_name"
|
29 |
+
2. Applies geometric transforms to the image and annotation
|
30 |
+
3. Find and applies suitable cropping to the image and annotation
|
31 |
+
4. Prepare image and annotation to Tensors
|
32 |
+
"""
|
33 |
+
|
34 |
+
@configurable
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
is_train=True,
|
38 |
+
*,
|
39 |
+
augmentations,
|
40 |
+
image_format,
|
41 |
+
ignore_label,
|
42 |
+
size_divisibility,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
NOTE: this interface is experimental.
|
46 |
+
Args:
|
47 |
+
is_train: for training or inference
|
48 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
49 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
50 |
+
ignore_label: the label that is ignored to evaluation
|
51 |
+
size_divisibility: pad image size to be divisible by this value
|
52 |
+
"""
|
53 |
+
self.is_train = is_train
|
54 |
+
self.tfm_gens = augmentations
|
55 |
+
self.img_format = image_format
|
56 |
+
self.ignore_label = ignore_label
|
57 |
+
self.size_divisibility = size_divisibility
|
58 |
+
|
59 |
+
logger = logging.getLogger(__name__)
|
60 |
+
mode = "training" if is_train else "inference"
|
61 |
+
logger.info(
|
62 |
+
f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}"
|
63 |
+
)
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def from_config(cls, cfg, is_train=True):
|
67 |
+
# Build augmentation
|
68 |
+
if is_train:
|
69 |
+
augs = [
|
70 |
+
T.ResizeShortestEdge(
|
71 |
+
cfg.INPUT.MIN_SIZE_TRAIN,
|
72 |
+
cfg.INPUT.MAX_SIZE_TRAIN,
|
73 |
+
cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
|
74 |
+
)
|
75 |
+
]
|
76 |
+
if cfg.INPUT.CROP.ENABLED:
|
77 |
+
augs.append(
|
78 |
+
T.RandomCrop_CategoryAreaConstraint(
|
79 |
+
cfg.INPUT.CROP.TYPE,
|
80 |
+
cfg.INPUT.CROP.SIZE,
|
81 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
|
82 |
+
cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
83 |
+
)
|
84 |
+
)
|
85 |
+
if cfg.INPUT.COLOR_AUG_SSD:
|
86 |
+
augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
|
87 |
+
augs.append(T.RandomFlip())
|
88 |
+
|
89 |
+
# Assume always applies to the training set.
|
90 |
+
dataset_names = cfg.DATASETS.TRAIN
|
91 |
+
else:
|
92 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
93 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
94 |
+
sample_style = "choice"
|
95 |
+
augs = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
|
96 |
+
dataset_names = cfg.DATASETS.TEST
|
97 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
98 |
+
ignore_label = meta.ignore_label
|
99 |
+
|
100 |
+
ret = {
|
101 |
+
"is_train": is_train,
|
102 |
+
"augmentations": augs,
|
103 |
+
"image_format": cfg.INPUT.FORMAT,
|
104 |
+
"ignore_label": ignore_label,
|
105 |
+
"size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY if is_train else -1,
|
106 |
+
}
|
107 |
+
return ret
|
108 |
+
|
109 |
+
def __call__(self, dataset_dict):
|
110 |
+
"""
|
111 |
+
Args:
|
112 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
dict: a format that builtin models in detectron2 accept
|
116 |
+
"""
|
117 |
+
# assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
|
118 |
+
|
119 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
120 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
121 |
+
utils.check_image_size(dataset_dict, image)
|
122 |
+
|
123 |
+
if "sem_seg_file_name" in dataset_dict:
|
124 |
+
# PyTorch transformation not implemented for uint16, so converting it to double first
|
125 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype(
|
126 |
+
"double"
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
sem_seg_gt = None
|
130 |
+
|
131 |
+
if sem_seg_gt is None:
|
132 |
+
raise ValueError(
|
133 |
+
"Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
|
134 |
+
dataset_dict["file_name"]
|
135 |
+
)
|
136 |
+
)
|
137 |
+
|
138 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
139 |
+
aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
|
140 |
+
image = aug_input.image
|
141 |
+
sem_seg_gt = aug_input.sem_seg
|
142 |
+
|
143 |
+
# Pad image and segmentation label here!
|
144 |
+
image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
145 |
+
if sem_seg_gt is not None:
|
146 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
147 |
+
|
148 |
+
if self.size_divisibility > 0:
|
149 |
+
image_size = (image.shape[-2], image.shape[-1])
|
150 |
+
padding_size = [
|
151 |
+
0,
|
152 |
+
self.size_divisibility - image_size[1],
|
153 |
+
0,
|
154 |
+
self.size_divisibility - image_size[0],
|
155 |
+
]
|
156 |
+
image = F.pad(image, padding_size, value=128).contiguous()
|
157 |
+
if sem_seg_gt is not None:
|
158 |
+
sem_seg_gt = F.pad(
|
159 |
+
sem_seg_gt, padding_size, value=self.ignore_label
|
160 |
+
).contiguous()
|
161 |
+
|
162 |
+
image_shape = (image.shape[-2], image.shape[-1]) # h, w
|
163 |
+
|
164 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
165 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
166 |
+
# Therefore it's important to use torch.Tensor.
|
167 |
+
dataset_dict["image"] = image
|
168 |
+
|
169 |
+
if sem_seg_gt is not None:
|
170 |
+
dataset_dict["sem_seg"] = sem_seg_gt.long()
|
171 |
+
|
172 |
+
if "annotations" in dataset_dict:
|
173 |
+
raise ValueError(
|
174 |
+
"Semantic segmentation dataset should not have 'annotations'."
|
175 |
+
)
|
176 |
+
|
177 |
+
# Prepare per-category binary masks
|
178 |
+
if sem_seg_gt is not None:
|
179 |
+
sem_seg_gt = sem_seg_gt.numpy()
|
180 |
+
instances = Instances(image_shape)
|
181 |
+
classes = np.unique(sem_seg_gt)
|
182 |
+
# remove ignored region
|
183 |
+
classes = classes[classes != self.ignore_label]
|
184 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
185 |
+
|
186 |
+
masks = []
|
187 |
+
for class_id in classes:
|
188 |
+
masks.append(sem_seg_gt == class_id)
|
189 |
+
|
190 |
+
if len(masks) == 0:
|
191 |
+
# Some image does not have annotation (all ignored)
|
192 |
+
instances.gt_masks = torch.zeros(
|
193 |
+
(0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
masks = BitMasks(
|
197 |
+
torch.stack(
|
198 |
+
[
|
199 |
+
torch.from_numpy(np.ascontiguousarray(x.copy()))
|
200 |
+
for x in masks
|
201 |
+
]
|
202 |
+
)
|
203 |
+
)
|
204 |
+
instances.gt_masks = masks.tensor
|
205 |
+
|
206 |
+
dataset_dict["instances"] = instances
|
207 |
+
|
208 |
+
return dataset_dict
|
open_vocab_seg/data/datasets/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import register_coco_stuff, register_voc_seg
|
3 |
+
from . import register_cc3m
|
4 |
+
from . import register_ade20k_full
|
5 |
+
from . import register_pascal_context
|
open_vocab_seg/data/datasets/csv_data.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import ast
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
import time
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from multiprocessing import Value
|
12 |
+
|
13 |
+
import braceexpand
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
import torch
|
17 |
+
import torchvision.datasets as datasets
|
18 |
+
import webdataset as wds
|
19 |
+
from PIL import Image
|
20 |
+
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
|
21 |
+
from torch.utils.data.distributed import DistributedSampler
|
22 |
+
from webdataset.filters import _shuffle
|
23 |
+
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
|
24 |
+
|
25 |
+
try:
|
26 |
+
import horovod.torch as hvd
|
27 |
+
except ImportError:
|
28 |
+
hvd = None
|
29 |
+
|
30 |
+
from clip import tokenize
|
31 |
+
|
32 |
+
|
33 |
+
class CsvDataset(Dataset):
|
34 |
+
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
|
35 |
+
logging.debug(f'Loading csv data from {input_filename}.')
|
36 |
+
df = pd.read_csv(input_filename, sep=sep)
|
37 |
+
|
38 |
+
self.images = df[img_key].tolist()
|
39 |
+
self.captions = df[caption_key].tolist()
|
40 |
+
self.transforms = transforms
|
41 |
+
logging.debug('Done loading data.')
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.captions)
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
images = self.transforms(Image.open(str(self.images[idx])))
|
48 |
+
texts = tokenize([str(self.captions[idx])])[0]
|
49 |
+
return images, texts
|
50 |
+
|
51 |
+
|
52 |
+
class SharedEpoch:
|
53 |
+
def __init__(self, epoch: int = 0):
|
54 |
+
self.shared_epoch = Value('i', epoch)
|
55 |
+
|
56 |
+
def set_value(self, epoch):
|
57 |
+
self.shared_epoch.value = epoch
|
58 |
+
|
59 |
+
def get_value(self):
|
60 |
+
return self.shared_epoch.value
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass
|
64 |
+
class DataInfo:
|
65 |
+
dataloader: DataLoader
|
66 |
+
sampler: DistributedSampler = None
|
67 |
+
shared_epoch: SharedEpoch = None
|
68 |
+
|
69 |
+
def set_epoch(self, epoch):
|
70 |
+
if self.shared_epoch is not None:
|
71 |
+
self.shared_epoch.set_value(epoch)
|
72 |
+
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
|
73 |
+
self.sampler.set_epoch(epoch)
|
74 |
+
|
75 |
+
|
76 |
+
def preprocess_txt(text):
|
77 |
+
return tokenize([str(text)])[0]
|
78 |
+
|
79 |
+
|
80 |
+
def get_dataset_size(shards):
|
81 |
+
shards_list = list(braceexpand.braceexpand(shards))
|
82 |
+
dir_path = os.path.dirname(shards)
|
83 |
+
sizes_filename = os.path.join(dir_path, 'sizes.json')
|
84 |
+
len_filename = os.path.join(dir_path, '__len__')
|
85 |
+
if os.path.exists(sizes_filename):
|
86 |
+
sizes = json.load(open(sizes_filename, 'r'))
|
87 |
+
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
|
88 |
+
elif os.path.exists(len_filename):
|
89 |
+
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
|
90 |
+
total_size = ast.literal_eval(open(len_filename, 'r').read())
|
91 |
+
else:
|
92 |
+
total_size = None # num samples undefined
|
93 |
+
# some common dataset sizes (at time of authors last download)
|
94 |
+
# CC3M (train): 2905954
|
95 |
+
# CC12M: 10968539
|
96 |
+
# LAION-400M: 407332084
|
97 |
+
# LAION-2B (english): 2170337258
|
98 |
+
num_shards = len(shards_list)
|
99 |
+
return total_size, num_shards
|
100 |
+
|
101 |
+
|
102 |
+
def get_imagenet(args, preprocess_fns, split):
|
103 |
+
assert split in ["train", "val", "v2"]
|
104 |
+
is_train = split == "train"
|
105 |
+
preprocess_train, preprocess_val = preprocess_fns
|
106 |
+
|
107 |
+
if split == "v2":
|
108 |
+
from imagenetv2_pytorch import ImageNetV2Dataset
|
109 |
+
dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
|
110 |
+
else:
|
111 |
+
if is_train:
|
112 |
+
data_path = args.imagenet_train
|
113 |
+
preprocess_fn = preprocess_train
|
114 |
+
else:
|
115 |
+
data_path = args.imagenet_val
|
116 |
+
preprocess_fn = preprocess_val
|
117 |
+
assert data_path
|
118 |
+
|
119 |
+
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
|
120 |
+
|
121 |
+
if is_train:
|
122 |
+
idxs = np.zeros(len(dataset.targets))
|
123 |
+
target_array = np.array(dataset.targets)
|
124 |
+
k = 50
|
125 |
+
for c in range(1000):
|
126 |
+
m = target_array == c
|
127 |
+
n = len(idxs[m])
|
128 |
+
arr = np.zeros(n)
|
129 |
+
arr[:k] = 1
|
130 |
+
np.random.shuffle(arr)
|
131 |
+
idxs[m] = arr
|
132 |
+
|
133 |
+
idxs = idxs.astype('int')
|
134 |
+
sampler = SubsetRandomSampler(np.where(idxs)[0])
|
135 |
+
else:
|
136 |
+
sampler = None
|
137 |
+
|
138 |
+
dataloader = torch.utils.data.DataLoader(
|
139 |
+
dataset,
|
140 |
+
batch_size=args.batch_size,
|
141 |
+
num_workers=args.workers,
|
142 |
+
sampler=sampler,
|
143 |
+
)
|
144 |
+
|
145 |
+
return DataInfo(dataloader=dataloader, sampler=sampler)
|
146 |
+
|
147 |
+
|
148 |
+
def count_samples(dataloader):
|
149 |
+
os.environ["WDS_EPOCH"] = "0"
|
150 |
+
n_elements, n_batches = 0, 0
|
151 |
+
for images, texts in dataloader:
|
152 |
+
n_batches += 1
|
153 |
+
n_elements += len(images)
|
154 |
+
assert len(images) == len(texts)
|
155 |
+
return n_elements, n_batches
|
156 |
+
|
157 |
+
|
158 |
+
def filter_no_caption(sample):
|
159 |
+
return 'txt' in sample
|
160 |
+
|
161 |
+
|
162 |
+
def log_and_continue(exn):
|
163 |
+
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
164 |
+
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
165 |
+
return True
|
166 |
+
|
167 |
+
|
168 |
+
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
|
169 |
+
"""Return function over iterator that groups key, value pairs into samples.
|
170 |
+
|
171 |
+
:param keys: function that splits the key into key and extension (base_plus_ext)
|
172 |
+
:param lcase: convert suffixes to lower case (Default value = True)
|
173 |
+
"""
|
174 |
+
current_sample = None
|
175 |
+
for filesample in data:
|
176 |
+
assert isinstance(filesample, dict)
|
177 |
+
fname, value = filesample["fname"], filesample["data"]
|
178 |
+
prefix, suffix = keys(fname)
|
179 |
+
if prefix is None:
|
180 |
+
continue
|
181 |
+
if lcase:
|
182 |
+
suffix = suffix.lower()
|
183 |
+
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
|
184 |
+
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
|
185 |
+
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
|
186 |
+
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
|
187 |
+
if valid_sample(current_sample):
|
188 |
+
yield current_sample
|
189 |
+
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
190 |
+
if suffixes is None or suffix in suffixes:
|
191 |
+
current_sample[suffix] = value
|
192 |
+
if valid_sample(current_sample):
|
193 |
+
yield current_sample
|
194 |
+
|
195 |
+
|
196 |
+
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
|
197 |
+
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
|
198 |
+
streams = url_opener(src, handler=handler)
|
199 |
+
files = tar_file_expander(streams, handler=handler)
|
200 |
+
samples = group_by_keys_nothrow(files, handler=handler)
|
201 |
+
return samples
|
202 |
+
|
203 |
+
|
204 |
+
def pytorch_worker_seed():
|
205 |
+
"""get dataloader worker seed from pytorch"""
|
206 |
+
worker_info = get_worker_info()
|
207 |
+
if worker_info is not None:
|
208 |
+
# favour the seed already created for pytorch dataloader workers if it exists
|
209 |
+
return worker_info.seed
|
210 |
+
# fallback to wds rank based seed
|
211 |
+
return wds.utils.pytorch_worker_seed()
|
212 |
+
|
213 |
+
|
214 |
+
_SHARD_SHUFFLE_SIZE = 2000
|
215 |
+
_SHARD_SHUFFLE_INITIAL = 500
|
216 |
+
_SAMPLE_SHUFFLE_SIZE = 5000
|
217 |
+
_SAMPLE_SHUFFLE_INITIAL = 1000
|
218 |
+
|
219 |
+
|
220 |
+
class detshuffle2(wds.PipelineStage):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
bufsize=1000,
|
224 |
+
initial=100,
|
225 |
+
seed=0,
|
226 |
+
epoch=-1,
|
227 |
+
):
|
228 |
+
self.bufsize = bufsize
|
229 |
+
self.initial = initial
|
230 |
+
self.seed = seed
|
231 |
+
self.epoch = epoch
|
232 |
+
|
233 |
+
def run(self, src):
|
234 |
+
if isinstance(self.epoch, SharedEpoch):
|
235 |
+
epoch = self.epoch.get_value()
|
236 |
+
else:
|
237 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
238 |
+
# situation as different workers may wrap at different times (or not at all).
|
239 |
+
self.epoch += 1
|
240 |
+
epoch = self.epoch
|
241 |
+
rng = random.Random()
|
242 |
+
if self.seed < 0:
|
243 |
+
seed = pytorch_worker_seed() + epoch
|
244 |
+
else:
|
245 |
+
seed = self.seed + epoch
|
246 |
+
rng.seed(seed)
|
247 |
+
return _shuffle(src, self.bufsize, self.initial, rng)
|
248 |
+
|
249 |
+
|
250 |
+
class ResampledShards2(IterableDataset):
|
251 |
+
"""An iterable dataset yielding a list of urls."""
|
252 |
+
|
253 |
+
def __init__(
|
254 |
+
self,
|
255 |
+
urls,
|
256 |
+
nshards=sys.maxsize,
|
257 |
+
worker_seed=None,
|
258 |
+
deterministic=False,
|
259 |
+
epoch=-1,
|
260 |
+
):
|
261 |
+
"""Sample shards from the shard list with replacement.
|
262 |
+
|
263 |
+
:param urls: a list of URLs as a Python list or brace notation string
|
264 |
+
"""
|
265 |
+
super().__init__()
|
266 |
+
urls = wds.shardlists.expand_urls(urls)
|
267 |
+
self.urls = urls
|
268 |
+
assert isinstance(self.urls[0], str)
|
269 |
+
self.nshards = nshards
|
270 |
+
self.rng = random.Random()
|
271 |
+
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
|
272 |
+
self.deterministic = deterministic
|
273 |
+
self.epoch = epoch
|
274 |
+
|
275 |
+
def __iter__(self):
|
276 |
+
"""Return an iterator over the shards."""
|
277 |
+
if isinstance(self.epoch, SharedEpoch):
|
278 |
+
epoch = self.epoch.get_value()
|
279 |
+
else:
|
280 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
281 |
+
# situation as different workers may wrap at different times (or not at all).
|
282 |
+
self.epoch += 1
|
283 |
+
epoch = self.epoch
|
284 |
+
if self.deterministic:
|
285 |
+
# reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
|
286 |
+
self.rng.seed(self.worker_seed() + epoch)
|
287 |
+
for _ in range(self.nshards):
|
288 |
+
yield dict(url=self.rng.choice(self.urls))
|
289 |
+
|
290 |
+
|
291 |
+
def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False):
|
292 |
+
input_shards = args.train_data if is_train else args.val_data
|
293 |
+
assert input_shards is not None
|
294 |
+
resampled = getattr(args, 'dataset_resampled', False) and is_train
|
295 |
+
|
296 |
+
num_samples, num_shards = get_dataset_size(input_shards)
|
297 |
+
if not num_samples:
|
298 |
+
if is_train:
|
299 |
+
num_samples = args.train_num_samples
|
300 |
+
if not num_samples:
|
301 |
+
raise RuntimeError(
|
302 |
+
'Currently, number of dataset samples must be specified for training dataset. '
|
303 |
+
'Please specify via `--train-num-samples` if no dataset length info present.')
|
304 |
+
else:
|
305 |
+
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified
|
306 |
+
|
307 |
+
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
|
308 |
+
if resampled:
|
309 |
+
pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)]
|
310 |
+
else:
|
311 |
+
pipeline = [wds.SimpleShardList(input_shards)]
|
312 |
+
|
313 |
+
# at this point we have an iterator over all the shards
|
314 |
+
if is_train:
|
315 |
+
if not resampled:
|
316 |
+
pipeline.extend([
|
317 |
+
detshuffle2(
|
318 |
+
bufsize=_SHARD_SHUFFLE_SIZE,
|
319 |
+
initial=_SHARD_SHUFFLE_INITIAL,
|
320 |
+
seed=args.seed,
|
321 |
+
epoch=shared_epoch,
|
322 |
+
),
|
323 |
+
wds.split_by_node,
|
324 |
+
wds.split_by_worker,
|
325 |
+
])
|
326 |
+
pipeline.extend([
|
327 |
+
# at this point, we have an iterator over the shards assigned to each worker at each node
|
328 |
+
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
|
329 |
+
wds.shuffle(
|
330 |
+
bufsize=_SAMPLE_SHUFFLE_SIZE,
|
331 |
+
initial=_SAMPLE_SHUFFLE_INITIAL,
|
332 |
+
),
|
333 |
+
])
|
334 |
+
else:
|
335 |
+
pipeline.extend([
|
336 |
+
wds.split_by_worker,
|
337 |
+
# at this point, we have an iterator over the shards assigned to each worker
|
338 |
+
wds.tarfile_to_samples(handler=log_and_continue),
|
339 |
+
])
|
340 |
+
pipeline.extend([
|
341 |
+
wds.select(filter_no_caption),
|
342 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
343 |
+
wds.rename(image="jpg;png", text="txt"),
|
344 |
+
wds.map_dict(image=preprocess_img, text=preprocess_txt),
|
345 |
+
wds.to_tuple("image", "text"),
|
346 |
+
wds.batched(args.batch_size, partial=not is_train),
|
347 |
+
])
|
348 |
+
|
349 |
+
dataset = wds.DataPipeline(*pipeline)
|
350 |
+
if is_train:
|
351 |
+
if not resampled:
|
352 |
+
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
|
353 |
+
# roll over and repeat a few samples to get same number of full batches on each node
|
354 |
+
round_fn = math.floor if floor else math.ceil
|
355 |
+
global_batch_size = args.batch_size * args.world_size
|
356 |
+
num_batches = round_fn(num_samples / global_batch_size)
|
357 |
+
num_workers = max(1, args.workers)
|
358 |
+
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
|
359 |
+
num_batches = num_worker_batches * num_workers
|
360 |
+
num_samples = num_batches * global_batch_size
|
361 |
+
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
|
362 |
+
else:
|
363 |
+
# last batches are partial, eval is done on single (master) node
|
364 |
+
num_batches = math.ceil(num_samples / args.batch_size)
|
365 |
+
|
366 |
+
dataloader = wds.WebLoader(
|
367 |
+
dataset,
|
368 |
+
batch_size=None,
|
369 |
+
shuffle=False,
|
370 |
+
num_workers=args.workers,
|
371 |
+
persistent_workers=True,
|
372 |
+
)
|
373 |
+
|
374 |
+
# FIXME not clear which approach is better, with_epoch before vs after dataloader?
|
375 |
+
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169
|
376 |
+
# if is_train:
|
377 |
+
# # roll over and repeat a few samples to get same number of full batches on each node
|
378 |
+
# global_batch_size = args.batch_size * args.world_size
|
379 |
+
# num_batches = math.ceil(num_samples / global_batch_size)
|
380 |
+
# num_workers = max(1, args.workers)
|
381 |
+
# num_batches = math.ceil(num_batches / num_workers) * num_workers
|
382 |
+
# num_samples = num_batches * global_batch_size
|
383 |
+
# dataloader = dataloader.with_epoch(num_batches)
|
384 |
+
# else:
|
385 |
+
# # last batches are partial, eval is done on single (master) node
|
386 |
+
# num_batches = math.ceil(num_samples / args.batch_size)
|
387 |
+
|
388 |
+
# add meta-data to dataloader instance for convenience
|
389 |
+
dataloader.num_batches = num_batches
|
390 |
+
dataloader.num_samples = num_samples
|
391 |
+
|
392 |
+
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
|
393 |
+
|
394 |
+
|
395 |
+
def get_csv_dataset(args, preprocess_fn, is_train, epoch=0):
|
396 |
+
input_filename = args.train_data if is_train else args.val_data
|
397 |
+
assert input_filename
|
398 |
+
dataset = CsvDataset(
|
399 |
+
input_filename,
|
400 |
+
preprocess_fn,
|
401 |
+
img_key=args.csv_img_key,
|
402 |
+
caption_key=args.csv_caption_key,
|
403 |
+
sep=args.csv_separator)
|
404 |
+
num_samples = len(dataset)
|
405 |
+
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
|
406 |
+
shuffle = is_train and sampler is None
|
407 |
+
|
408 |
+
dataloader = DataLoader(
|
409 |
+
dataset,
|
410 |
+
batch_size=args.batch_size,
|
411 |
+
shuffle=shuffle,
|
412 |
+
num_workers=args.workers,
|
413 |
+
pin_memory=True,
|
414 |
+
sampler=sampler,
|
415 |
+
drop_last=is_train,
|
416 |
+
)
|
417 |
+
dataloader.num_samples = num_samples
|
418 |
+
dataloader.num_batches = len(dataloader)
|
419 |
+
|
420 |
+
return DataInfo(dataloader, sampler)
|
421 |
+
|
422 |
+
|
423 |
+
def get_dataset_fn(data_path, dataset_type):
|
424 |
+
if dataset_type == "webdataset":
|
425 |
+
return get_wds_dataset
|
426 |
+
elif dataset_type == "csv":
|
427 |
+
return get_csv_dataset
|
428 |
+
elif dataset_type == "auto":
|
429 |
+
ext = data_path.split('.')[-1]
|
430 |
+
if ext in ['csv', 'tsv']:
|
431 |
+
return get_csv_dataset
|
432 |
+
elif ext in ['tar']:
|
433 |
+
return get_wds_dataset
|
434 |
+
else:
|
435 |
+
raise ValueError(
|
436 |
+
f"Tried to figure out dataset type, but failed for extention {ext}.")
|
437 |
+
else:
|
438 |
+
raise ValueError(f"Unsupported dataset type: {dataset_type}")
|
439 |
+
|
440 |
+
|
441 |
+
def get_data(args, preprocess_fns, epoch=0):
|
442 |
+
preprocess_train, preprocess_val = preprocess_fns
|
443 |
+
data = {}
|
444 |
+
|
445 |
+
if args.train_data:
|
446 |
+
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
|
447 |
+
args, preprocess_train, is_train=True, epoch=epoch)
|
448 |
+
|
449 |
+
if args.val_data:
|
450 |
+
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
|
451 |
+
args, preprocess_val, is_train=False)
|
452 |
+
|
453 |
+
if args.imagenet_val is not None:
|
454 |
+
data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
|
455 |
+
|
456 |
+
if args.imagenet_v2 is not None:
|
457 |
+
data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
|
458 |
+
|
459 |
+
return data
|
open_vocab_seg/data/datasets/register_ade20k_full.py
ADDED
@@ -0,0 +1,995 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import os
|
3 |
+
|
4 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
5 |
+
from detectron2.data.datasets import load_sem_seg
|
6 |
+
|
7 |
+
ADE20K_SEM_SEG_FULL_CATEGORIES = [
|
8 |
+
{"name": "wall", "id": 2978, "trainId": 0},
|
9 |
+
{"name": "building, edifice", "id": 312, "trainId": 1},
|
10 |
+
{"name": "sky", "id": 2420, "trainId": 2},
|
11 |
+
{"name": "tree", "id": 2855, "trainId": 3},
|
12 |
+
{"name": "road, route", "id": 2131, "trainId": 4},
|
13 |
+
{"name": "floor, flooring", "id": 976, "trainId": 5},
|
14 |
+
{"name": "ceiling", "id": 447, "trainId": 6},
|
15 |
+
{"name": "bed", "id": 165, "trainId": 7},
|
16 |
+
{"name": "sidewalk, pavement", "id": 2377, "trainId": 8},
|
17 |
+
{"name": "earth, ground", "id": 838, "trainId": 9},
|
18 |
+
{"name": "cabinet", "id": 350, "trainId": 10},
|
19 |
+
{
|
20 |
+
"name": "person, individual, someone, somebody, mortal, soul",
|
21 |
+
"id": 1831,
|
22 |
+
"trainId": 11,
|
23 |
+
},
|
24 |
+
{"name": "grass", "id": 1125, "trainId": 12},
|
25 |
+
{"name": "windowpane, window", "id": 3055, "trainId": 13},
|
26 |
+
{"name": "car, auto, automobile, machine, motorcar", "id": 401, "trainId": 14},
|
27 |
+
{"name": "mountain, mount", "id": 1610, "trainId": 15},
|
28 |
+
{"name": "plant, flora, plant life", "id": 1910, "trainId": 16},
|
29 |
+
{"name": "table", "id": 2684, "trainId": 17},
|
30 |
+
{"name": "chair", "id": 471, "trainId": 18},
|
31 |
+
{"name": "curtain, drape, drapery, mantle, pall", "id": 687, "trainId": 19},
|
32 |
+
{"name": "door", "id": 774, "trainId": 20},
|
33 |
+
{"name": "sofa, couch, lounge", "id": 2473, "trainId": 21},
|
34 |
+
{"name": "sea", "id": 2264, "trainId": 22},
|
35 |
+
{"name": "painting, picture", "id": 1735, "trainId": 23},
|
36 |
+
{"name": "water", "id": 2994, "trainId": 24},
|
37 |
+
{"name": "mirror", "id": 1564, "trainId": 25},
|
38 |
+
{"name": "house", "id": 1276, "trainId": 26},
|
39 |
+
{"name": "rug, carpet, carpeting", "id": 2178, "trainId": 27},
|
40 |
+
{"name": "shelf", "id": 2329, "trainId": 28},
|
41 |
+
{"name": "armchair", "id": 57, "trainId": 29},
|
42 |
+
{"name": "fence, fencing", "id": 907, "trainId": 30},
|
43 |
+
{"name": "field", "id": 913, "trainId": 31},
|
44 |
+
{"name": "lamp", "id": 1395, "trainId": 32},
|
45 |
+
{"name": "rock, stone", "id": 2138, "trainId": 33},
|
46 |
+
{"name": "seat", "id": 2272, "trainId": 34},
|
47 |
+
{"name": "river", "id": 2128, "trainId": 35},
|
48 |
+
{"name": "desk", "id": 724, "trainId": 36},
|
49 |
+
{"name": "bathtub, bathing tub, bath, tub", "id": 155, "trainId": 37},
|
50 |
+
{"name": "railing, rail", "id": 2053, "trainId": 38},
|
51 |
+
{"name": "signboard, sign", "id": 2380, "trainId": 39},
|
52 |
+
{"name": "cushion", "id": 689, "trainId": 40},
|
53 |
+
{"name": "path", "id": 1788, "trainId": 41},
|
54 |
+
{"name": "work surface", "id": 3087, "trainId": 42},
|
55 |
+
{"name": "stairs, steps", "id": 2530, "trainId": 43},
|
56 |
+
{"name": "column, pillar", "id": 581, "trainId": 44},
|
57 |
+
{"name": "sink", "id": 2388, "trainId": 45},
|
58 |
+
{"name": "wardrobe, closet, press", "id": 2985, "trainId": 46},
|
59 |
+
{"name": "snow", "id": 2454, "trainId": 47},
|
60 |
+
{"name": "refrigerator, icebox", "id": 2096, "trainId": 48},
|
61 |
+
{"name": "base, pedestal, stand", "id": 137, "trainId": 49},
|
62 |
+
{"name": "bridge, span", "id": 294, "trainId": 50},
|
63 |
+
{"name": "blind, screen", "id": 212, "trainId": 51},
|
64 |
+
{"name": "runway", "id": 2185, "trainId": 52},
|
65 |
+
{"name": "cliff, drop, drop-off", "id": 524, "trainId": 53},
|
66 |
+
{"name": "sand", "id": 2212, "trainId": 54},
|
67 |
+
{"name": "fireplace, hearth, open fireplace", "id": 943, "trainId": 55},
|
68 |
+
{"name": "pillow", "id": 1869, "trainId": 56},
|
69 |
+
{"name": "screen door, screen", "id": 2251, "trainId": 57},
|
70 |
+
{
|
71 |
+
"name": "toilet, can, commode, crapper, pot, potty, stool, throne",
|
72 |
+
"id": 2793,
|
73 |
+
"trainId": 58,
|
74 |
+
},
|
75 |
+
{"name": "skyscraper", "id": 2423, "trainId": 59},
|
76 |
+
{"name": "grandstand, covered stand", "id": 1121, "trainId": 60},
|
77 |
+
{"name": "box", "id": 266, "trainId": 61},
|
78 |
+
{"name": "pool table, billiard table, snooker table", "id": 1948, "trainId": 62},
|
79 |
+
{"name": "palm, palm tree", "id": 1744, "trainId": 63},
|
80 |
+
{"name": "double door", "id": 783, "trainId": 64},
|
81 |
+
{"name": "coffee table, cocktail table", "id": 571, "trainId": 65},
|
82 |
+
{"name": "counter", "id": 627, "trainId": 66},
|
83 |
+
{"name": "countertop", "id": 629, "trainId": 67},
|
84 |
+
{"name": "chest of drawers, chest, bureau, dresser", "id": 491, "trainId": 68},
|
85 |
+
{"name": "kitchen island", "id": 1374, "trainId": 69},
|
86 |
+
{"name": "boat", "id": 223, "trainId": 70},
|
87 |
+
{"name": "waterfall, falls", "id": 3016, "trainId": 71},
|
88 |
+
{
|
89 |
+
"name": "stove, kitchen stove, range, kitchen range, cooking stove",
|
90 |
+
"id": 2598,
|
91 |
+
"trainId": 72,
|
92 |
+
},
|
93 |
+
{"name": "flower", "id": 978, "trainId": 73},
|
94 |
+
{"name": "bookcase", "id": 239, "trainId": 74},
|
95 |
+
{"name": "controls", "id": 608, "trainId": 75},
|
96 |
+
{"name": "book", "id": 236, "trainId": 76},
|
97 |
+
{"name": "stairway, staircase", "id": 2531, "trainId": 77},
|
98 |
+
{"name": "streetlight, street lamp", "id": 2616, "trainId": 78},
|
99 |
+
{
|
100 |
+
"name": "computer, computing machine, computing device, data processor, electronic computer, information processing system",
|
101 |
+
"id": 591,
|
102 |
+
"trainId": 79,
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"name": "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle",
|
106 |
+
"id": 327,
|
107 |
+
"trainId": 80,
|
108 |
+
},
|
109 |
+
{"name": "swivel chair", "id": 2679, "trainId": 81},
|
110 |
+
{"name": "light, light source", "id": 1451, "trainId": 82},
|
111 |
+
{"name": "bench", "id": 181, "trainId": 83},
|
112 |
+
{"name": "case, display case, showcase, vitrine", "id": 420, "trainId": 84},
|
113 |
+
{"name": "towel", "id": 2821, "trainId": 85},
|
114 |
+
{"name": "fountain", "id": 1023, "trainId": 86},
|
115 |
+
{"name": "embankment", "id": 855, "trainId": 87},
|
116 |
+
{
|
117 |
+
"name": "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box",
|
118 |
+
"id": 2733,
|
119 |
+
"trainId": 88,
|
120 |
+
},
|
121 |
+
{"name": "van", "id": 2928, "trainId": 89},
|
122 |
+
{"name": "hill", "id": 1240, "trainId": 90},
|
123 |
+
{"name": "awning, sunshade, sunblind", "id": 77, "trainId": 91},
|
124 |
+
{"name": "poster, posting, placard, notice, bill, card", "id": 1969, "trainId": 92},
|
125 |
+
{"name": "truck, motortruck", "id": 2880, "trainId": 93},
|
126 |
+
{"name": "airplane, aeroplane, plane", "id": 14, "trainId": 94},
|
127 |
+
{"name": "pole", "id": 1936, "trainId": 95},
|
128 |
+
{"name": "tower", "id": 2828, "trainId": 96},
|
129 |
+
{"name": "court", "id": 631, "trainId": 97},
|
130 |
+
{"name": "ball", "id": 103, "trainId": 98},
|
131 |
+
{
|
132 |
+
"name": "aircraft carrier, carrier, flattop, attack aircraft carrier",
|
133 |
+
"id": 3144,
|
134 |
+
"trainId": 99,
|
135 |
+
},
|
136 |
+
{"name": "buffet, counter, sideboard", "id": 308, "trainId": 100},
|
137 |
+
{"name": "hovel, hut, hutch, shack, shanty", "id": 1282, "trainId": 101},
|
138 |
+
{"name": "apparel, wearing apparel, dress, clothes", "id": 38, "trainId": 102},
|
139 |
+
{"name": "minibike, motorbike", "id": 1563, "trainId": 103},
|
140 |
+
{
|
141 |
+
"name": "animal, animate being, beast, brute, creature, fauna",
|
142 |
+
"id": 29,
|
143 |
+
"trainId": 104,
|
144 |
+
},
|
145 |
+
{"name": "chandelier, pendant, pendent", "id": 480, "trainId": 105},
|
146 |
+
{"name": "step, stair", "id": 2569, "trainId": 106},
|
147 |
+
{"name": "booth, cubicle, stall, kiosk", "id": 247, "trainId": 107},
|
148 |
+
{"name": "bicycle, bike, wheel, cycle", "id": 187, "trainId": 108},
|
149 |
+
{"name": "doorframe, doorcase", "id": 778, "trainId": 109},
|
150 |
+
{"name": "sconce", "id": 2243, "trainId": 110},
|
151 |
+
{"name": "pond", "id": 1941, "trainId": 111},
|
152 |
+
{"name": "trade name, brand name, brand, marque", "id": 2833, "trainId": 112},
|
153 |
+
{
|
154 |
+
"name": "bannister, banister, balustrade, balusters, handrail",
|
155 |
+
"id": 120,
|
156 |
+
"trainId": 113,
|
157 |
+
},
|
158 |
+
{"name": "bag", "id": 95, "trainId": 114},
|
159 |
+
{"name": "traffic light, traffic signal, stoplight", "id": 2836, "trainId": 115},
|
160 |
+
{"name": "gazebo", "id": 1087, "trainId": 116},
|
161 |
+
{"name": "escalator, moving staircase, moving stairway", "id": 868, "trainId": 117},
|
162 |
+
{"name": "land, ground, soil", "id": 1401, "trainId": 118},
|
163 |
+
{"name": "board, plank", "id": 220, "trainId": 119},
|
164 |
+
{"name": "arcade machine", "id": 47, "trainId": 120},
|
165 |
+
{"name": "eiderdown, duvet, continental quilt", "id": 843, "trainId": 121},
|
166 |
+
{"name": "bar", "id": 123, "trainId": 122},
|
167 |
+
{"name": "stall, stand, sales booth", "id": 2537, "trainId": 123},
|
168 |
+
{"name": "playground", "id": 1927, "trainId": 124},
|
169 |
+
{"name": "ship", "id": 2337, "trainId": 125},
|
170 |
+
{"name": "ottoman, pouf, pouffe, puff, hassock", "id": 1702, "trainId": 126},
|
171 |
+
{
|
172 |
+
"name": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
|
173 |
+
"id": 64,
|
174 |
+
"trainId": 127,
|
175 |
+
},
|
176 |
+
{"name": "bottle", "id": 249, "trainId": 128},
|
177 |
+
{"name": "cradle", "id": 642, "trainId": 129},
|
178 |
+
{"name": "pot, flowerpot", "id": 1981, "trainId": 130},
|
179 |
+
{
|
180 |
+
"name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
|
181 |
+
"id": 609,
|
182 |
+
"trainId": 131,
|
183 |
+
},
|
184 |
+
{"name": "train, railroad train", "id": 2840, "trainId": 132},
|
185 |
+
{"name": "stool", "id": 2586, "trainId": 133},
|
186 |
+
{"name": "lake", "id": 1393, "trainId": 134},
|
187 |
+
{"name": "tank, storage tank", "id": 2704, "trainId": 135},
|
188 |
+
{"name": "ice, water ice", "id": 1304, "trainId": 136},
|
189 |
+
{"name": "basket, handbasket", "id": 146, "trainId": 137},
|
190 |
+
{"name": "manhole", "id": 1494, "trainId": 138},
|
191 |
+
{"name": "tent, collapsible shelter", "id": 2739, "trainId": 139},
|
192 |
+
{"name": "canopy", "id": 389, "trainId": 140},
|
193 |
+
{"name": "microwave, microwave oven", "id": 1551, "trainId": 141},
|
194 |
+
{"name": "barrel, cask", "id": 131, "trainId": 142},
|
195 |
+
{"name": "dirt track", "id": 738, "trainId": 143},
|
196 |
+
{"name": "beam", "id": 161, "trainId": 144},
|
197 |
+
{"name": "dishwasher, dish washer, dishwashing machine", "id": 747, "trainId": 145},
|
198 |
+
{"name": "plate", "id": 1919, "trainId": 146},
|
199 |
+
{"name": "screen, crt screen", "id": 3109, "trainId": 147},
|
200 |
+
{"name": "ruins", "id": 2179, "trainId": 148},
|
201 |
+
{"name": "washer, automatic washer, washing machine", "id": 2989, "trainId": 149},
|
202 |
+
{"name": "blanket, cover", "id": 206, "trainId": 150},
|
203 |
+
{"name": "plaything, toy", "id": 1930, "trainId": 151},
|
204 |
+
{"name": "food, solid food", "id": 1002, "trainId": 152},
|
205 |
+
{"name": "screen, silver screen, projection screen", "id": 2254, "trainId": 153},
|
206 |
+
{"name": "oven", "id": 1708, "trainId": 154},
|
207 |
+
{"name": "stage", "id": 2526, "trainId": 155},
|
208 |
+
{"name": "beacon, lighthouse, beacon light, pharos", "id": 160, "trainId": 156},
|
209 |
+
{"name": "umbrella", "id": 2901, "trainId": 157},
|
210 |
+
{"name": "sculpture", "id": 2262, "trainId": 158},
|
211 |
+
{"name": "aqueduct", "id": 44, "trainId": 159},
|
212 |
+
{"name": "container", "id": 597, "trainId": 160},
|
213 |
+
{"name": "scaffolding, staging", "id": 2235, "trainId": 161},
|
214 |
+
{"name": "hood, exhaust hood", "id": 1260, "trainId": 162},
|
215 |
+
{"name": "curb, curbing, kerb", "id": 682, "trainId": 163},
|
216 |
+
{"name": "roller coaster", "id": 2151, "trainId": 164},
|
217 |
+
{"name": "horse, equus caballus", "id": 3107, "trainId": 165},
|
218 |
+
{"name": "catwalk", "id": 432, "trainId": 166},
|
219 |
+
{"name": "glass, drinking glass", "id": 1098, "trainId": 167},
|
220 |
+
{"name": "vase", "id": 2932, "trainId": 168},
|
221 |
+
{"name": "central reservation", "id": 461, "trainId": 169},
|
222 |
+
{"name": "carousel", "id": 410, "trainId": 170},
|
223 |
+
{"name": "radiator", "id": 2046, "trainId": 171},
|
224 |
+
{"name": "closet", "id": 533, "trainId": 172},
|
225 |
+
{"name": "machine", "id": 1481, "trainId": 173},
|
226 |
+
{"name": "pier, wharf, wharfage, dock", "id": 1858, "trainId": 174},
|
227 |
+
{"name": "fan", "id": 894, "trainId": 175},
|
228 |
+
{"name": "inflatable bounce game", "id": 1322, "trainId": 176},
|
229 |
+
{"name": "pitch", "id": 1891, "trainId": 177},
|
230 |
+
{"name": "paper", "id": 1756, "trainId": 178},
|
231 |
+
{"name": "arcade, colonnade", "id": 49, "trainId": 179},
|
232 |
+
{"name": "hot tub", "id": 1272, "trainId": 180},
|
233 |
+
{"name": "helicopter", "id": 1229, "trainId": 181},
|
234 |
+
{"name": "tray", "id": 2850, "trainId": 182},
|
235 |
+
{"name": "partition, divider", "id": 1784, "trainId": 183},
|
236 |
+
{"name": "vineyard", "id": 2962, "trainId": 184},
|
237 |
+
{"name": "bowl", "id": 259, "trainId": 185},
|
238 |
+
{"name": "bullring", "id": 319, "trainId": 186},
|
239 |
+
{"name": "flag", "id": 954, "trainId": 187},
|
240 |
+
{"name": "pot", "id": 1974, "trainId": 188},
|
241 |
+
{"name": "footbridge, overcrossing, pedestrian bridge", "id": 1013, "trainId": 189},
|
242 |
+
{"name": "shower", "id": 2356, "trainId": 190},
|
243 |
+
{
|
244 |
+
"name": "bag, traveling bag, travelling bag, grip, suitcase",
|
245 |
+
"id": 97,
|
246 |
+
"trainId": 191,
|
247 |
+
},
|
248 |
+
{"name": "bulletin board, notice board", "id": 318, "trainId": 192},
|
249 |
+
{"name": "confessional booth", "id": 592, "trainId": 193},
|
250 |
+
{"name": "trunk, tree trunk, bole", "id": 2885, "trainId": 194},
|
251 |
+
{"name": "forest", "id": 1017, "trainId": 195},
|
252 |
+
{"name": "elevator door", "id": 851, "trainId": 196},
|
253 |
+
{"name": "laptop, laptop computer", "id": 1407, "trainId": 197},
|
254 |
+
{"name": "instrument panel", "id": 1332, "trainId": 198},
|
255 |
+
{"name": "bucket, pail", "id": 303, "trainId": 199},
|
256 |
+
{"name": "tapestry, tapis", "id": 2714, "trainId": 200},
|
257 |
+
{"name": "platform", "id": 1924, "trainId": 201},
|
258 |
+
{"name": "jacket", "id": 1346, "trainId": 202},
|
259 |
+
{"name": "gate", "id": 1081, "trainId": 203},
|
260 |
+
{"name": "monitor, monitoring device", "id": 1583, "trainId": 204},
|
261 |
+
{
|
262 |
+
"name": "telephone booth, phone booth, call box, telephone box, telephone kiosk",
|
263 |
+
"id": 2727,
|
264 |
+
"trainId": 205,
|
265 |
+
},
|
266 |
+
{"name": "spotlight, spot", "id": 2509, "trainId": 206},
|
267 |
+
{"name": "ring", "id": 2123, "trainId": 207},
|
268 |
+
{"name": "control panel", "id": 602, "trainId": 208},
|
269 |
+
{"name": "blackboard, chalkboard", "id": 202, "trainId": 209},
|
270 |
+
{"name": "air conditioner, air conditioning", "id": 10, "trainId": 210},
|
271 |
+
{"name": "chest", "id": 490, "trainId": 211},
|
272 |
+
{"name": "clock", "id": 530, "trainId": 212},
|
273 |
+
{"name": "sand dune", "id": 2213, "trainId": 213},
|
274 |
+
{"name": "pipe, pipage, piping", "id": 1884, "trainId": 214},
|
275 |
+
{"name": "vault", "id": 2934, "trainId": 215},
|
276 |
+
{"name": "table football", "id": 2687, "trainId": 216},
|
277 |
+
{"name": "cannon", "id": 387, "trainId": 217},
|
278 |
+
{"name": "swimming pool, swimming bath, natatorium", "id": 2668, "trainId": 218},
|
279 |
+
{"name": "fluorescent, fluorescent fixture", "id": 982, "trainId": 219},
|
280 |
+
{"name": "statue", "id": 2547, "trainId": 220},
|
281 |
+
{
|
282 |
+
"name": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
|
283 |
+
"id": 1474,
|
284 |
+
"trainId": 221,
|
285 |
+
},
|
286 |
+
{"name": "exhibitor", "id": 877, "trainId": 222},
|
287 |
+
{"name": "ladder", "id": 1391, "trainId": 223},
|
288 |
+
{"name": "carport", "id": 414, "trainId": 224},
|
289 |
+
{"name": "dam", "id": 698, "trainId": 225},
|
290 |
+
{"name": "pulpit", "id": 2019, "trainId": 226},
|
291 |
+
{"name": "skylight, fanlight", "id": 2422, "trainId": 227},
|
292 |
+
{"name": "water tower", "id": 3010, "trainId": 228},
|
293 |
+
{"name": "grill, grille, grillwork", "id": 1139, "trainId": 229},
|
294 |
+
{"name": "display board", "id": 753, "trainId": 230},
|
295 |
+
{"name": "pane, pane of glass, window glass", "id": 1747, "trainId": 231},
|
296 |
+
{"name": "rubbish, trash, scrap", "id": 2175, "trainId": 232},
|
297 |
+
{"name": "ice rink", "id": 1301, "trainId": 233},
|
298 |
+
{"name": "fruit", "id": 1033, "trainId": 234},
|
299 |
+
{"name": "patio", "id": 1789, "trainId": 235},
|
300 |
+
{"name": "vending machine", "id": 2939, "trainId": 236},
|
301 |
+
{"name": "telephone, phone, telephone set", "id": 2730, "trainId": 237},
|
302 |
+
{"name": "net", "id": 1652, "trainId": 238},
|
303 |
+
{
|
304 |
+
"name": "backpack, back pack, knapsack, packsack, rucksack, haversack",
|
305 |
+
"id": 90,
|
306 |
+
"trainId": 239,
|
307 |
+
},
|
308 |
+
{"name": "jar", "id": 1349, "trainId": 240},
|
309 |
+
{"name": "track", "id": 2830, "trainId": 241},
|
310 |
+
{"name": "magazine", "id": 1485, "trainId": 242},
|
311 |
+
{"name": "shutter", "id": 2370, "trainId": 243},
|
312 |
+
{"name": "roof", "id": 2155, "trainId": 244},
|
313 |
+
{"name": "banner, streamer", "id": 118, "trainId": 245},
|
314 |
+
{"name": "landfill", "id": 1402, "trainId": 246},
|
315 |
+
{"name": "post", "id": 1957, "trainId": 247},
|
316 |
+
{"name": "altarpiece, reredos", "id": 3130, "trainId": 248},
|
317 |
+
{"name": "hat, chapeau, lid", "id": 1197, "trainId": 249},
|
318 |
+
{"name": "arch, archway", "id": 52, "trainId": 250},
|
319 |
+
{"name": "table game", "id": 2688, "trainId": 251},
|
320 |
+
{"name": "bag, handbag, pocketbook, purse", "id": 96, "trainId": 252},
|
321 |
+
{"name": "document, written document, papers", "id": 762, "trainId": 253},
|
322 |
+
{"name": "dome", "id": 772, "trainId": 254},
|
323 |
+
{"name": "pier", "id": 1857, "trainId": 255},
|
324 |
+
{"name": "shanties", "id": 2315, "trainId": 256},
|
325 |
+
{"name": "forecourt", "id": 1016, "trainId": 257},
|
326 |
+
{"name": "crane", "id": 643, "trainId": 258},
|
327 |
+
{"name": "dog, domestic dog, canis familiaris", "id": 3105, "trainId": 259},
|
328 |
+
{"name": "piano, pianoforte, forte-piano", "id": 1849, "trainId": 260},
|
329 |
+
{"name": "drawing", "id": 791, "trainId": 261},
|
330 |
+
{"name": "cabin", "id": 349, "trainId": 262},
|
331 |
+
{
|
332 |
+
"name": "ad, advertisement, advertizement, advertising, advertizing, advert",
|
333 |
+
"id": 6,
|
334 |
+
"trainId": 263,
|
335 |
+
},
|
336 |
+
{"name": "amphitheater, amphitheatre, coliseum", "id": 3114, "trainId": 264},
|
337 |
+
{"name": "monument", "id": 1587, "trainId": 265},
|
338 |
+
{"name": "henhouse", "id": 1233, "trainId": 266},
|
339 |
+
{"name": "cockpit", "id": 559, "trainId": 267},
|
340 |
+
{"name": "heater, warmer", "id": 1223, "trainId": 268},
|
341 |
+
{"name": "windmill, aerogenerator, wind generator", "id": 3049, "trainId": 269},
|
342 |
+
{"name": "pool", "id": 1943, "trainId": 270},
|
343 |
+
{"name": "elevator, lift", "id": 853, "trainId": 271},
|
344 |
+
{"name": "decoration, ornament, ornamentation", "id": 709, "trainId": 272},
|
345 |
+
{"name": "labyrinth", "id": 1390, "trainId": 273},
|
346 |
+
{"name": "text, textual matter", "id": 2748, "trainId": 274},
|
347 |
+
{"name": "printer", "id": 2007, "trainId": 275},
|
348 |
+
{"name": "mezzanine, first balcony", "id": 1546, "trainId": 276},
|
349 |
+
{"name": "mattress", "id": 1513, "trainId": 277},
|
350 |
+
{"name": "straw", "id": 2600, "trainId": 278},
|
351 |
+
{"name": "stalls", "id": 2538, "trainId": 279},
|
352 |
+
{"name": "patio, terrace", "id": 1790, "trainId": 280},
|
353 |
+
{"name": "billboard, hoarding", "id": 194, "trainId": 281},
|
354 |
+
{"name": "bus stop", "id": 326, "trainId": 282},
|
355 |
+
{"name": "trouser, pant", "id": 2877, "trainId": 283},
|
356 |
+
{"name": "console table, console", "id": 594, "trainId": 284},
|
357 |
+
{"name": "rack", "id": 2036, "trainId": 285},
|
358 |
+
{"name": "notebook", "id": 1662, "trainId": 286},
|
359 |
+
{"name": "shrine", "id": 2366, "trainId": 287},
|
360 |
+
{"name": "pantry", "id": 1754, "trainId": 288},
|
361 |
+
{"name": "cart", "id": 418, "trainId": 289},
|
362 |
+
{"name": "steam shovel", "id": 2553, "trainId": 290},
|
363 |
+
{"name": "porch", "id": 1951, "trainId": 291},
|
364 |
+
{"name": "postbox, mailbox, letter box", "id": 1963, "trainId": 292},
|
365 |
+
{"name": "figurine, statuette", "id": 918, "trainId": 293},
|
366 |
+
{"name": "recycling bin", "id": 2086, "trainId": 294},
|
367 |
+
{"name": "folding screen", "id": 997, "trainId": 295},
|
368 |
+
{"name": "telescope", "id": 2731, "trainId": 296},
|
369 |
+
{"name": "deck chair, beach chair", "id": 704, "trainId": 297},
|
370 |
+
{"name": "kennel", "id": 1365, "trainId": 298},
|
371 |
+
{"name": "coffee maker", "id": 569, "trainId": 299},
|
372 |
+
{"name": "altar, communion table, lord's table", "id": 3108, "trainId": 300},
|
373 |
+
{"name": "fish", "id": 948, "trainId": 301},
|
374 |
+
{"name": "easel", "id": 839, "trainId": 302},
|
375 |
+
{"name": "artificial golf green", "id": 63, "trainId": 303},
|
376 |
+
{"name": "iceberg", "id": 1305, "trainId": 304},
|
377 |
+
{"name": "candlestick, candle holder", "id": 378, "trainId": 305},
|
378 |
+
{"name": "shower stall, shower bath", "id": 2362, "trainId": 306},
|
379 |
+
{"name": "television stand", "id": 2734, "trainId": 307},
|
380 |
+
{
|
381 |
+
"name": "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle",
|
382 |
+
"id": 2982,
|
383 |
+
"trainId": 308,
|
384 |
+
},
|
385 |
+
{"name": "skeleton", "id": 2398, "trainId": 309},
|
386 |
+
{"name": "grand piano, grand", "id": 1119, "trainId": 310},
|
387 |
+
{"name": "candy, confect", "id": 382, "trainId": 311},
|
388 |
+
{"name": "grille door", "id": 1141, "trainId": 312},
|
389 |
+
{"name": "pedestal, plinth, footstall", "id": 1805, "trainId": 313},
|
390 |
+
{"name": "jersey, t-shirt, tee shirt", "id": 3102, "trainId": 314},
|
391 |
+
{"name": "shoe", "id": 2341, "trainId": 315},
|
392 |
+
{"name": "gravestone, headstone, tombstone", "id": 1131, "trainId": 316},
|
393 |
+
{"name": "shanty", "id": 2316, "trainId": 317},
|
394 |
+
{"name": "structure", "id": 2626, "trainId": 318},
|
395 |
+
{"name": "rocking chair, rocker", "id": 3104, "trainId": 319},
|
396 |
+
{"name": "bird", "id": 198, "trainId": 320},
|
397 |
+
{"name": "place mat", "id": 1896, "trainId": 321},
|
398 |
+
{"name": "tomb", "id": 2800, "trainId": 322},
|
399 |
+
{"name": "big top", "id": 190, "trainId": 323},
|
400 |
+
{
|
401 |
+
"name": "gas pump, gasoline pump, petrol pump, island dispenser",
|
402 |
+
"id": 3131,
|
403 |
+
"trainId": 324,
|
404 |
+
},
|
405 |
+
{"name": "lockers", "id": 1463, "trainId": 325},
|
406 |
+
{"name": "cage", "id": 357, "trainId": 326},
|
407 |
+
{"name": "finger", "id": 929, "trainId": 327},
|
408 |
+
{"name": "bleachers", "id": 209, "trainId": 328},
|
409 |
+
{"name": "ferris wheel", "id": 912, "trainId": 329},
|
410 |
+
{"name": "hairdresser chair", "id": 1164, "trainId": 330},
|
411 |
+
{"name": "mat", "id": 1509, "trainId": 331},
|
412 |
+
{"name": "stands", "id": 2539, "trainId": 332},
|
413 |
+
{"name": "aquarium, fish tank, marine museum", "id": 3116, "trainId": 333},
|
414 |
+
{
|
415 |
+
"name": "streetcar, tram, tramcar, trolley, trolley car",
|
416 |
+
"id": 2615,
|
417 |
+
"trainId": 334,
|
418 |
+
},
|
419 |
+
{"name": "napkin, table napkin, serviette", "id": 1644, "trainId": 335},
|
420 |
+
{"name": "dummy", "id": 818, "trainId": 336},
|
421 |
+
{"name": "booklet, brochure, folder, leaflet, pamphlet", "id": 242, "trainId": 337},
|
422 |
+
{"name": "sand trap", "id": 2217, "trainId": 338},
|
423 |
+
{"name": "shop, store", "id": 2347, "trainId": 339},
|
424 |
+
{"name": "table cloth", "id": 2686, "trainId": 340},
|
425 |
+
{"name": "service station", "id": 2300, "trainId": 341},
|
426 |
+
{"name": "coffin", "id": 572, "trainId": 342},
|
427 |
+
{"name": "drawer", "id": 789, "trainId": 343},
|
428 |
+
{"name": "cages", "id": 358, "trainId": 344},
|
429 |
+
{"name": "slot machine, coin machine", "id": 2443, "trainId": 345},
|
430 |
+
{"name": "balcony", "id": 101, "trainId": 346},
|
431 |
+
{"name": "volleyball court", "id": 2969, "trainId": 347},
|
432 |
+
{"name": "table tennis", "id": 2692, "trainId": 348},
|
433 |
+
{"name": "control table", "id": 606, "trainId": 349},
|
434 |
+
{"name": "shirt", "id": 2339, "trainId": 350},
|
435 |
+
{"name": "merchandise, ware, product", "id": 1533, "trainId": 351},
|
436 |
+
{"name": "railway", "id": 2060, "trainId": 352},
|
437 |
+
{"name": "parterre", "id": 1782, "trainId": 353},
|
438 |
+
{"name": "chimney", "id": 495, "trainId": 354},
|
439 |
+
{"name": "can, tin, tin can", "id": 371, "trainId": 355},
|
440 |
+
{"name": "tanks", "id": 2707, "trainId": 356},
|
441 |
+
{"name": "fabric, cloth, material, textile", "id": 889, "trainId": 357},
|
442 |
+
{"name": "alga, algae", "id": 3156, "trainId": 358},
|
443 |
+
{"name": "system", "id": 2683, "trainId": 359},
|
444 |
+
{"name": "map", "id": 1499, "trainId": 360},
|
445 |
+
{"name": "greenhouse", "id": 1135, "trainId": 361},
|
446 |
+
{"name": "mug", "id": 1619, "trainId": 362},
|
447 |
+
{"name": "barbecue", "id": 125, "trainId": 363},
|
448 |
+
{"name": "trailer", "id": 2838, "trainId": 364},
|
449 |
+
{
|
450 |
+
"name": "toilet tissue, toilet paper, bathroom tissue",
|
451 |
+
"id": 2792,
|
452 |
+
"trainId": 365,
|
453 |
+
},
|
454 |
+
{"name": "organ", "id": 1695, "trainId": 366},
|
455 |
+
{"name": "dishrag, dishcloth", "id": 746, "trainId": 367},
|
456 |
+
{"name": "island", "id": 1343, "trainId": 368},
|
457 |
+
{"name": "keyboard", "id": 1370, "trainId": 369},
|
458 |
+
{"name": "trench", "id": 2858, "trainId": 370},
|
459 |
+
{"name": "basket, basketball hoop, hoop", "id": 145, "trainId": 371},
|
460 |
+
{"name": "steering wheel, wheel", "id": 2565, "trainId": 372},
|
461 |
+
{"name": "pitcher, ewer", "id": 1892, "trainId": 373},
|
462 |
+
{"name": "goal", "id": 1103, "trainId": 374},
|
463 |
+
{"name": "bread, breadstuff, staff of life", "id": 286, "trainId": 375},
|
464 |
+
{"name": "beds", "id": 170, "trainId": 376},
|
465 |
+
{"name": "wood", "id": 3073, "trainId": 377},
|
466 |
+
{"name": "file cabinet", "id": 922, "trainId": 378},
|
467 |
+
{"name": "newspaper, paper", "id": 1655, "trainId": 379},
|
468 |
+
{"name": "motorboat", "id": 1602, "trainId": 380},
|
469 |
+
{"name": "rope", "id": 2160, "trainId": 381},
|
470 |
+
{"name": "guitar", "id": 1151, "trainId": 382},
|
471 |
+
{"name": "rubble", "id": 2176, "trainId": 383},
|
472 |
+
{"name": "scarf", "id": 2239, "trainId": 384},
|
473 |
+
{"name": "barrels", "id": 132, "trainId": 385},
|
474 |
+
{"name": "cap", "id": 394, "trainId": 386},
|
475 |
+
{"name": "leaves", "id": 1424, "trainId": 387},
|
476 |
+
{"name": "control tower", "id": 607, "trainId": 388},
|
477 |
+
{"name": "dashboard", "id": 700, "trainId": 389},
|
478 |
+
{"name": "bandstand", "id": 116, "trainId": 390},
|
479 |
+
{"name": "lectern", "id": 1425, "trainId": 391},
|
480 |
+
{"name": "switch, electric switch, electrical switch", "id": 2676, "trainId": 392},
|
481 |
+
{"name": "baseboard, mopboard, skirting board", "id": 141, "trainId": 393},
|
482 |
+
{"name": "shower room", "id": 2360, "trainId": 394},
|
483 |
+
{"name": "smoke", "id": 2449, "trainId": 395},
|
484 |
+
{"name": "faucet, spigot", "id": 897, "trainId": 396},
|
485 |
+
{"name": "bulldozer", "id": 317, "trainId": 397},
|
486 |
+
{"name": "saucepan", "id": 2228, "trainId": 398},
|
487 |
+
{"name": "shops", "id": 2351, "trainId": 399},
|
488 |
+
{"name": "meter", "id": 1543, "trainId": 400},
|
489 |
+
{"name": "crevasse", "id": 656, "trainId": 401},
|
490 |
+
{"name": "gear", "id": 1088, "trainId": 402},
|
491 |
+
{"name": "candelabrum, candelabra", "id": 373, "trainId": 403},
|
492 |
+
{"name": "sofa bed", "id": 2472, "trainId": 404},
|
493 |
+
{"name": "tunnel", "id": 2892, "trainId": 405},
|
494 |
+
{"name": "pallet", "id": 1740, "trainId": 406},
|
495 |
+
{"name": "wire, conducting wire", "id": 3067, "trainId": 407},
|
496 |
+
{"name": "kettle, boiler", "id": 1367, "trainId": 408},
|
497 |
+
{"name": "bidet", "id": 188, "trainId": 409},
|
498 |
+
{
|
499 |
+
"name": "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher",
|
500 |
+
"id": 79,
|
501 |
+
"trainId": 410,
|
502 |
+
},
|
503 |
+
{"name": "music stand", "id": 1633, "trainId": 411},
|
504 |
+
{"name": "pipe, tube", "id": 1885, "trainId": 412},
|
505 |
+
{"name": "cup", "id": 677, "trainId": 413},
|
506 |
+
{"name": "parking meter", "id": 1779, "trainId": 414},
|
507 |
+
{"name": "ice hockey rink", "id": 1297, "trainId": 415},
|
508 |
+
{"name": "shelter", "id": 2334, "trainId": 416},
|
509 |
+
{"name": "weeds", "id": 3027, "trainId": 417},
|
510 |
+
{"name": "temple", "id": 2735, "trainId": 418},
|
511 |
+
{"name": "patty, cake", "id": 1791, "trainId": 419},
|
512 |
+
{"name": "ski slope", "id": 2405, "trainId": 420},
|
513 |
+
{"name": "panel", "id": 1748, "trainId": 421},
|
514 |
+
{"name": "wallet", "id": 2983, "trainId": 422},
|
515 |
+
{"name": "wheel", "id": 3035, "trainId": 423},
|
516 |
+
{"name": "towel rack, towel horse", "id": 2824, "trainId": 424},
|
517 |
+
{"name": "roundabout", "id": 2168, "trainId": 425},
|
518 |
+
{"name": "canister, cannister, tin", "id": 385, "trainId": 426},
|
519 |
+
{"name": "rod", "id": 2148, "trainId": 427},
|
520 |
+
{"name": "soap dispenser", "id": 2465, "trainId": 428},
|
521 |
+
{"name": "bell", "id": 175, "trainId": 429},
|
522 |
+
{"name": "canvas", "id": 390, "trainId": 430},
|
523 |
+
{"name": "box office, ticket office, ticket booth", "id": 268, "trainId": 431},
|
524 |
+
{"name": "teacup", "id": 2722, "trainId": 432},
|
525 |
+
{"name": "trellis", "id": 2857, "trainId": 433},
|
526 |
+
{"name": "workbench", "id": 3088, "trainId": 434},
|
527 |
+
{"name": "valley, vale", "id": 2926, "trainId": 435},
|
528 |
+
{"name": "toaster", "id": 2782, "trainId": 436},
|
529 |
+
{"name": "knife", "id": 1378, "trainId": 437},
|
530 |
+
{"name": "podium", "id": 1934, "trainId": 438},
|
531 |
+
{"name": "ramp", "id": 2072, "trainId": 439},
|
532 |
+
{"name": "tumble dryer", "id": 2889, "trainId": 440},
|
533 |
+
{"name": "fireplug, fire hydrant, plug", "id": 944, "trainId": 441},
|
534 |
+
{"name": "gym shoe, sneaker, tennis shoe", "id": 1158, "trainId": 442},
|
535 |
+
{"name": "lab bench", "id": 1383, "trainId": 443},
|
536 |
+
{"name": "equipment", "id": 867, "trainId": 444},
|
537 |
+
{"name": "rocky formation", "id": 2145, "trainId": 445},
|
538 |
+
{"name": "plastic", "id": 1915, "trainId": 446},
|
539 |
+
{"name": "calendar", "id": 361, "trainId": 447},
|
540 |
+
{"name": "caravan", "id": 402, "trainId": 448},
|
541 |
+
{"name": "check-in-desk", "id": 482, "trainId": 449},
|
542 |
+
{"name": "ticket counter", "id": 2761, "trainId": 450},
|
543 |
+
{"name": "brush", "id": 300, "trainId": 451},
|
544 |
+
{"name": "mill", "id": 1554, "trainId": 452},
|
545 |
+
{"name": "covered bridge", "id": 636, "trainId": 453},
|
546 |
+
{"name": "bowling alley", "id": 260, "trainId": 454},
|
547 |
+
{"name": "hanger", "id": 1186, "trainId": 455},
|
548 |
+
{"name": "excavator", "id": 871, "trainId": 456},
|
549 |
+
{"name": "trestle", "id": 2859, "trainId": 457},
|
550 |
+
{"name": "revolving door", "id": 2103, "trainId": 458},
|
551 |
+
{"name": "blast furnace", "id": 208, "trainId": 459},
|
552 |
+
{"name": "scale, weighing machine", "id": 2236, "trainId": 460},
|
553 |
+
{"name": "projector", "id": 2012, "trainId": 461},
|
554 |
+
{"name": "soap", "id": 2462, "trainId": 462},
|
555 |
+
{"name": "locker", "id": 1462, "trainId": 463},
|
556 |
+
{"name": "tractor", "id": 2832, "trainId": 464},
|
557 |
+
{"name": "stretcher", "id": 2617, "trainId": 465},
|
558 |
+
{"name": "frame", "id": 1024, "trainId": 466},
|
559 |
+
{"name": "grating", "id": 1129, "trainId": 467},
|
560 |
+
{"name": "alembic", "id": 18, "trainId": 468},
|
561 |
+
{"name": "candle, taper, wax light", "id": 376, "trainId": 469},
|
562 |
+
{"name": "barrier", "id": 134, "trainId": 470},
|
563 |
+
{"name": "cardboard", "id": 407, "trainId": 471},
|
564 |
+
{"name": "cave", "id": 434, "trainId": 472},
|
565 |
+
{"name": "puddle", "id": 2017, "trainId": 473},
|
566 |
+
{"name": "tarp", "id": 2717, "trainId": 474},
|
567 |
+
{"name": "price tag", "id": 2005, "trainId": 475},
|
568 |
+
{"name": "watchtower", "id": 2993, "trainId": 476},
|
569 |
+
{"name": "meters", "id": 1545, "trainId": 477},
|
570 |
+
{
|
571 |
+
"name": "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb",
|
572 |
+
"id": 1445,
|
573 |
+
"trainId": 478,
|
574 |
+
},
|
575 |
+
{"name": "tracks", "id": 2831, "trainId": 479},
|
576 |
+
{"name": "hair dryer", "id": 1161, "trainId": 480},
|
577 |
+
{"name": "skirt", "id": 2411, "trainId": 481},
|
578 |
+
{"name": "viaduct", "id": 2949, "trainId": 482},
|
579 |
+
{"name": "paper towel", "id": 1769, "trainId": 483},
|
580 |
+
{"name": "coat", "id": 552, "trainId": 484},
|
581 |
+
{"name": "sheet", "id": 2327, "trainId": 485},
|
582 |
+
{"name": "fire extinguisher, extinguisher, asphyxiator", "id": 939, "trainId": 486},
|
583 |
+
{"name": "water wheel", "id": 3013, "trainId": 487},
|
584 |
+
{"name": "pottery, clayware", "id": 1986, "trainId": 488},
|
585 |
+
{"name": "magazine rack", "id": 1486, "trainId": 489},
|
586 |
+
{"name": "teapot", "id": 2723, "trainId": 490},
|
587 |
+
{"name": "microphone, mike", "id": 1549, "trainId": 491},
|
588 |
+
{"name": "support", "id": 2649, "trainId": 492},
|
589 |
+
{"name": "forklift", "id": 1020, "trainId": 493},
|
590 |
+
{"name": "canyon", "id": 392, "trainId": 494},
|
591 |
+
{"name": "cash register, register", "id": 422, "trainId": 495},
|
592 |
+
{"name": "leaf, leafage, foliage", "id": 1419, "trainId": 496},
|
593 |
+
{"name": "remote control, remote", "id": 2099, "trainId": 497},
|
594 |
+
{"name": "soap dish", "id": 2464, "trainId": 498},
|
595 |
+
{"name": "windshield, windscreen", "id": 3058, "trainId": 499},
|
596 |
+
{"name": "cat", "id": 430, "trainId": 500},
|
597 |
+
{"name": "cue, cue stick, pool cue, pool stick", "id": 675, "trainId": 501},
|
598 |
+
{"name": "vent, venthole, vent-hole, blowhole", "id": 2941, "trainId": 502},
|
599 |
+
{"name": "videos", "id": 2955, "trainId": 503},
|
600 |
+
{"name": "shovel", "id": 2355, "trainId": 504},
|
601 |
+
{"name": "eaves", "id": 840, "trainId": 505},
|
602 |
+
{"name": "antenna, aerial, transmitting aerial", "id": 32, "trainId": 506},
|
603 |
+
{"name": "shipyard", "id": 2338, "trainId": 507},
|
604 |
+
{"name": "hen, biddy", "id": 1232, "trainId": 508},
|
605 |
+
{"name": "traffic cone", "id": 2834, "trainId": 509},
|
606 |
+
{"name": "washing machines", "id": 2991, "trainId": 510},
|
607 |
+
{"name": "truck crane", "id": 2879, "trainId": 511},
|
608 |
+
{"name": "cds", "id": 444, "trainId": 512},
|
609 |
+
{"name": "niche", "id": 1657, "trainId": 513},
|
610 |
+
{"name": "scoreboard", "id": 2246, "trainId": 514},
|
611 |
+
{"name": "briefcase", "id": 296, "trainId": 515},
|
612 |
+
{"name": "boot", "id": 245, "trainId": 516},
|
613 |
+
{"name": "sweater, jumper", "id": 2661, "trainId": 517},
|
614 |
+
{"name": "hay", "id": 1202, "trainId": 518},
|
615 |
+
{"name": "pack", "id": 1714, "trainId": 519},
|
616 |
+
{"name": "bottle rack", "id": 251, "trainId": 520},
|
617 |
+
{"name": "glacier", "id": 1095, "trainId": 521},
|
618 |
+
{"name": "pergola", "id": 1828, "trainId": 522},
|
619 |
+
{"name": "building materials", "id": 311, "trainId": 523},
|
620 |
+
{"name": "television camera", "id": 2732, "trainId": 524},
|
621 |
+
{"name": "first floor", "id": 947, "trainId": 525},
|
622 |
+
{"name": "rifle", "id": 2115, "trainId": 526},
|
623 |
+
{"name": "tennis table", "id": 2738, "trainId": 527},
|
624 |
+
{"name": "stadium", "id": 2525, "trainId": 528},
|
625 |
+
{"name": "safety belt", "id": 2194, "trainId": 529},
|
626 |
+
{"name": "cover", "id": 634, "trainId": 530},
|
627 |
+
{"name": "dish rack", "id": 740, "trainId": 531},
|
628 |
+
{"name": "synthesizer", "id": 2682, "trainId": 532},
|
629 |
+
{"name": "pumpkin", "id": 2020, "trainId": 533},
|
630 |
+
{"name": "gutter", "id": 1156, "trainId": 534},
|
631 |
+
{"name": "fruit stand", "id": 1036, "trainId": 535},
|
632 |
+
{"name": "ice floe, floe", "id": 1295, "trainId": 536},
|
633 |
+
{"name": "handle, grip, handgrip, hold", "id": 1181, "trainId": 537},
|
634 |
+
{"name": "wheelchair", "id": 3037, "trainId": 538},
|
635 |
+
{"name": "mousepad, mouse mat", "id": 1614, "trainId": 539},
|
636 |
+
{"name": "diploma", "id": 736, "trainId": 540},
|
637 |
+
{"name": "fairground ride", "id": 893, "trainId": 541},
|
638 |
+
{"name": "radio", "id": 2047, "trainId": 542},
|
639 |
+
{"name": "hotplate", "id": 1274, "trainId": 543},
|
640 |
+
{"name": "junk", "id": 1361, "trainId": 544},
|
641 |
+
{"name": "wheelbarrow", "id": 3036, "trainId": 545},
|
642 |
+
{"name": "stream", "id": 2606, "trainId": 546},
|
643 |
+
{"name": "toll plaza", "id": 2797, "trainId": 547},
|
644 |
+
{"name": "punching bag", "id": 2022, "trainId": 548},
|
645 |
+
{"name": "trough", "id": 2876, "trainId": 549},
|
646 |
+
{"name": "throne", "id": 2758, "trainId": 550},
|
647 |
+
{"name": "chair desk", "id": 472, "trainId": 551},
|
648 |
+
{"name": "weighbridge", "id": 3028, "trainId": 552},
|
649 |
+
{"name": "extractor fan", "id": 882, "trainId": 553},
|
650 |
+
{"name": "hanging clothes", "id": 1189, "trainId": 554},
|
651 |
+
{"name": "dish, dish aerial, dish antenna, saucer", "id": 743, "trainId": 555},
|
652 |
+
{"name": "alarm clock, alarm", "id": 3122, "trainId": 556},
|
653 |
+
{"name": "ski lift", "id": 2401, "trainId": 557},
|
654 |
+
{"name": "chain", "id": 468, "trainId": 558},
|
655 |
+
{"name": "garage", "id": 1061, "trainId": 559},
|
656 |
+
{"name": "mechanical shovel", "id": 1523, "trainId": 560},
|
657 |
+
{"name": "wine rack", "id": 3059, "trainId": 561},
|
658 |
+
{"name": "tramway", "id": 2843, "trainId": 562},
|
659 |
+
{"name": "treadmill", "id": 2853, "trainId": 563},
|
660 |
+
{"name": "menu", "id": 1529, "trainId": 564},
|
661 |
+
{"name": "block", "id": 214, "trainId": 565},
|
662 |
+
{"name": "well", "id": 3032, "trainId": 566},
|
663 |
+
{"name": "witness stand", "id": 3071, "trainId": 567},
|
664 |
+
{"name": "branch", "id": 277, "trainId": 568},
|
665 |
+
{"name": "duck", "id": 813, "trainId": 569},
|
666 |
+
{"name": "casserole", "id": 426, "trainId": 570},
|
667 |
+
{"name": "frying pan", "id": 1039, "trainId": 571},
|
668 |
+
{"name": "desk organizer", "id": 727, "trainId": 572},
|
669 |
+
{"name": "mast", "id": 1508, "trainId": 573},
|
670 |
+
{"name": "spectacles, specs, eyeglasses, glasses", "id": 2490, "trainId": 574},
|
671 |
+
{"name": "service elevator", "id": 2299, "trainId": 575},
|
672 |
+
{"name": "dollhouse", "id": 768, "trainId": 576},
|
673 |
+
{"name": "hammock", "id": 1172, "trainId": 577},
|
674 |
+
{"name": "clothes hanging", "id": 537, "trainId": 578},
|
675 |
+
{"name": "photocopier", "id": 1847, "trainId": 579},
|
676 |
+
{"name": "notepad", "id": 1664, "trainId": 580},
|
677 |
+
{"name": "golf cart", "id": 1110, "trainId": 581},
|
678 |
+
{"name": "footpath", "id": 1014, "trainId": 582},
|
679 |
+
{"name": "cross", "id": 662, "trainId": 583},
|
680 |
+
{"name": "baptismal font", "id": 121, "trainId": 584},
|
681 |
+
{"name": "boiler", "id": 227, "trainId": 585},
|
682 |
+
{"name": "skip", "id": 2410, "trainId": 586},
|
683 |
+
{"name": "rotisserie", "id": 2165, "trainId": 587},
|
684 |
+
{"name": "tables", "id": 2696, "trainId": 588},
|
685 |
+
{"name": "water mill", "id": 3005, "trainId": 589},
|
686 |
+
{"name": "helmet", "id": 1231, "trainId": 590},
|
687 |
+
{"name": "cover curtain", "id": 635, "trainId": 591},
|
688 |
+
{"name": "brick", "id": 292, "trainId": 592},
|
689 |
+
{"name": "table runner", "id": 2690, "trainId": 593},
|
690 |
+
{"name": "ashtray", "id": 65, "trainId": 594},
|
691 |
+
{"name": "street box", "id": 2607, "trainId": 595},
|
692 |
+
{"name": "stick", "id": 2574, "trainId": 596},
|
693 |
+
{"name": "hangers", "id": 1188, "trainId": 597},
|
694 |
+
{"name": "cells", "id": 456, "trainId": 598},
|
695 |
+
{"name": "urinal", "id": 2913, "trainId": 599},
|
696 |
+
{"name": "centerpiece", "id": 459, "trainId": 600},
|
697 |
+
{"name": "portable fridge", "id": 1955, "trainId": 601},
|
698 |
+
{"name": "dvds", "id": 827, "trainId": 602},
|
699 |
+
{"name": "golf club", "id": 1111, "trainId": 603},
|
700 |
+
{"name": "skirting board", "id": 2412, "trainId": 604},
|
701 |
+
{"name": "water cooler", "id": 2997, "trainId": 605},
|
702 |
+
{"name": "clipboard", "id": 528, "trainId": 606},
|
703 |
+
{"name": "camera, photographic camera", "id": 366, "trainId": 607},
|
704 |
+
{"name": "pigeonhole", "id": 1863, "trainId": 608},
|
705 |
+
{"name": "chips", "id": 500, "trainId": 609},
|
706 |
+
{"name": "food processor", "id": 1001, "trainId": 610},
|
707 |
+
{"name": "post box", "id": 1958, "trainId": 611},
|
708 |
+
{"name": "lid", "id": 1441, "trainId": 612},
|
709 |
+
{"name": "drum", "id": 809, "trainId": 613},
|
710 |
+
{"name": "blender", "id": 210, "trainId": 614},
|
711 |
+
{"name": "cave entrance", "id": 435, "trainId": 615},
|
712 |
+
{"name": "dental chair", "id": 718, "trainId": 616},
|
713 |
+
{"name": "obelisk", "id": 1674, "trainId": 617},
|
714 |
+
{"name": "canoe", "id": 388, "trainId": 618},
|
715 |
+
{"name": "mobile", "id": 1572, "trainId": 619},
|
716 |
+
{"name": "monitors", "id": 1584, "trainId": 620},
|
717 |
+
{"name": "pool ball", "id": 1944, "trainId": 621},
|
718 |
+
{"name": "cue rack", "id": 674, "trainId": 622},
|
719 |
+
{"name": "baggage carts", "id": 99, "trainId": 623},
|
720 |
+
{"name": "shore", "id": 2352, "trainId": 624},
|
721 |
+
{"name": "fork", "id": 1019, "trainId": 625},
|
722 |
+
{"name": "paper filer", "id": 1763, "trainId": 626},
|
723 |
+
{"name": "bicycle rack", "id": 185, "trainId": 627},
|
724 |
+
{"name": "coat rack", "id": 554, "trainId": 628},
|
725 |
+
{"name": "garland", "id": 1066, "trainId": 629},
|
726 |
+
{"name": "sports bag", "id": 2508, "trainId": 630},
|
727 |
+
{"name": "fish tank", "id": 951, "trainId": 631},
|
728 |
+
{"name": "towel dispenser", "id": 2822, "trainId": 632},
|
729 |
+
{"name": "carriage", "id": 415, "trainId": 633},
|
730 |
+
{"name": "brochure", "id": 297, "trainId": 634},
|
731 |
+
{"name": "plaque", "id": 1914, "trainId": 635},
|
732 |
+
{"name": "stringer", "id": 2619, "trainId": 636},
|
733 |
+
{"name": "iron", "id": 1338, "trainId": 637},
|
734 |
+
{"name": "spoon", "id": 2505, "trainId": 638},
|
735 |
+
{"name": "flag pole", "id": 955, "trainId": 639},
|
736 |
+
{"name": "toilet brush", "id": 2786, "trainId": 640},
|
737 |
+
{"name": "book stand", "id": 238, "trainId": 641},
|
738 |
+
{"name": "water faucet, water tap, tap, hydrant", "id": 3000, "trainId": 642},
|
739 |
+
{"name": "ticket office", "id": 2763, "trainId": 643},
|
740 |
+
{"name": "broom", "id": 299, "trainId": 644},
|
741 |
+
{"name": "dvd", "id": 822, "trainId": 645},
|
742 |
+
{"name": "ice bucket", "id": 1288, "trainId": 646},
|
743 |
+
{"name": "carapace, shell, cuticle, shield", "id": 3101, "trainId": 647},
|
744 |
+
{"name": "tureen", "id": 2894, "trainId": 648},
|
745 |
+
{"name": "folders", "id": 992, "trainId": 649},
|
746 |
+
{"name": "chess", "id": 489, "trainId": 650},
|
747 |
+
{"name": "root", "id": 2157, "trainId": 651},
|
748 |
+
{"name": "sewing machine", "id": 2309, "trainId": 652},
|
749 |
+
{"name": "model", "id": 1576, "trainId": 653},
|
750 |
+
{"name": "pen", "id": 1810, "trainId": 654},
|
751 |
+
{"name": "violin", "id": 2964, "trainId": 655},
|
752 |
+
{"name": "sweatshirt", "id": 2662, "trainId": 656},
|
753 |
+
{"name": "recycling materials", "id": 2087, "trainId": 657},
|
754 |
+
{"name": "mitten", "id": 1569, "trainId": 658},
|
755 |
+
{"name": "chopping board, cutting board", "id": 503, "trainId": 659},
|
756 |
+
{"name": "mask", "id": 1505, "trainId": 660},
|
757 |
+
{"name": "log", "id": 1468, "trainId": 661},
|
758 |
+
{"name": "mouse, computer mouse", "id": 1613, "trainId": 662},
|
759 |
+
{"name": "grill", "id": 1138, "trainId": 663},
|
760 |
+
{"name": "hole", "id": 1256, "trainId": 664},
|
761 |
+
{"name": "target", "id": 2715, "trainId": 665},
|
762 |
+
{"name": "trash bag", "id": 2846, "trainId": 666},
|
763 |
+
{"name": "chalk", "id": 477, "trainId": 667},
|
764 |
+
{"name": "sticks", "id": 2576, "trainId": 668},
|
765 |
+
{"name": "balloon", "id": 108, "trainId": 669},
|
766 |
+
{"name": "score", "id": 2245, "trainId": 670},
|
767 |
+
{"name": "hair spray", "id": 1162, "trainId": 671},
|
768 |
+
{"name": "roll", "id": 2149, "trainId": 672},
|
769 |
+
{"name": "runner", "id": 2183, "trainId": 673},
|
770 |
+
{"name": "engine", "id": 858, "trainId": 674},
|
771 |
+
{"name": "inflatable glove", "id": 1324, "trainId": 675},
|
772 |
+
{"name": "games", "id": 1055, "trainId": 676},
|
773 |
+
{"name": "pallets", "id": 1741, "trainId": 677},
|
774 |
+
{"name": "baskets", "id": 149, "trainId": 678},
|
775 |
+
{"name": "coop", "id": 615, "trainId": 679},
|
776 |
+
{"name": "dvd player", "id": 825, "trainId": 680},
|
777 |
+
{"name": "rocking horse", "id": 2143, "trainId": 681},
|
778 |
+
{"name": "buckets", "id": 304, "trainId": 682},
|
779 |
+
{"name": "bread rolls", "id": 283, "trainId": 683},
|
780 |
+
{"name": "shawl", "id": 2322, "trainId": 684},
|
781 |
+
{"name": "watering can", "id": 3017, "trainId": 685},
|
782 |
+
{"name": "spotlights", "id": 2510, "trainId": 686},
|
783 |
+
{"name": "post-it", "id": 1960, "trainId": 687},
|
784 |
+
{"name": "bowls", "id": 265, "trainId": 688},
|
785 |
+
{"name": "security camera", "id": 2282, "trainId": 689},
|
786 |
+
{"name": "runner cloth", "id": 2184, "trainId": 690},
|
787 |
+
{"name": "lock", "id": 1461, "trainId": 691},
|
788 |
+
{"name": "alarm, warning device, alarm system", "id": 3113, "trainId": 692},
|
789 |
+
{"name": "side", "id": 2372, "trainId": 693},
|
790 |
+
{"name": "roulette", "id": 2166, "trainId": 694},
|
791 |
+
{"name": "bone", "id": 232, "trainId": 695},
|
792 |
+
{"name": "cutlery", "id": 693, "trainId": 696},
|
793 |
+
{"name": "pool balls", "id": 1945, "trainId": 697},
|
794 |
+
{"name": "wheels", "id": 3039, "trainId": 698},
|
795 |
+
{"name": "spice rack", "id": 2494, "trainId": 699},
|
796 |
+
{"name": "plant pots", "id": 1908, "trainId": 700},
|
797 |
+
{"name": "towel ring", "id": 2827, "trainId": 701},
|
798 |
+
{"name": "bread box", "id": 280, "trainId": 702},
|
799 |
+
{"name": "video", "id": 2950, "trainId": 703},
|
800 |
+
{"name": "funfair", "id": 1044, "trainId": 704},
|
801 |
+
{"name": "breads", "id": 288, "trainId": 705},
|
802 |
+
{"name": "tripod", "id": 2863, "trainId": 706},
|
803 |
+
{"name": "ironing board", "id": 1342, "trainId": 707},
|
804 |
+
{"name": "skimmer", "id": 2409, "trainId": 708},
|
805 |
+
{"name": "hollow", "id": 1258, "trainId": 709},
|
806 |
+
{"name": "scratching post", "id": 2249, "trainId": 710},
|
807 |
+
{"name": "tricycle", "id": 2862, "trainId": 711},
|
808 |
+
{"name": "file box", "id": 920, "trainId": 712},
|
809 |
+
{"name": "mountain pass", "id": 1607, "trainId": 713},
|
810 |
+
{"name": "tombstones", "id": 2802, "trainId": 714},
|
811 |
+
{"name": "cooker", "id": 610, "trainId": 715},
|
812 |
+
{"name": "card game, cards", "id": 3129, "trainId": 716},
|
813 |
+
{"name": "golf bag", "id": 1108, "trainId": 717},
|
814 |
+
{"name": "towel paper", "id": 2823, "trainId": 718},
|
815 |
+
{"name": "chaise lounge", "id": 476, "trainId": 719},
|
816 |
+
{"name": "sun", "id": 2641, "trainId": 720},
|
817 |
+
{"name": "toilet paper holder", "id": 2788, "trainId": 721},
|
818 |
+
{"name": "rake", "id": 2070, "trainId": 722},
|
819 |
+
{"name": "key", "id": 1368, "trainId": 723},
|
820 |
+
{"name": "umbrella stand", "id": 2903, "trainId": 724},
|
821 |
+
{"name": "dartboard", "id": 699, "trainId": 725},
|
822 |
+
{"name": "transformer", "id": 2844, "trainId": 726},
|
823 |
+
{"name": "fireplace utensils", "id": 942, "trainId": 727},
|
824 |
+
{"name": "sweatshirts", "id": 2663, "trainId": 728},
|
825 |
+
{
|
826 |
+
"name": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
|
827 |
+
"id": 457,
|
828 |
+
"trainId": 729,
|
829 |
+
},
|
830 |
+
{"name": "tallboy", "id": 2701, "trainId": 730},
|
831 |
+
{"name": "stapler", "id": 2540, "trainId": 731},
|
832 |
+
{"name": "sauna", "id": 2231, "trainId": 732},
|
833 |
+
{"name": "test tube", "id": 2746, "trainId": 733},
|
834 |
+
{"name": "palette", "id": 1738, "trainId": 734},
|
835 |
+
{"name": "shopping carts", "id": 2350, "trainId": 735},
|
836 |
+
{"name": "tools", "id": 2808, "trainId": 736},
|
837 |
+
{"name": "push button, push, button", "id": 2025, "trainId": 737},
|
838 |
+
{"name": "star", "id": 2541, "trainId": 738},
|
839 |
+
{"name": "roof rack", "id": 2156, "trainId": 739},
|
840 |
+
{"name": "barbed wire", "id": 126, "trainId": 740},
|
841 |
+
{"name": "spray", "id": 2512, "trainId": 741},
|
842 |
+
{"name": "ear", "id": 831, "trainId": 742},
|
843 |
+
{"name": "sponge", "id": 2503, "trainId": 743},
|
844 |
+
{"name": "racket", "id": 2039, "trainId": 744},
|
845 |
+
{"name": "tins", "id": 2774, "trainId": 745},
|
846 |
+
{"name": "eyeglasses", "id": 886, "trainId": 746},
|
847 |
+
{"name": "file", "id": 919, "trainId": 747},
|
848 |
+
{"name": "scarfs", "id": 2240, "trainId": 748},
|
849 |
+
{"name": "sugar bowl", "id": 2636, "trainId": 749},
|
850 |
+
{"name": "flip flop", "id": 963, "trainId": 750},
|
851 |
+
{"name": "headstones", "id": 1218, "trainId": 751},
|
852 |
+
{"name": "laptop bag", "id": 1406, "trainId": 752},
|
853 |
+
{"name": "leash", "id": 1420, "trainId": 753},
|
854 |
+
{"name": "climbing frame", "id": 526, "trainId": 754},
|
855 |
+
{"name": "suit hanger", "id": 2639, "trainId": 755},
|
856 |
+
{"name": "floor spotlight", "id": 975, "trainId": 756},
|
857 |
+
{"name": "plate rack", "id": 1921, "trainId": 757},
|
858 |
+
{"name": "sewer", "id": 2305, "trainId": 758},
|
859 |
+
{"name": "hard drive", "id": 1193, "trainId": 759},
|
860 |
+
{"name": "sprinkler", "id": 2517, "trainId": 760},
|
861 |
+
{"name": "tools box", "id": 2809, "trainId": 761},
|
862 |
+
{"name": "necklace", "id": 1647, "trainId": 762},
|
863 |
+
{"name": "bulbs", "id": 314, "trainId": 763},
|
864 |
+
{"name": "steel industry", "id": 2560, "trainId": 764},
|
865 |
+
{"name": "club", "id": 545, "trainId": 765},
|
866 |
+
{"name": "jack", "id": 1345, "trainId": 766},
|
867 |
+
{"name": "door bars", "id": 775, "trainId": 767},
|
868 |
+
{
|
869 |
+
"name": "control panel, instrument panel, control board, board, panel",
|
870 |
+
"id": 603,
|
871 |
+
"trainId": 768,
|
872 |
+
},
|
873 |
+
{"name": "hairbrush", "id": 1163, "trainId": 769},
|
874 |
+
{"name": "napkin holder", "id": 1641, "trainId": 770},
|
875 |
+
{"name": "office", "id": 1678, "trainId": 771},
|
876 |
+
{"name": "smoke detector", "id": 2450, "trainId": 772},
|
877 |
+
{"name": "utensils", "id": 2915, "trainId": 773},
|
878 |
+
{"name": "apron", "id": 42, "trainId": 774},
|
879 |
+
{"name": "scissors", "id": 2242, "trainId": 775},
|
880 |
+
{"name": "terminal", "id": 2741, "trainId": 776},
|
881 |
+
{"name": "grinder", "id": 1143, "trainId": 777},
|
882 |
+
{"name": "entry phone", "id": 862, "trainId": 778},
|
883 |
+
{"name": "newspaper stand", "id": 1654, "trainId": 779},
|
884 |
+
{"name": "pepper shaker", "id": 1826, "trainId": 780},
|
885 |
+
{"name": "onions", "id": 1689, "trainId": 781},
|
886 |
+
{
|
887 |
+
"name": "central processing unit, cpu, c p u , central processor, processor, mainframe",
|
888 |
+
"id": 3124,
|
889 |
+
"trainId": 782,
|
890 |
+
},
|
891 |
+
{"name": "tape", "id": 2710, "trainId": 783},
|
892 |
+
{"name": "bat", "id": 152, "trainId": 784},
|
893 |
+
{"name": "coaster", "id": 549, "trainId": 785},
|
894 |
+
{"name": "calculator", "id": 360, "trainId": 786},
|
895 |
+
{"name": "potatoes", "id": 1982, "trainId": 787},
|
896 |
+
{"name": "luggage rack", "id": 1478, "trainId": 788},
|
897 |
+
{"name": "salt", "id": 2203, "trainId": 789},
|
898 |
+
{"name": "street number", "id": 2612, "trainId": 790},
|
899 |
+
{"name": "viewpoint", "id": 2956, "trainId": 791},
|
900 |
+
{"name": "sword", "id": 2681, "trainId": 792},
|
901 |
+
{"name": "cd", "id": 437, "trainId": 793},
|
902 |
+
{"name": "rowing machine", "id": 2171, "trainId": 794},
|
903 |
+
{"name": "plug", "id": 1933, "trainId": 795},
|
904 |
+
{"name": "andiron, firedog, dog, dog-iron", "id": 3110, "trainId": 796},
|
905 |
+
{"name": "pepper", "id": 1824, "trainId": 797},
|
906 |
+
{"name": "tongs", "id": 2803, "trainId": 798},
|
907 |
+
{"name": "bonfire", "id": 234, "trainId": 799},
|
908 |
+
{"name": "dog dish", "id": 764, "trainId": 800},
|
909 |
+
{"name": "belt", "id": 177, "trainId": 801},
|
910 |
+
{"name": "dumbbells", "id": 817, "trainId": 802},
|
911 |
+
{"name": "videocassette recorder, vcr", "id": 3145, "trainId": 803},
|
912 |
+
{"name": "hook", "id": 1262, "trainId": 804},
|
913 |
+
{"name": "envelopes", "id": 864, "trainId": 805},
|
914 |
+
{"name": "shower faucet", "id": 2359, "trainId": 806},
|
915 |
+
{"name": "watch", "id": 2992, "trainId": 807},
|
916 |
+
{"name": "padlock", "id": 1725, "trainId": 808},
|
917 |
+
{"name": "swimming pool ladder", "id": 2667, "trainId": 809},
|
918 |
+
{"name": "spanners", "id": 2484, "trainId": 810},
|
919 |
+
{"name": "gravy boat", "id": 1133, "trainId": 811},
|
920 |
+
{"name": "notice board", "id": 1667, "trainId": 812},
|
921 |
+
{"name": "trash bags", "id": 2847, "trainId": 813},
|
922 |
+
{"name": "fire alarm", "id": 932, "trainId": 814},
|
923 |
+
{"name": "ladle", "id": 1392, "trainId": 815},
|
924 |
+
{"name": "stethoscope", "id": 2573, "trainId": 816},
|
925 |
+
{"name": "rocket", "id": 2140, "trainId": 817},
|
926 |
+
{"name": "funnel", "id": 1046, "trainId": 818},
|
927 |
+
{"name": "bowling pins", "id": 264, "trainId": 819},
|
928 |
+
{"name": "valve", "id": 2927, "trainId": 820},
|
929 |
+
{"name": "thermometer", "id": 2752, "trainId": 821},
|
930 |
+
{"name": "cups", "id": 679, "trainId": 822},
|
931 |
+
{"name": "spice jar", "id": 2493, "trainId": 823},
|
932 |
+
{"name": "night light", "id": 1658, "trainId": 824},
|
933 |
+
{"name": "soaps", "id": 2466, "trainId": 825},
|
934 |
+
{"name": "games table", "id": 1057, "trainId": 826},
|
935 |
+
{"name": "slotted spoon", "id": 2444, "trainId": 827},
|
936 |
+
{"name": "reel", "id": 2093, "trainId": 828},
|
937 |
+
{"name": "scourer", "id": 2248, "trainId": 829},
|
938 |
+
{"name": "sleeping robe", "id": 2432, "trainId": 830},
|
939 |
+
{"name": "desk mat", "id": 726, "trainId": 831},
|
940 |
+
{"name": "dumbbell", "id": 816, "trainId": 832},
|
941 |
+
{"name": "hammer", "id": 1171, "trainId": 833},
|
942 |
+
{"name": "tie", "id": 2766, "trainId": 834},
|
943 |
+
{"name": "typewriter", "id": 2900, "trainId": 835},
|
944 |
+
{"name": "shaker", "id": 2313, "trainId": 836},
|
945 |
+
{"name": "cheese dish", "id": 488, "trainId": 837},
|
946 |
+
{"name": "sea star", "id": 2265, "trainId": 838},
|
947 |
+
{"name": "racquet", "id": 2043, "trainId": 839},
|
948 |
+
{"name": "butane gas cylinder", "id": 332, "trainId": 840},
|
949 |
+
{"name": "paper weight", "id": 1771, "trainId": 841},
|
950 |
+
{"name": "shaving brush", "id": 2320, "trainId": 842},
|
951 |
+
{"name": "sunglasses", "id": 2646, "trainId": 843},
|
952 |
+
{"name": "gear shift", "id": 1089, "trainId": 844},
|
953 |
+
{"name": "towel rail", "id": 2826, "trainId": 845},
|
954 |
+
{"name": "adding machine, totalizer, totaliser", "id": 3148, "trainId": 846},
|
955 |
+
]
|
956 |
+
|
957 |
+
|
958 |
+
def _get_ade20k_full_meta():
|
959 |
+
stuff_ids = [k["id"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
|
960 |
+
assert len(stuff_ids) == 847, len(stuff_ids)
|
961 |
+
|
962 |
+
stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
963 |
+
stuff_classes = [k["name"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
|
964 |
+
|
965 |
+
ret = {
|
966 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
967 |
+
"stuff_classes": stuff_classes,
|
968 |
+
}
|
969 |
+
return ret
|
970 |
+
|
971 |
+
|
972 |
+
def register_all_ade20k_full(root):
|
973 |
+
meta = _get_ade20k_full_meta()
|
974 |
+
for name, dirname in [("val", "validation")]:
|
975 |
+
image_dir = os.path.join(root, "ADE20K_2021_17_01/images_detectron2", dirname)
|
976 |
+
gt_dir = os.path.join(root, "ADE20K_2021_17_01/annotations_detectron2", dirname)
|
977 |
+
name = f"ade20k_full_sem_seg_{name}"
|
978 |
+
DatasetCatalog.register(
|
979 |
+
name,
|
980 |
+
lambda x=image_dir, y=gt_dir: load_sem_seg(
|
981 |
+
y, x, gt_ext="tif", image_ext="jpg"
|
982 |
+
),
|
983 |
+
)
|
984 |
+
MetadataCatalog.get(name).set(
|
985 |
+
stuff_classes=meta["stuff_classes"][:],
|
986 |
+
thing_classes=meta["stuff_classes"][:], # the same as stuff_classes
|
987 |
+
image_root=image_dir,
|
988 |
+
sem_seg_root=gt_dir,
|
989 |
+
evaluator_type="sem_seg",
|
990 |
+
ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images
|
991 |
+
)
|
992 |
+
|
993 |
+
|
994 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
995 |
+
register_all_ade20k_full(_root)
|
open_vocab_seg/data/datasets/register_cc3m.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import os
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
6 |
+
from detectron2.data.datasets import load_sem_seg
|
7 |
+
from detectron2.utils.file_io import PathManager
|
8 |
+
|
9 |
+
|
10 |
+
COCO_CATEGORIES = [
|
11 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
12 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
13 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
14 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
15 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
16 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
17 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
18 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
19 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
20 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
21 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
22 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
23 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
24 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
25 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
26 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
27 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
28 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
29 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
30 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
31 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
32 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
33 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
34 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
35 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
36 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
37 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
38 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
39 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
40 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
41 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
42 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
43 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
44 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
45 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
46 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
47 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
48 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
49 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
50 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
51 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
52 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
53 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
54 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
55 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
56 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
57 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
58 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
59 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
60 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
61 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
62 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
63 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
64 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
65 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
66 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
67 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
68 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
69 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
70 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
71 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
72 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
73 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
74 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
75 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
76 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
77 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
78 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
79 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
80 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
81 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
82 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
83 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
84 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
85 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
86 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
87 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
88 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
89 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
90 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
91 |
+
{"id": 92, "name": "banner", "supercategory": "textile"},
|
92 |
+
{"id": 93, "name": "blanket", "supercategory": "textile"},
|
93 |
+
{"id": 94, "name": "branch", "supercategory": "plant"},
|
94 |
+
{"id": 95, "name": "bridge", "supercategory": "building"},
|
95 |
+
{"id": 96, "name": "building-other", "supercategory": "building"},
|
96 |
+
{"id": 97, "name": "bush", "supercategory": "plant"},
|
97 |
+
{"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
|
98 |
+
{"id": 99, "name": "cage", "supercategory": "structural"},
|
99 |
+
{"id": 100, "name": "cardboard", "supercategory": "raw-material"},
|
100 |
+
{"id": 101, "name": "carpet", "supercategory": "floor"},
|
101 |
+
{"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
|
102 |
+
{"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
|
103 |
+
{"id": 104, "name": "cloth", "supercategory": "textile"},
|
104 |
+
{"id": 105, "name": "clothes", "supercategory": "textile"},
|
105 |
+
{"id": 106, "name": "clouds", "supercategory": "sky"},
|
106 |
+
{"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
|
107 |
+
{"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
|
108 |
+
{"id": 109, "name": "curtain", "supercategory": "textile"},
|
109 |
+
{"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
|
110 |
+
{"id": 111, "name": "dirt", "supercategory": "ground"},
|
111 |
+
{"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
|
112 |
+
{"id": 113, "name": "fence", "supercategory": "structural"},
|
113 |
+
{"id": 114, "name": "floor-marble", "supercategory": "floor"},
|
114 |
+
{"id": 115, "name": "floor-other", "supercategory": "floor"},
|
115 |
+
{"id": 116, "name": "floor-stone", "supercategory": "floor"},
|
116 |
+
{"id": 117, "name": "floor-tile", "supercategory": "floor"},
|
117 |
+
{"id": 118, "name": "floor-wood", "supercategory": "floor"},
|
118 |
+
{"id": 119, "name": "flower", "supercategory": "plant"},
|
119 |
+
{"id": 120, "name": "fog", "supercategory": "water"},
|
120 |
+
{"id": 121, "name": "food-other", "supercategory": "food-stuff"},
|
121 |
+
{"id": 122, "name": "fruit", "supercategory": "food-stuff"},
|
122 |
+
{"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
|
123 |
+
{"id": 124, "name": "grass", "supercategory": "plant"},
|
124 |
+
{"id": 125, "name": "gravel", "supercategory": "ground"},
|
125 |
+
{"id": 126, "name": "ground-other", "supercategory": "ground"},
|
126 |
+
{"id": 127, "name": "hill", "supercategory": "solid"},
|
127 |
+
{"id": 128, "name": "house", "supercategory": "building"},
|
128 |
+
{"id": 129, "name": "leaves", "supercategory": "plant"},
|
129 |
+
{"id": 130, "name": "light", "supercategory": "furniture-stuff"},
|
130 |
+
{"id": 131, "name": "mat", "supercategory": "textile"},
|
131 |
+
{"id": 132, "name": "metal", "supercategory": "raw-material"},
|
132 |
+
{"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
|
133 |
+
{"id": 134, "name": "moss", "supercategory": "plant"},
|
134 |
+
{"id": 135, "name": "mountain", "supercategory": "solid"},
|
135 |
+
{"id": 136, "name": "mud", "supercategory": "ground"},
|
136 |
+
{"id": 137, "name": "napkin", "supercategory": "textile"},
|
137 |
+
{"id": 138, "name": "net", "supercategory": "structural"},
|
138 |
+
{"id": 139, "name": "paper", "supercategory": "raw-material"},
|
139 |
+
{"id": 140, "name": "pavement", "supercategory": "ground"},
|
140 |
+
{"id": 141, "name": "pillow", "supercategory": "textile"},
|
141 |
+
{"id": 142, "name": "plant-other", "supercategory": "plant"},
|
142 |
+
{"id": 143, "name": "plastic", "supercategory": "raw-material"},
|
143 |
+
{"id": 144, "name": "platform", "supercategory": "ground"},
|
144 |
+
{"id": 145, "name": "playingfield", "supercategory": "ground"},
|
145 |
+
{"id": 146, "name": "railing", "supercategory": "structural"},
|
146 |
+
{"id": 147, "name": "railroad", "supercategory": "ground"},
|
147 |
+
{"id": 148, "name": "river", "supercategory": "water"},
|
148 |
+
{"id": 149, "name": "road", "supercategory": "ground"},
|
149 |
+
{"id": 150, "name": "rock", "supercategory": "solid"},
|
150 |
+
{"id": 151, "name": "roof", "supercategory": "building"},
|
151 |
+
{"id": 152, "name": "rug", "supercategory": "textile"},
|
152 |
+
{"id": 153, "name": "salad", "supercategory": "food-stuff"},
|
153 |
+
{"id": 154, "name": "sand", "supercategory": "ground"},
|
154 |
+
{"id": 155, "name": "sea", "supercategory": "water"},
|
155 |
+
{"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
|
156 |
+
{"id": 157, "name": "sky-other", "supercategory": "sky"},
|
157 |
+
{"id": 158, "name": "skyscraper", "supercategory": "building"},
|
158 |
+
{"id": 159, "name": "snow", "supercategory": "ground"},
|
159 |
+
{"id": 160, "name": "solid-other", "supercategory": "solid"},
|
160 |
+
{"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
|
161 |
+
{"id": 162, "name": "stone", "supercategory": "solid"},
|
162 |
+
{"id": 163, "name": "straw", "supercategory": "plant"},
|
163 |
+
{"id": 164, "name": "structural-other", "supercategory": "structural"},
|
164 |
+
{"id": 165, "name": "table", "supercategory": "furniture-stuff"},
|
165 |
+
{"id": 166, "name": "tent", "supercategory": "building"},
|
166 |
+
{"id": 167, "name": "textile-other", "supercategory": "textile"},
|
167 |
+
{"id": 168, "name": "towel", "supercategory": "textile"},
|
168 |
+
{"id": 169, "name": "tree", "supercategory": "plant"},
|
169 |
+
{"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
|
170 |
+
{"id": 171, "name": "wall-brick", "supercategory": "wall"},
|
171 |
+
{"id": 172, "name": "wall-concrete", "supercategory": "wall"},
|
172 |
+
{"id": 173, "name": "wall-other", "supercategory": "wall"},
|
173 |
+
{"id": 174, "name": "wall-panel", "supercategory": "wall"},
|
174 |
+
{"id": 175, "name": "wall-stone", "supercategory": "wall"},
|
175 |
+
{"id": 176, "name": "wall-tile", "supercategory": "wall"},
|
176 |
+
{"id": 177, "name": "wall-wood", "supercategory": "wall"},
|
177 |
+
{"id": 178, "name": "water-other", "supercategory": "water"},
|
178 |
+
{"id": 179, "name": "waterdrops", "supercategory": "water"},
|
179 |
+
{"id": 180, "name": "window-blind", "supercategory": "window"},
|
180 |
+
{"id": 181, "name": "window-other", "supercategory": "window"},
|
181 |
+
{"id": 182, "name": "wood", "supercategory": "solid"},
|
182 |
+
]
|
183 |
+
|
184 |
+
|
185 |
+
ADE20K_150_CATEGORIES = [
|
186 |
+
{"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
|
187 |
+
{"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
|
188 |
+
{"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
|
189 |
+
{"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
|
190 |
+
{"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
|
191 |
+
{"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
|
192 |
+
{"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
|
193 |
+
{"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
|
194 |
+
{"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
|
195 |
+
{"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
|
196 |
+
{"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
|
197 |
+
{"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
|
198 |
+
{"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
|
199 |
+
{"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
|
200 |
+
{"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
|
201 |
+
{"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
|
202 |
+
{"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
|
203 |
+
{"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
|
204 |
+
{"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
|
205 |
+
{"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
|
206 |
+
{"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
|
207 |
+
{"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
|
208 |
+
{"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
|
209 |
+
{"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
|
210 |
+
{"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
|
211 |
+
{"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
|
212 |
+
{"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
|
213 |
+
{"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
|
214 |
+
{"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
|
215 |
+
{"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
|
216 |
+
{"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
|
217 |
+
{"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
|
218 |
+
{"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
|
219 |
+
{"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
|
220 |
+
{"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
|
221 |
+
{"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
|
222 |
+
{"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
|
223 |
+
{"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
|
224 |
+
{"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
|
225 |
+
{"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
|
226 |
+
{"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
|
227 |
+
{"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
|
228 |
+
{"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
|
229 |
+
{"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
|
230 |
+
{
|
231 |
+
"color": [6, 51, 255],
|
232 |
+
"id": 44,
|
233 |
+
"isthing": 1,
|
234 |
+
"name": "chest of drawers, chest, bureau, dresser",
|
235 |
+
},
|
236 |
+
{"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
|
237 |
+
{"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
|
238 |
+
{"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
|
239 |
+
{"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
|
240 |
+
{"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
|
241 |
+
{"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
|
242 |
+
{"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
|
243 |
+
{"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
|
244 |
+
{"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
|
245 |
+
{"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
|
246 |
+
{"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
|
247 |
+
{
|
248 |
+
"color": [255, 71, 0],
|
249 |
+
"id": 56,
|
250 |
+
"isthing": 1,
|
251 |
+
"name": "pool table, billiard table, snooker table",
|
252 |
+
},
|
253 |
+
{"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
|
254 |
+
{"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
|
255 |
+
{"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
|
256 |
+
{"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
|
257 |
+
{"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
|
258 |
+
{"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
|
259 |
+
{"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
|
260 |
+
{"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
|
261 |
+
{
|
262 |
+
"color": [0, 255, 133],
|
263 |
+
"id": 65,
|
264 |
+
"isthing": 1,
|
265 |
+
"name": "toilet, can, commode, crapper, pot, potty, stool, throne",
|
266 |
+
},
|
267 |
+
{"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
|
268 |
+
{"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
|
269 |
+
{"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
|
270 |
+
{"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
|
271 |
+
{"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
|
272 |
+
{"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
|
273 |
+
{"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
|
274 |
+
{"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
|
275 |
+
{"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
|
276 |
+
{"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
|
277 |
+
{"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
|
278 |
+
{"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
|
279 |
+
{"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
|
280 |
+
{"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
|
281 |
+
{"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
|
282 |
+
{"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
|
283 |
+
{"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
|
284 |
+
{"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
|
285 |
+
{"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
|
286 |
+
{"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
|
287 |
+
{"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
|
288 |
+
{"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
|
289 |
+
{"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
|
290 |
+
{"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
|
291 |
+
{"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
|
292 |
+
{"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
|
293 |
+
{"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
|
294 |
+
{"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
|
295 |
+
{"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
|
296 |
+
{
|
297 |
+
"color": [0, 122, 255],
|
298 |
+
"id": 95,
|
299 |
+
"isthing": 1,
|
300 |
+
"name": "bannister, banister, balustrade, balusters, handrail",
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"color": [0, 255, 163],
|
304 |
+
"id": 96,
|
305 |
+
"isthing": 0,
|
306 |
+
"name": "escalator, moving staircase, moving stairway",
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"color": [255, 153, 0],
|
310 |
+
"id": 97,
|
311 |
+
"isthing": 1,
|
312 |
+
"name": "ottoman, pouf, pouffe, puff, hassock",
|
313 |
+
},
|
314 |
+
{"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
|
315 |
+
{"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
|
316 |
+
{
|
317 |
+
"color": [143, 255, 0],
|
318 |
+
"id": 100,
|
319 |
+
"isthing": 0,
|
320 |
+
"name": "poster, posting, placard, notice, bill, card",
|
321 |
+
},
|
322 |
+
{"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
|
323 |
+
{"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
|
324 |
+
{"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
|
325 |
+
{"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
|
326 |
+
{
|
327 |
+
"color": [133, 0, 255],
|
328 |
+
"id": 105,
|
329 |
+
"isthing": 0,
|
330 |
+
"name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
|
331 |
+
},
|
332 |
+
{"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
|
333 |
+
{
|
334 |
+
"color": [184, 0, 255],
|
335 |
+
"id": 107,
|
336 |
+
"isthing": 1,
|
337 |
+
"name": "washer, automatic washer, washing machine",
|
338 |
+
},
|
339 |
+
{"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
|
340 |
+
{"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
|
341 |
+
{"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
|
342 |
+
{"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
|
343 |
+
{"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
|
344 |
+
{"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
|
345 |
+
{"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
|
346 |
+
{"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
|
347 |
+
{"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
|
348 |
+
{"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
|
349 |
+
{"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
|
350 |
+
{"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
|
351 |
+
{"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
|
352 |
+
{"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
|
353 |
+
{"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
|
354 |
+
{"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
|
355 |
+
{"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
|
356 |
+
{"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
|
357 |
+
{"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
|
358 |
+
{"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
|
359 |
+
{"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
|
360 |
+
{"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
|
361 |
+
{"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
|
362 |
+
{"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
|
363 |
+
{"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
|
364 |
+
{"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
|
365 |
+
{"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
|
366 |
+
{"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
|
367 |
+
{"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
|
368 |
+
{"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
|
369 |
+
{"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
|
370 |
+
{"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
|
371 |
+
{"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
|
372 |
+
{"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
|
373 |
+
{"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
|
374 |
+
{"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
|
375 |
+
{"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
|
376 |
+
{"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
|
377 |
+
{"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
|
378 |
+
{"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
|
379 |
+
{"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
|
380 |
+
{"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
|
381 |
+
]
|
382 |
+
|
383 |
+
TEST_CATEGORIES = [
|
384 |
+
{"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "Oculus"},
|
385 |
+
{"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "Ukulele"},
|
386 |
+
]
|
387 |
+
|
388 |
+
COCO_BASE_CATEGORIES = [
|
389 |
+
c
|
390 |
+
for i, c in enumerate(COCO_CATEGORIES)
|
391 |
+
if c["id"] - 1
|
392 |
+
not in [20, 24, 32, 33, 40, 56, 86, 99, 105, 123, 144, 147, 148, 168, 171]
|
393 |
+
]
|
394 |
+
COCO_NOVEL_CATEGORIES = [
|
395 |
+
c
|
396 |
+
for i, c in enumerate(COCO_CATEGORIES)
|
397 |
+
if c["id"] - 1
|
398 |
+
in [20, 24, 32, 33, 40, 56, 86, 99, 105, 123, 144, 147, 148, 168, 171]
|
399 |
+
]
|
400 |
+
|
401 |
+
|
402 |
+
def load_cc_image(csv_file, img_key='filepath', caption_key='title', sep="\t"):
|
403 |
+
print(f'Loading csv data from {csv_file}.')
|
404 |
+
df = pd.read_csv(csv_file, sep=sep)
|
405 |
+
|
406 |
+
input_files = df[img_key].tolist()
|
407 |
+
captions = df[caption_key].tolist()
|
408 |
+
|
409 |
+
print("Loaded {} images".format(len(input_files)))
|
410 |
+
|
411 |
+
dataset_dicts = []
|
412 |
+
for (img_path, text) in zip(input_files, captions):
|
413 |
+
record = {}
|
414 |
+
record["file_name"] = img_path
|
415 |
+
record["caption"] = text
|
416 |
+
dataset_dicts.append(record)
|
417 |
+
|
418 |
+
return dataset_dicts
|
419 |
+
|
420 |
+
|
421 |
+
def _get_coco_stuff_meta(cat_list):
|
422 |
+
# Id 0 is reserved for ignore_label, we change ignore_label for 0
|
423 |
+
# to 255 in our pre-processing.
|
424 |
+
stuff_ids = [k["id"] for k in cat_list]
|
425 |
+
|
426 |
+
# For semantic segmentation, this mapping maps from contiguous stuff id
|
427 |
+
# (in [0, 91], used in models) to ids in the dataset (used for processing results)
|
428 |
+
stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
429 |
+
stuff_classes = [k["name"] for k in cat_list]
|
430 |
+
|
431 |
+
ret = {
|
432 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
433 |
+
"stuff_classes": stuff_classes,
|
434 |
+
}
|
435 |
+
return ret
|
436 |
+
|
437 |
+
|
438 |
+
def register_cc_3m(csv_file):
|
439 |
+
|
440 |
+
meta = _get_coco_stuff_meta(TEST_CATEGORIES)
|
441 |
+
name = "cc_3m_train"
|
442 |
+
|
443 |
+
DatasetCatalog.register(
|
444 |
+
name,
|
445 |
+
lambda x=csv_file: load_cc_image(x),
|
446 |
+
)
|
447 |
+
MetadataCatalog.get(name).set(
|
448 |
+
csv_file=csv_file,
|
449 |
+
evaluator_type="dummy",
|
450 |
+
ignore_label=255,
|
451 |
+
**meta,
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
# _csv_file = "/home/jeffliang/zsseg/datasets/coco/coco_train_merge_captions.csv"
|
456 |
+
_csv_file = "/home/jeffliang/zsseg/configs/masked_images/pred/samples.csv"
|
457 |
+
register_cc_3m(_csv_file)
|
open_vocab_seg/data/datasets/register_coco_stuff.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import os
|
3 |
+
|
4 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
5 |
+
from detectron2.data.datasets import load_sem_seg
|
6 |
+
|
7 |
+
|
8 |
+
COCO_CATEGORIES = [
|
9 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
10 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
11 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
12 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
13 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
14 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
15 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
16 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
17 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
18 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
19 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
20 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
21 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
22 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
23 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
24 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
25 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
26 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
27 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
28 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
29 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
30 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
31 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
32 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
33 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
34 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
35 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
36 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
37 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
38 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
39 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
40 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
41 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
42 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
43 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
44 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
45 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
46 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
47 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
48 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
49 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
50 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
51 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
52 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
53 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
54 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
55 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
56 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
57 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
58 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
59 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
60 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
61 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
62 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
63 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
64 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
65 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
66 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
67 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
68 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
69 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
70 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
71 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
72 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
73 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
74 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
75 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
76 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
77 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
78 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
79 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
80 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
81 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
82 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
83 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
84 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
85 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
86 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
87 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
88 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
89 |
+
{"id": 92, "name": "banner", "supercategory": "textile"},
|
90 |
+
{"id": 93, "name": "blanket", "supercategory": "textile"},
|
91 |
+
{"id": 94, "name": "branch", "supercategory": "plant"},
|
92 |
+
{"id": 95, "name": "bridge", "supercategory": "building"},
|
93 |
+
{"id": 96, "name": "building-other", "supercategory": "building"},
|
94 |
+
{"id": 97, "name": "bush", "supercategory": "plant"},
|
95 |
+
{"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
|
96 |
+
{"id": 99, "name": "cage", "supercategory": "structural"},
|
97 |
+
{"id": 100, "name": "cardboard", "supercategory": "raw-material"},
|
98 |
+
{"id": 101, "name": "carpet", "supercategory": "floor"},
|
99 |
+
{"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
|
100 |
+
{"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
|
101 |
+
{"id": 104, "name": "cloth", "supercategory": "textile"},
|
102 |
+
{"id": 105, "name": "clothes", "supercategory": "textile"},
|
103 |
+
{"id": 106, "name": "clouds", "supercategory": "sky"},
|
104 |
+
{"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
|
105 |
+
{"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
|
106 |
+
{"id": 109, "name": "curtain", "supercategory": "textile"},
|
107 |
+
{"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
|
108 |
+
{"id": 111, "name": "dirt", "supercategory": "ground"},
|
109 |
+
{"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
|
110 |
+
{"id": 113, "name": "fence", "supercategory": "structural"},
|
111 |
+
{"id": 114, "name": "floor-marble", "supercategory": "floor"},
|
112 |
+
{"id": 115, "name": "floor-other", "supercategory": "floor"},
|
113 |
+
{"id": 116, "name": "floor-stone", "supercategory": "floor"},
|
114 |
+
{"id": 117, "name": "floor-tile", "supercategory": "floor"},
|
115 |
+
{"id": 118, "name": "floor-wood", "supercategory": "floor"},
|
116 |
+
{"id": 119, "name": "flower", "supercategory": "plant"},
|
117 |
+
{"id": 120, "name": "fog", "supercategory": "water"},
|
118 |
+
{"id": 121, "name": "food-other", "supercategory": "food-stuff"},
|
119 |
+
{"id": 122, "name": "fruit", "supercategory": "food-stuff"},
|
120 |
+
{"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
|
121 |
+
{"id": 124, "name": "grass", "supercategory": "plant"},
|
122 |
+
{"id": 125, "name": "gravel", "supercategory": "ground"},
|
123 |
+
{"id": 126, "name": "ground-other", "supercategory": "ground"},
|
124 |
+
{"id": 127, "name": "hill", "supercategory": "solid"},
|
125 |
+
{"id": 128, "name": "house", "supercategory": "building"},
|
126 |
+
{"id": 129, "name": "leaves", "supercategory": "plant"},
|
127 |
+
{"id": 130, "name": "light", "supercategory": "furniture-stuff"},
|
128 |
+
{"id": 131, "name": "mat", "supercategory": "textile"},
|
129 |
+
{"id": 132, "name": "metal", "supercategory": "raw-material"},
|
130 |
+
{"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
|
131 |
+
{"id": 134, "name": "moss", "supercategory": "plant"},
|
132 |
+
{"id": 135, "name": "mountain", "supercategory": "solid"},
|
133 |
+
{"id": 136, "name": "mud", "supercategory": "ground"},
|
134 |
+
{"id": 137, "name": "napkin", "supercategory": "textile"},
|
135 |
+
{"id": 138, "name": "net", "supercategory": "structural"},
|
136 |
+
{"id": 139, "name": "paper", "supercategory": "raw-material"},
|
137 |
+
{"id": 140, "name": "pavement", "supercategory": "ground"},
|
138 |
+
{"id": 141, "name": "pillow", "supercategory": "textile"},
|
139 |
+
{"id": 142, "name": "plant-other", "supercategory": "plant"},
|
140 |
+
{"id": 143, "name": "plastic", "supercategory": "raw-material"},
|
141 |
+
{"id": 144, "name": "platform", "supercategory": "ground"},
|
142 |
+
{"id": 145, "name": "playingfield", "supercategory": "ground"},
|
143 |
+
{"id": 146, "name": "railing", "supercategory": "structural"},
|
144 |
+
{"id": 147, "name": "railroad", "supercategory": "ground"},
|
145 |
+
{"id": 148, "name": "river", "supercategory": "water"},
|
146 |
+
{"id": 149, "name": "road", "supercategory": "ground"},
|
147 |
+
{"id": 150, "name": "rock", "supercategory": "solid"},
|
148 |
+
{"id": 151, "name": "roof", "supercategory": "building"},
|
149 |
+
{"id": 152, "name": "rug", "supercategory": "textile"},
|
150 |
+
{"id": 153, "name": "salad", "supercategory": "food-stuff"},
|
151 |
+
{"id": 154, "name": "sand", "supercategory": "ground"},
|
152 |
+
{"id": 155, "name": "sea", "supercategory": "water"},
|
153 |
+
{"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
|
154 |
+
{"id": 157, "name": "sky-other", "supercategory": "sky"},
|
155 |
+
{"id": 158, "name": "skyscraper", "supercategory": "building"},
|
156 |
+
{"id": 159, "name": "snow", "supercategory": "ground"},
|
157 |
+
{"id": 160, "name": "solid-other", "supercategory": "solid"},
|
158 |
+
{"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
|
159 |
+
{"id": 162, "name": "stone", "supercategory": "solid"},
|
160 |
+
{"id": 163, "name": "straw", "supercategory": "plant"},
|
161 |
+
{"id": 164, "name": "structural-other", "supercategory": "structural"},
|
162 |
+
{"id": 165, "name": "table", "supercategory": "furniture-stuff"},
|
163 |
+
{"id": 166, "name": "tent", "supercategory": "building"},
|
164 |
+
{"id": 167, "name": "textile-other", "supercategory": "textile"},
|
165 |
+
{"id": 168, "name": "towel", "supercategory": "textile"},
|
166 |
+
{"id": 169, "name": "tree", "supercategory": "plant"},
|
167 |
+
{"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
|
168 |
+
{"id": 171, "name": "wall-brick", "supercategory": "wall"},
|
169 |
+
{"id": 172, "name": "wall-concrete", "supercategory": "wall"},
|
170 |
+
{"id": 173, "name": "wall-other", "supercategory": "wall"},
|
171 |
+
{"id": 174, "name": "wall-panel", "supercategory": "wall"},
|
172 |
+
{"id": 175, "name": "wall-stone", "supercategory": "wall"},
|
173 |
+
{"id": 176, "name": "wall-tile", "supercategory": "wall"},
|
174 |
+
{"id": 177, "name": "wall-wood", "supercategory": "wall"},
|
175 |
+
{"id": 178, "name": "water-other", "supercategory": "water"},
|
176 |
+
{"id": 179, "name": "waterdrops", "supercategory": "water"},
|
177 |
+
{"id": 180, "name": "window-blind", "supercategory": "window"},
|
178 |
+
{"id": 181, "name": "window-other", "supercategory": "window"},
|
179 |
+
{"id": 182, "name": "wood", "supercategory": "solid"},
|
180 |
+
]
|
181 |
+
|
182 |
+
def _get_coco_stuff_meta(cat_list):
|
183 |
+
# Id 0 is reserved for ignore_label, we change ignore_label for 0
|
184 |
+
# to 255 in our pre-processing.
|
185 |
+
stuff_ids = [k["id"] for k in cat_list]
|
186 |
+
|
187 |
+
# For semantic segmentation, this mapping maps from contiguous stuff id
|
188 |
+
# (in [0, 91], used in models) to ids in the dataset (used for processing results)
|
189 |
+
stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
190 |
+
stuff_classes = [k["name"] for k in cat_list]
|
191 |
+
|
192 |
+
ret = {
|
193 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
194 |
+
"stuff_classes": stuff_classes,
|
195 |
+
}
|
196 |
+
return ret
|
197 |
+
|
198 |
+
|
199 |
+
def register_all_coco_stuff_10k(root):
|
200 |
+
root = os.path.join(root, "coco", "coco_stuff_10k")
|
201 |
+
meta = _get_coco_stuff_meta(COCO_CATEGORIES)
|
202 |
+
for name, image_dirname, sem_seg_dirname in [
|
203 |
+
("train", "images_detectron2/train", "annotations_detectron2/train"),
|
204 |
+
]:
|
205 |
+
image_dir = os.path.join(root, image_dirname)
|
206 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
207 |
+
name = f"coco_2017_{name}_stuff_10k_sem_seg"
|
208 |
+
DatasetCatalog.register(
|
209 |
+
name,
|
210 |
+
lambda x=image_dir, y=gt_dir: load_sem_seg(
|
211 |
+
y, x, gt_ext="png", image_ext="jpg"
|
212 |
+
),
|
213 |
+
)
|
214 |
+
MetadataCatalog.get(name).set(
|
215 |
+
image_root=image_dir,
|
216 |
+
sem_seg_root=gt_dir,
|
217 |
+
evaluator_type="sem_seg",
|
218 |
+
ignore_label=255,
|
219 |
+
**meta,
|
220 |
+
)
|
221 |
+
|
222 |
+
|
223 |
+
def register_all_coco_stuff(root):
|
224 |
+
root = os.path.join(root, "coco")
|
225 |
+
meta = _get_coco_stuff_meta(COCO_CATEGORIES)
|
226 |
+
|
227 |
+
for name, image_dirname, sem_seg_dirname in [
|
228 |
+
("train", "train2017", "stuffthingmaps_detectron2/train2017"),
|
229 |
+
]:
|
230 |
+
image_dir = os.path.join(root, image_dirname)
|
231 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
232 |
+
all_name = f"coco_2017_{name}_stuff_sem_seg"
|
233 |
+
DatasetCatalog.register(
|
234 |
+
all_name,
|
235 |
+
lambda x=image_dir, y=gt_dir: load_sem_seg(
|
236 |
+
y, x, gt_ext="png", image_ext="jpg"
|
237 |
+
),
|
238 |
+
)
|
239 |
+
MetadataCatalog.get(all_name).set(
|
240 |
+
image_root=image_dir,
|
241 |
+
sem_seg_root=gt_dir,
|
242 |
+
evaluator_type="sem_seg",
|
243 |
+
ignore_label=255,
|
244 |
+
**meta,
|
245 |
+
)
|
246 |
+
|
247 |
+
|
248 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
249 |
+
register_all_coco_stuff_10k(_root)
|
250 |
+
register_all_coco_stuff(_root)
|
open_vocab_seg/data/datasets/register_pascal_context.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import os
|
3 |
+
|
4 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
5 |
+
from detectron2.data.datasets import load_sem_seg
|
6 |
+
|
7 |
+
PASCALCONTEX59_NAMES = (
|
8 |
+
"aeroplane",
|
9 |
+
"bicycle",
|
10 |
+
"bird",
|
11 |
+
"boat",
|
12 |
+
"bottle",
|
13 |
+
"bus",
|
14 |
+
"car",
|
15 |
+
"cat",
|
16 |
+
"chair",
|
17 |
+
"cow",
|
18 |
+
"table",
|
19 |
+
"dog",
|
20 |
+
"horse",
|
21 |
+
"motorbike",
|
22 |
+
"person",
|
23 |
+
"pottedplant",
|
24 |
+
"sheep",
|
25 |
+
"sofa",
|
26 |
+
"train",
|
27 |
+
"tvmonitor",
|
28 |
+
"bag",
|
29 |
+
"bed",
|
30 |
+
"bench",
|
31 |
+
"book",
|
32 |
+
"building",
|
33 |
+
"cabinet",
|
34 |
+
"ceiling",
|
35 |
+
"cloth",
|
36 |
+
"computer",
|
37 |
+
"cup",
|
38 |
+
"door",
|
39 |
+
"fence",
|
40 |
+
"floor",
|
41 |
+
"flower",
|
42 |
+
"food",
|
43 |
+
"grass",
|
44 |
+
"ground",
|
45 |
+
"keyboard",
|
46 |
+
"light",
|
47 |
+
"mountain",
|
48 |
+
"mouse",
|
49 |
+
"curtain",
|
50 |
+
"platform",
|
51 |
+
"sign",
|
52 |
+
"plate",
|
53 |
+
"road",
|
54 |
+
"rock",
|
55 |
+
"shelves",
|
56 |
+
"sidewalk",
|
57 |
+
"sky",
|
58 |
+
"snow",
|
59 |
+
"bedclothes",
|
60 |
+
"track",
|
61 |
+
"tree",
|
62 |
+
"truck",
|
63 |
+
"wall",
|
64 |
+
"water",
|
65 |
+
"window",
|
66 |
+
"wood",
|
67 |
+
)
|
68 |
+
|
69 |
+
PASCALCONTEX459_NAMES = (
|
70 |
+
"accordion",
|
71 |
+
"aeroplane",
|
72 |
+
"air conditioner",
|
73 |
+
"antenna",
|
74 |
+
"artillery",
|
75 |
+
"ashtray",
|
76 |
+
"atrium",
|
77 |
+
"baby carriage",
|
78 |
+
"bag",
|
79 |
+
"ball",
|
80 |
+
"balloon",
|
81 |
+
"bamboo weaving",
|
82 |
+
"barrel",
|
83 |
+
"baseball bat",
|
84 |
+
"basket",
|
85 |
+
"basketball backboard",
|
86 |
+
"bathtub",
|
87 |
+
"bed",
|
88 |
+
"bedclothes",
|
89 |
+
"beer",
|
90 |
+
"bell",
|
91 |
+
"bench",
|
92 |
+
"bicycle",
|
93 |
+
"binoculars",
|
94 |
+
"bird",
|
95 |
+
"bird cage",
|
96 |
+
"bird feeder",
|
97 |
+
"bird nest",
|
98 |
+
"blackboard",
|
99 |
+
"board",
|
100 |
+
"boat",
|
101 |
+
"bone",
|
102 |
+
"book",
|
103 |
+
"bottle",
|
104 |
+
"bottle opener",
|
105 |
+
"bowl",
|
106 |
+
"box",
|
107 |
+
"bracelet",
|
108 |
+
"brick",
|
109 |
+
"bridge",
|
110 |
+
"broom",
|
111 |
+
"brush",
|
112 |
+
"bucket",
|
113 |
+
"building",
|
114 |
+
"bus",
|
115 |
+
"cabinet",
|
116 |
+
"cabinet door",
|
117 |
+
"cage",
|
118 |
+
"cake",
|
119 |
+
"calculator",
|
120 |
+
"calendar",
|
121 |
+
"camel",
|
122 |
+
"camera",
|
123 |
+
"camera lens",
|
124 |
+
"can",
|
125 |
+
"candle",
|
126 |
+
"candle holder",
|
127 |
+
"cap",
|
128 |
+
"car",
|
129 |
+
"card",
|
130 |
+
"cart",
|
131 |
+
"case",
|
132 |
+
"casette recorder",
|
133 |
+
"cash register",
|
134 |
+
"cat",
|
135 |
+
"cd",
|
136 |
+
"cd player",
|
137 |
+
"ceiling",
|
138 |
+
"cell phone",
|
139 |
+
"cello",
|
140 |
+
"chain",
|
141 |
+
"chair",
|
142 |
+
"chessboard",
|
143 |
+
"chicken",
|
144 |
+
"chopstick",
|
145 |
+
"clip",
|
146 |
+
"clippers",
|
147 |
+
"clock",
|
148 |
+
"closet",
|
149 |
+
"cloth",
|
150 |
+
"clothes tree",
|
151 |
+
"coffee",
|
152 |
+
"coffee machine",
|
153 |
+
"comb",
|
154 |
+
"computer",
|
155 |
+
"concrete",
|
156 |
+
"cone",
|
157 |
+
"container",
|
158 |
+
"control booth",
|
159 |
+
"controller",
|
160 |
+
"cooker",
|
161 |
+
"copying machine",
|
162 |
+
"coral",
|
163 |
+
"cork",
|
164 |
+
"corkscrew",
|
165 |
+
"counter",
|
166 |
+
"court",
|
167 |
+
"cow",
|
168 |
+
"crabstick",
|
169 |
+
"crane",
|
170 |
+
"crate",
|
171 |
+
"cross",
|
172 |
+
"crutch",
|
173 |
+
"cup",
|
174 |
+
"curtain",
|
175 |
+
"cushion",
|
176 |
+
"cutting board",
|
177 |
+
"dais",
|
178 |
+
"disc",
|
179 |
+
"disc case",
|
180 |
+
"dishwasher",
|
181 |
+
"dock",
|
182 |
+
"dog",
|
183 |
+
"dolphin",
|
184 |
+
"door",
|
185 |
+
"drainer",
|
186 |
+
"dray",
|
187 |
+
"drink dispenser",
|
188 |
+
"drinking machine",
|
189 |
+
"drop",
|
190 |
+
"drug",
|
191 |
+
"drum",
|
192 |
+
"drum kit",
|
193 |
+
"duck",
|
194 |
+
"dumbbell",
|
195 |
+
"earphone",
|
196 |
+
"earrings",
|
197 |
+
"egg",
|
198 |
+
"electric fan",
|
199 |
+
"electric iron",
|
200 |
+
"electric pot",
|
201 |
+
"electric saw",
|
202 |
+
"electronic keyboard",
|
203 |
+
"engine",
|
204 |
+
"envelope",
|
205 |
+
"equipment",
|
206 |
+
"escalator",
|
207 |
+
"exhibition booth",
|
208 |
+
"extinguisher",
|
209 |
+
"eyeglass",
|
210 |
+
"fan",
|
211 |
+
"faucet",
|
212 |
+
"fax machine",
|
213 |
+
"fence",
|
214 |
+
"ferris wheel",
|
215 |
+
"fire extinguisher",
|
216 |
+
"fire hydrant",
|
217 |
+
"fire place",
|
218 |
+
"fish",
|
219 |
+
"fish tank",
|
220 |
+
"fishbowl",
|
221 |
+
"fishing net",
|
222 |
+
"fishing pole",
|
223 |
+
"flag",
|
224 |
+
"flagstaff",
|
225 |
+
"flame",
|
226 |
+
"flashlight",
|
227 |
+
"floor",
|
228 |
+
"flower",
|
229 |
+
"fly",
|
230 |
+
"foam",
|
231 |
+
"food",
|
232 |
+
"footbridge",
|
233 |
+
"forceps",
|
234 |
+
"fork",
|
235 |
+
"forklift",
|
236 |
+
"fountain",
|
237 |
+
"fox",
|
238 |
+
"frame",
|
239 |
+
"fridge",
|
240 |
+
"frog",
|
241 |
+
"fruit",
|
242 |
+
"funnel",
|
243 |
+
"furnace",
|
244 |
+
"game controller",
|
245 |
+
"game machine",
|
246 |
+
"gas cylinder",
|
247 |
+
"gas hood",
|
248 |
+
"gas stove",
|
249 |
+
"gift box",
|
250 |
+
"glass",
|
251 |
+
"glass marble",
|
252 |
+
"globe",
|
253 |
+
"glove",
|
254 |
+
"goal",
|
255 |
+
"grandstand",
|
256 |
+
"grass",
|
257 |
+
"gravestone",
|
258 |
+
"ground",
|
259 |
+
"guardrail",
|
260 |
+
"guitar",
|
261 |
+
"gun",
|
262 |
+
"hammer",
|
263 |
+
"hand cart",
|
264 |
+
"handle",
|
265 |
+
"handrail",
|
266 |
+
"hanger",
|
267 |
+
"hard disk drive",
|
268 |
+
"hat",
|
269 |
+
"hay",
|
270 |
+
"headphone",
|
271 |
+
"heater",
|
272 |
+
"helicopter",
|
273 |
+
"helmet",
|
274 |
+
"holder",
|
275 |
+
"hook",
|
276 |
+
"horse",
|
277 |
+
"horse-drawn carriage",
|
278 |
+
"hot-air balloon",
|
279 |
+
"hydrovalve",
|
280 |
+
"ice",
|
281 |
+
"inflator pump",
|
282 |
+
"ipod",
|
283 |
+
"iron",
|
284 |
+
"ironing board",
|
285 |
+
"jar",
|
286 |
+
"kart",
|
287 |
+
"kettle",
|
288 |
+
"key",
|
289 |
+
"keyboard",
|
290 |
+
"kitchen range",
|
291 |
+
"kite",
|
292 |
+
"knife",
|
293 |
+
"knife block",
|
294 |
+
"ladder",
|
295 |
+
"ladder truck",
|
296 |
+
"ladle",
|
297 |
+
"laptop",
|
298 |
+
"leaves",
|
299 |
+
"lid",
|
300 |
+
"life buoy",
|
301 |
+
"light",
|
302 |
+
"light bulb",
|
303 |
+
"lighter",
|
304 |
+
"line",
|
305 |
+
"lion",
|
306 |
+
"lobster",
|
307 |
+
"lock",
|
308 |
+
"machine",
|
309 |
+
"mailbox",
|
310 |
+
"mannequin",
|
311 |
+
"map",
|
312 |
+
"mask",
|
313 |
+
"mat",
|
314 |
+
"match book",
|
315 |
+
"mattress",
|
316 |
+
"menu",
|
317 |
+
"metal",
|
318 |
+
"meter box",
|
319 |
+
"microphone",
|
320 |
+
"microwave",
|
321 |
+
"mirror",
|
322 |
+
"missile",
|
323 |
+
"model",
|
324 |
+
"money",
|
325 |
+
"monkey",
|
326 |
+
"mop",
|
327 |
+
"motorbike",
|
328 |
+
"mountain",
|
329 |
+
"mouse",
|
330 |
+
"mouse pad",
|
331 |
+
"musical instrument",
|
332 |
+
"napkin",
|
333 |
+
"net",
|
334 |
+
"newspaper",
|
335 |
+
"oar",
|
336 |
+
"ornament",
|
337 |
+
"outlet",
|
338 |
+
"oven",
|
339 |
+
"oxygen bottle",
|
340 |
+
"pack",
|
341 |
+
"pan",
|
342 |
+
"paper",
|
343 |
+
"paper box",
|
344 |
+
"paper cutter",
|
345 |
+
"parachute",
|
346 |
+
"parasol",
|
347 |
+
"parterre",
|
348 |
+
"patio",
|
349 |
+
"pelage",
|
350 |
+
"pen",
|
351 |
+
"pen container",
|
352 |
+
"pencil",
|
353 |
+
"person",
|
354 |
+
"photo",
|
355 |
+
"piano",
|
356 |
+
"picture",
|
357 |
+
"pig",
|
358 |
+
"pillar",
|
359 |
+
"pillow",
|
360 |
+
"pipe",
|
361 |
+
"pitcher",
|
362 |
+
"plant",
|
363 |
+
"plastic",
|
364 |
+
"plate",
|
365 |
+
"platform",
|
366 |
+
"player",
|
367 |
+
"playground",
|
368 |
+
"pliers",
|
369 |
+
"plume",
|
370 |
+
"poker",
|
371 |
+
"poker chip",
|
372 |
+
"pole",
|
373 |
+
"pool table",
|
374 |
+
"postcard",
|
375 |
+
"poster",
|
376 |
+
"pot",
|
377 |
+
"pottedplant",
|
378 |
+
"printer",
|
379 |
+
"projector",
|
380 |
+
"pumpkin",
|
381 |
+
"rabbit",
|
382 |
+
"racket",
|
383 |
+
"radiator",
|
384 |
+
"radio",
|
385 |
+
"rail",
|
386 |
+
"rake",
|
387 |
+
"ramp",
|
388 |
+
"range hood",
|
389 |
+
"receiver",
|
390 |
+
"recorder",
|
391 |
+
"recreational machines",
|
392 |
+
"remote control",
|
393 |
+
"road",
|
394 |
+
"robot",
|
395 |
+
"rock",
|
396 |
+
"rocket",
|
397 |
+
"rocking horse",
|
398 |
+
"rope",
|
399 |
+
"rug",
|
400 |
+
"ruler",
|
401 |
+
"runway",
|
402 |
+
"saddle",
|
403 |
+
"sand",
|
404 |
+
"saw",
|
405 |
+
"scale",
|
406 |
+
"scanner",
|
407 |
+
"scissors",
|
408 |
+
"scoop",
|
409 |
+
"screen",
|
410 |
+
"screwdriver",
|
411 |
+
"sculpture",
|
412 |
+
"scythe",
|
413 |
+
"sewer",
|
414 |
+
"sewing machine",
|
415 |
+
"shed",
|
416 |
+
"sheep",
|
417 |
+
"shell",
|
418 |
+
"shelves",
|
419 |
+
"shoe",
|
420 |
+
"shopping cart",
|
421 |
+
"shovel",
|
422 |
+
"sidecar",
|
423 |
+
"sidewalk",
|
424 |
+
"sign",
|
425 |
+
"signal light",
|
426 |
+
"sink",
|
427 |
+
"skateboard",
|
428 |
+
"ski",
|
429 |
+
"sky",
|
430 |
+
"sled",
|
431 |
+
"slippers",
|
432 |
+
"smoke",
|
433 |
+
"snail",
|
434 |
+
"snake",
|
435 |
+
"snow",
|
436 |
+
"snowmobiles",
|
437 |
+
"sofa",
|
438 |
+
"spanner",
|
439 |
+
"spatula",
|
440 |
+
"speaker",
|
441 |
+
"speed bump",
|
442 |
+
"spice container",
|
443 |
+
"spoon",
|
444 |
+
"sprayer",
|
445 |
+
"squirrel",
|
446 |
+
"stage",
|
447 |
+
"stair",
|
448 |
+
"stapler",
|
449 |
+
"stick",
|
450 |
+
"sticky note",
|
451 |
+
"stone",
|
452 |
+
"stool",
|
453 |
+
"stove",
|
454 |
+
"straw",
|
455 |
+
"stretcher",
|
456 |
+
"sun",
|
457 |
+
"sunglass",
|
458 |
+
"sunshade",
|
459 |
+
"surveillance camera",
|
460 |
+
"swan",
|
461 |
+
"sweeper",
|
462 |
+
"swim ring",
|
463 |
+
"swimming pool",
|
464 |
+
"swing",
|
465 |
+
"switch",
|
466 |
+
"table",
|
467 |
+
"tableware",
|
468 |
+
"tank",
|
469 |
+
"tap",
|
470 |
+
"tape",
|
471 |
+
"tarp",
|
472 |
+
"telephone",
|
473 |
+
"telephone booth",
|
474 |
+
"tent",
|
475 |
+
"tire",
|
476 |
+
"toaster",
|
477 |
+
"toilet",
|
478 |
+
"tong",
|
479 |
+
"tool",
|
480 |
+
"toothbrush",
|
481 |
+
"towel",
|
482 |
+
"toy",
|
483 |
+
"toy car",
|
484 |
+
"track",
|
485 |
+
"train",
|
486 |
+
"trampoline",
|
487 |
+
"trash bin",
|
488 |
+
"tray",
|
489 |
+
"tree",
|
490 |
+
"tricycle",
|
491 |
+
"tripod",
|
492 |
+
"trophy",
|
493 |
+
"truck",
|
494 |
+
"tube",
|
495 |
+
"turtle",
|
496 |
+
"tvmonitor",
|
497 |
+
"tweezers",
|
498 |
+
"typewriter",
|
499 |
+
"umbrella",
|
500 |
+
"unknown",
|
501 |
+
"vacuum cleaner",
|
502 |
+
"vending machine",
|
503 |
+
"video camera",
|
504 |
+
"video game console",
|
505 |
+
"video player",
|
506 |
+
"video tape",
|
507 |
+
"violin",
|
508 |
+
"wakeboard",
|
509 |
+
"wall",
|
510 |
+
"wallet",
|
511 |
+
"wardrobe",
|
512 |
+
"washing machine",
|
513 |
+
"watch",
|
514 |
+
"water",
|
515 |
+
"water dispenser",
|
516 |
+
"water pipe",
|
517 |
+
"water skate board",
|
518 |
+
"watermelon",
|
519 |
+
"whale",
|
520 |
+
"wharf",
|
521 |
+
"wheel",
|
522 |
+
"wheelchair",
|
523 |
+
"window",
|
524 |
+
"window blinds",
|
525 |
+
"wineglass",
|
526 |
+
"wire",
|
527 |
+
"wood",
|
528 |
+
"wool",
|
529 |
+
|
530 |
+
)
|
531 |
+
|
532 |
+
|
533 |
+
def _get_voc_meta(cat_list):
|
534 |
+
ret = {
|
535 |
+
"stuff_classes": cat_list,
|
536 |
+
}
|
537 |
+
return ret
|
538 |
+
|
539 |
+
|
540 |
+
def register_pascal_context_59(root):
|
541 |
+
root = os.path.join(root, "VOCdevkit/VOC2010")
|
542 |
+
meta = _get_voc_meta(PASCALCONTEX59_NAMES)
|
543 |
+
for name, image_dirname, sem_seg_dirname in [
|
544 |
+
("val", "JPEGImages", "annotations_detectron2/pc59_val"),
|
545 |
+
]:
|
546 |
+
image_dir = os.path.join(root, image_dirname)
|
547 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
548 |
+
all_name = f"pascal_context_59_sem_seg_{name}"
|
549 |
+
DatasetCatalog.register(
|
550 |
+
all_name,
|
551 |
+
lambda x=image_dir, y=gt_dir: load_sem_seg(
|
552 |
+
y, x, gt_ext="png", image_ext="jpg"
|
553 |
+
),
|
554 |
+
)
|
555 |
+
MetadataCatalog.get(all_name).set(
|
556 |
+
image_root=image_dir,
|
557 |
+
sem_seg_root=gt_dir,
|
558 |
+
evaluator_type="sem_seg",
|
559 |
+
ignore_label=255,
|
560 |
+
**meta,
|
561 |
+
)
|
562 |
+
|
563 |
+
def register_pascal_context_459(root):
|
564 |
+
root = os.path.join(root, "VOCdevkit/VOC2010")
|
565 |
+
meta = _get_voc_meta(PASCALCONTEX459_NAMES)
|
566 |
+
for name, image_dirname, sem_seg_dirname in [
|
567 |
+
("val", "JPEGImages", "annotations_detectron2/pc459_val"),
|
568 |
+
]:
|
569 |
+
image_dir = os.path.join(root, image_dirname)
|
570 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
571 |
+
all_name = f"pascal_context_459_sem_seg_{name}"
|
572 |
+
DatasetCatalog.register(
|
573 |
+
all_name,
|
574 |
+
lambda x=image_dir, y=gt_dir: load_sem_seg(
|
575 |
+
y, x, gt_ext="tif", image_ext="jpg"
|
576 |
+
),
|
577 |
+
)
|
578 |
+
MetadataCatalog.get(all_name).set(
|
579 |
+
image_root=image_dir,
|
580 |
+
sem_seg_root=gt_dir,
|
581 |
+
evaluator_type="sem_seg",
|
582 |
+
ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images
|
583 |
+
**meta,
|
584 |
+
)
|
585 |
+
|
586 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
587 |
+
register_pascal_context_59(_root)
|
588 |
+
register_pascal_context_459(_root)
|
open_vocab_seg/data/datasets/register_voc_seg.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import os
|
3 |
+
|
4 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
5 |
+
from detectron2.data.datasets import load_sem_seg
|
6 |
+
|
7 |
+
PASCALVOC20_NAMES = (
|
8 |
+
"aeroplane",
|
9 |
+
"bicycle",
|
10 |
+
"bird",
|
11 |
+
"boat",
|
12 |
+
"bottle",
|
13 |
+
"bus",
|
14 |
+
"car",
|
15 |
+
"cat",
|
16 |
+
"chair",
|
17 |
+
"cow",
|
18 |
+
"diningtable",
|
19 |
+
"dog",
|
20 |
+
"horse",
|
21 |
+
"motorbike",
|
22 |
+
"person",
|
23 |
+
"pottedplant",
|
24 |
+
"sheep",
|
25 |
+
"sofa",
|
26 |
+
"train",
|
27 |
+
"tvmonitor",
|
28 |
+
)
|
29 |
+
|
30 |
+
def _get_voc_meta(cat_list):
|
31 |
+
ret = {
|
32 |
+
"stuff_classes": cat_list,
|
33 |
+
}
|
34 |
+
return ret
|
35 |
+
|
36 |
+
|
37 |
+
def register_pascalvoc(root):
|
38 |
+
root = os.path.join(root, "VOCdevkit/VOC2012")
|
39 |
+
meta = _get_voc_meta(PASCALVOC20_NAMES)
|
40 |
+
|
41 |
+
for name, image_dirname, sem_seg_dirname in [
|
42 |
+
("val", "JPEGImages", "annotations_detectron2/val"),
|
43 |
+
]:
|
44 |
+
image_dir = os.path.join(root, image_dirname)
|
45 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
46 |
+
all_name = f"pascalvoc20_sem_seg_{name}"
|
47 |
+
DatasetCatalog.register(
|
48 |
+
all_name,
|
49 |
+
lambda x=image_dir, y=gt_dir: load_sem_seg(
|
50 |
+
y, x, gt_ext="png", image_ext="jpg"
|
51 |
+
),
|
52 |
+
)
|
53 |
+
MetadataCatalog.get(all_name).set(
|
54 |
+
image_root=image_dir,
|
55 |
+
sem_seg_root=gt_dir,
|
56 |
+
evaluator_type="sem_seg",
|
57 |
+
ignore_label=255,
|
58 |
+
**meta,
|
59 |
+
)
|
60 |
+
|
61 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
62 |
+
register_pascalvoc(_root)
|
open_vocab_seg/evaluation/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from .generalized_sem_seg_evaluation import GeneralizedSemSegEvaluator
|
open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import itertools
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
from collections import OrderedDict
|
9 |
+
import PIL.Image as Image
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
13 |
+
from detectron2.utils.comm import all_gather, is_main_process, synchronize
|
14 |
+
from detectron2.utils.file_io import PathManager
|
15 |
+
|
16 |
+
from detectron2.evaluation import SemSegEvaluator
|
17 |
+
|
18 |
+
|
19 |
+
class GeneralizedSemSegEvaluator(SemSegEvaluator):
|
20 |
+
"""
|
21 |
+
Evaluate semantic segmentation metrics.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
dataset_name,
|
27 |
+
distributed=True,
|
28 |
+
output_dir=None,
|
29 |
+
*,
|
30 |
+
num_classes=None,
|
31 |
+
ignore_label=None,
|
32 |
+
post_process_func=None,
|
33 |
+
):
|
34 |
+
super().__init__(
|
35 |
+
dataset_name,
|
36 |
+
distributed=distributed,
|
37 |
+
output_dir=output_dir,
|
38 |
+
num_classes=num_classes,
|
39 |
+
ignore_label=ignore_label,
|
40 |
+
)
|
41 |
+
meta = MetadataCatalog.get(dataset_name)
|
42 |
+
try:
|
43 |
+
self._evaluation_set = meta.evaluation_set
|
44 |
+
except AttributeError:
|
45 |
+
self._evaluation_set = None
|
46 |
+
self.post_process_func = (
|
47 |
+
post_process_func
|
48 |
+
if post_process_func is not None
|
49 |
+
else lambda x, **kwargs: x
|
50 |
+
)
|
51 |
+
|
52 |
+
def process(self, inputs, outputs):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
inputs: the inputs to a model.
|
56 |
+
It is a list of dicts. Each dict corresponds to an image and
|
57 |
+
contains keys like "height", "width", "file_name".
|
58 |
+
outputs: the outputs of a model. It is either list of semantic segmentation predictions
|
59 |
+
(Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
|
60 |
+
segmentation prediction in the same format.
|
61 |
+
"""
|
62 |
+
for input, output in zip(inputs, outputs):
|
63 |
+
output = self.post_process_func(
|
64 |
+
output["sem_seg"], image=np.array(Image.open(input["file_name"]))
|
65 |
+
)
|
66 |
+
output = output.argmax(dim=0).to(self._cpu_device)
|
67 |
+
pred = np.array(output, dtype=np.int)
|
68 |
+
with PathManager.open(
|
69 |
+
self.input_file_to_gt_file[input["file_name"]], "rb"
|
70 |
+
) as f:
|
71 |
+
gt = np.array(Image.open(f), dtype=np.int)
|
72 |
+
|
73 |
+
gt[gt == self._ignore_label] = self._num_classes
|
74 |
+
|
75 |
+
self._conf_matrix += np.bincount(
|
76 |
+
(self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
|
77 |
+
minlength=self._conf_matrix.size,
|
78 |
+
).reshape(self._conf_matrix.shape)
|
79 |
+
|
80 |
+
self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
|
81 |
+
|
82 |
+
def evaluate(self):
|
83 |
+
"""
|
84 |
+
Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
|
85 |
+
|
86 |
+
* Mean intersection-over-union averaged across classes (mIoU)
|
87 |
+
* Frequency Weighted IoU (fwIoU)
|
88 |
+
* Mean pixel accuracy averaged across classes (mACC)
|
89 |
+
* Pixel Accuracy (pACC)
|
90 |
+
"""
|
91 |
+
if self._distributed:
|
92 |
+
synchronize()
|
93 |
+
conf_matrix_list = all_gather(self._conf_matrix)
|
94 |
+
self._predictions = all_gather(self._predictions)
|
95 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
96 |
+
if not is_main_process():
|
97 |
+
return
|
98 |
+
|
99 |
+
self._conf_matrix = np.zeros_like(self._conf_matrix)
|
100 |
+
for conf_matrix in conf_matrix_list:
|
101 |
+
self._conf_matrix += conf_matrix
|
102 |
+
|
103 |
+
if self._output_dir:
|
104 |
+
PathManager.mkdirs(self._output_dir)
|
105 |
+
file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
|
106 |
+
with PathManager.open(file_path, "w") as f:
|
107 |
+
f.write(json.dumps(self._predictions))
|
108 |
+
|
109 |
+
acc = np.full(self._num_classes, np.nan, dtype=np.float)
|
110 |
+
iou = np.full(self._num_classes, np.nan, dtype=np.float)
|
111 |
+
tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
|
112 |
+
pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
|
113 |
+
class_weights = pos_gt / np.sum(pos_gt)
|
114 |
+
pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
|
115 |
+
acc_valid = pos_gt > 0
|
116 |
+
acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
|
117 |
+
iou_valid = (pos_gt + pos_pred) > 0
|
118 |
+
union = pos_gt + pos_pred - tp
|
119 |
+
iou[acc_valid] = tp[acc_valid] / union[acc_valid]
|
120 |
+
macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
|
121 |
+
miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
|
122 |
+
fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
|
123 |
+
pacc = np.sum(tp) / np.sum(pos_gt)
|
124 |
+
|
125 |
+
res = {}
|
126 |
+
res["mIoU"] = 100 * miou
|
127 |
+
res["fwIoU"] = 100 * fiou
|
128 |
+
for i, name in enumerate(self._class_names):
|
129 |
+
res["IoU-{}".format(name)] = 100 * iou[i]
|
130 |
+
res["mACC"] = 100 * macc
|
131 |
+
res["pACC"] = 100 * pacc
|
132 |
+
for i, name in enumerate(self._class_names):
|
133 |
+
res["ACC-{}".format(name)] = 100 * acc[i]
|
134 |
+
if self._evaluation_set is not None:
|
135 |
+
for set_name, set_inds in self._evaluation_set.items():
|
136 |
+
iou_list = []
|
137 |
+
set_inds = np.array(set_inds, np.int)
|
138 |
+
mask = np.zeros((len(iou),)).astype(np.bool)
|
139 |
+
mask[set_inds] = 1
|
140 |
+
miou = np.sum(iou[mask][acc_valid[mask]]) / np.sum(iou_valid[mask])
|
141 |
+
pacc = np.sum(tp[mask]) / np.sum(pos_gt[mask])
|
142 |
+
res["mIoU-{}".format(set_name)] = 100 * miou
|
143 |
+
res["pAcc-{}".format(set_name)] = 100 * pacc
|
144 |
+
iou_list.append(miou)
|
145 |
+
miou = np.sum(iou[~mask][acc_valid[~mask]]) / np.sum(iou_valid[~mask])
|
146 |
+
pacc = np.sum(tp[~mask]) / np.sum(pos_gt[~mask])
|
147 |
+
res["mIoU-un{}".format(set_name)] = 100 * miou
|
148 |
+
res["pAcc-un{}".format(set_name)] = 100 * pacc
|
149 |
+
iou_list.append(miou)
|
150 |
+
res["hIoU-{}".format(set_name)] = (
|
151 |
+
100 * len(iou_list) / sum([1 / iou for iou in iou_list])
|
152 |
+
)
|
153 |
+
if self._output_dir:
|
154 |
+
file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
|
155 |
+
with PathManager.open(file_path, "wb") as f:
|
156 |
+
torch.save(res, f)
|
157 |
+
results = OrderedDict({"sem_seg": res})
|
158 |
+
self._logger.info(results)
|
159 |
+
return results
|
open_vocab_seg/mask_former_model.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from detectron2.config import configurable
|
11 |
+
from detectron2.data import MetadataCatalog
|
12 |
+
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
|
13 |
+
from detectron2.modeling.backbone import Backbone
|
14 |
+
from detectron2.modeling.postprocessing import sem_seg_postprocess
|
15 |
+
from detectron2.structures import ImageList
|
16 |
+
|
17 |
+
from .modeling.criterion import SetCriterion
|
18 |
+
from .modeling.matcher import HungarianMatcher
|
19 |
+
|
20 |
+
|
21 |
+
@META_ARCH_REGISTRY.register()
|
22 |
+
class MaskFormer(nn.Module):
|
23 |
+
"""
|
24 |
+
Main class for mask classification semantic segmentation architectures.
|
25 |
+
"""
|
26 |
+
|
27 |
+
@configurable
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
*,
|
31 |
+
backbone: Backbone,
|
32 |
+
sem_seg_head: nn.Module,
|
33 |
+
criterion: nn.Module,
|
34 |
+
num_queries: int,
|
35 |
+
panoptic_on: bool,
|
36 |
+
object_mask_threshold: float,
|
37 |
+
overlap_threshold: float,
|
38 |
+
metadata,
|
39 |
+
size_divisibility: int,
|
40 |
+
sem_seg_postprocess_before_inference: bool,
|
41 |
+
pixel_mean: Tuple[float],
|
42 |
+
pixel_std: Tuple[float],
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
47 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
48 |
+
criterion: a module that defines the loss
|
49 |
+
num_queries: int, number of queries
|
50 |
+
panoptic_on: bool, whether to output panoptic segmentation prediction
|
51 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
52 |
+
for panoptic segmentation inference
|
53 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
54 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
55 |
+
segmentation inference
|
56 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
57 |
+
specific integer. We can use this to override such requirement.
|
58 |
+
sem_seg_postprocess_before_inference: whether to resize the prediction back
|
59 |
+
to original input size before semantic segmentation inference or after.
|
60 |
+
For high-resolution dataset like Mapillary, resizing predictions before
|
61 |
+
inference will cause OOM error.
|
62 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
63 |
+
the per-channel mean and std to be used to normalize the input image
|
64 |
+
"""
|
65 |
+
super().__init__()
|
66 |
+
self.backbone = backbone
|
67 |
+
self.sem_seg_head = sem_seg_head
|
68 |
+
self.criterion = criterion
|
69 |
+
self.num_queries = num_queries
|
70 |
+
self.overlap_threshold = overlap_threshold
|
71 |
+
self.panoptic_on = panoptic_on
|
72 |
+
self.object_mask_threshold = object_mask_threshold
|
73 |
+
self.metadata = metadata
|
74 |
+
if size_divisibility < 0:
|
75 |
+
# use backbone size_divisibility if not set
|
76 |
+
size_divisibility = self.backbone.size_divisibility
|
77 |
+
self.size_divisibility = size_divisibility
|
78 |
+
self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
|
79 |
+
self.register_buffer(
|
80 |
+
"pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
|
81 |
+
)
|
82 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def from_config(cls, cfg):
|
86 |
+
backbone = build_backbone(cfg)
|
87 |
+
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
|
88 |
+
|
89 |
+
# Loss parameters:
|
90 |
+
deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
|
91 |
+
no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
|
92 |
+
dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
|
93 |
+
mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
|
94 |
+
|
95 |
+
# building criterion
|
96 |
+
matcher = HungarianMatcher(
|
97 |
+
cost_class=1,
|
98 |
+
cost_mask=mask_weight,
|
99 |
+
cost_dice=dice_weight,
|
100 |
+
)
|
101 |
+
|
102 |
+
weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight}
|
103 |
+
if deep_supervision:
|
104 |
+
dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
|
105 |
+
aux_weight_dict = {}
|
106 |
+
for i in range(dec_layers - 1):
|
107 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
108 |
+
weight_dict.update(aux_weight_dict)
|
109 |
+
|
110 |
+
losses = ["labels", "masks"]
|
111 |
+
|
112 |
+
criterion = SetCriterion(
|
113 |
+
sem_seg_head.num_classes,
|
114 |
+
matcher=matcher,
|
115 |
+
weight_dict=weight_dict,
|
116 |
+
eos_coef=no_object_weight,
|
117 |
+
losses=losses,
|
118 |
+
)
|
119 |
+
|
120 |
+
return {
|
121 |
+
"backbone": backbone,
|
122 |
+
"sem_seg_head": sem_seg_head,
|
123 |
+
"criterion": criterion,
|
124 |
+
"num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
|
125 |
+
"panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
|
126 |
+
"object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
|
127 |
+
"overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
|
128 |
+
"metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
|
129 |
+
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
|
130 |
+
"sem_seg_postprocess_before_inference": (
|
131 |
+
cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
|
132 |
+
or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
|
133 |
+
),
|
134 |
+
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
|
135 |
+
"pixel_std": cfg.MODEL.PIXEL_STD,
|
136 |
+
}
|
137 |
+
|
138 |
+
@property
|
139 |
+
def device(self):
|
140 |
+
return self.pixel_mean.device
|
141 |
+
|
142 |
+
def forward(self, batched_inputs):
|
143 |
+
"""
|
144 |
+
Args:
|
145 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
146 |
+
Each item in the list contains the inputs for one image.
|
147 |
+
For now, each item in the list is a dict that contains:
|
148 |
+
* "image": Tensor, image in (C, H, W) format.
|
149 |
+
* "instances": per-region ground truth
|
150 |
+
* Other information that's included in the original dicts, such as:
|
151 |
+
"height", "width" (int): the output resolution of the model (may be different
|
152 |
+
from input resolution), used in inference.
|
153 |
+
Returns:
|
154 |
+
list[dict]:
|
155 |
+
each dict has the results for one image. The dict contains the following keys:
|
156 |
+
|
157 |
+
* "sem_seg":
|
158 |
+
A Tensor that represents the
|
159 |
+
per-pixel segmentation prediced by the head.
|
160 |
+
The prediction has shape KxHxW that represents the logits of
|
161 |
+
each class for each pixel.
|
162 |
+
* "panoptic_seg":
|
163 |
+
A tuple that represent panoptic output
|
164 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
165 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
166 |
+
Each dict contains keys "id", "category_id", "isthing".
|
167 |
+
"""
|
168 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
169 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
170 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
171 |
+
|
172 |
+
features = self.backbone(images.tensor)
|
173 |
+
outputs = self.sem_seg_head(features)
|
174 |
+
|
175 |
+
if self.training:
|
176 |
+
# mask classification target
|
177 |
+
if "instances" in batched_inputs[0]:
|
178 |
+
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
179 |
+
targets = self.prepare_targets(gt_instances, images)
|
180 |
+
else:
|
181 |
+
targets = None
|
182 |
+
|
183 |
+
# bipartite matching-based loss
|
184 |
+
losses = self.criterion(outputs, targets)
|
185 |
+
|
186 |
+
for k in list(losses.keys()):
|
187 |
+
if k in self.criterion.weight_dict:
|
188 |
+
losses[k] *= self.criterion.weight_dict[k]
|
189 |
+
else:
|
190 |
+
# remove this loss if not specified in `weight_dict`
|
191 |
+
losses.pop(k)
|
192 |
+
|
193 |
+
return losses
|
194 |
+
else:
|
195 |
+
mask_cls_results = outputs["pred_logits"]
|
196 |
+
mask_pred_results = outputs["pred_masks"]
|
197 |
+
# upsample masks
|
198 |
+
mask_pred_results = F.interpolate(
|
199 |
+
mask_pred_results,
|
200 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
201 |
+
mode="bilinear",
|
202 |
+
align_corners=False,
|
203 |
+
)
|
204 |
+
|
205 |
+
processed_results = []
|
206 |
+
for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
207 |
+
mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
|
208 |
+
):
|
209 |
+
height = input_per_image.get("height", image_size[0])
|
210 |
+
width = input_per_image.get("width", image_size[1])
|
211 |
+
|
212 |
+
if self.sem_seg_postprocess_before_inference:
|
213 |
+
mask_pred_result = sem_seg_postprocess(
|
214 |
+
mask_pred_result, image_size, height, width
|
215 |
+
)
|
216 |
+
|
217 |
+
# semantic segmentation inference
|
218 |
+
r = self.semantic_inference(mask_cls_result, mask_pred_result)
|
219 |
+
if not self.sem_seg_postprocess_before_inference:
|
220 |
+
r = sem_seg_postprocess(r, image_size, height, width)
|
221 |
+
processed_results.append({"sem_seg": r})
|
222 |
+
|
223 |
+
# panoptic segmentation inference
|
224 |
+
if self.panoptic_on:
|
225 |
+
panoptic_r = self.panoptic_inference(
|
226 |
+
mask_cls_result, mask_pred_result
|
227 |
+
)
|
228 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
229 |
+
|
230 |
+
return processed_results
|
231 |
+
|
232 |
+
def prepare_targets(self, targets, images):
|
233 |
+
h, w = images.tensor.shape[-2:]
|
234 |
+
new_targets = []
|
235 |
+
for targets_per_image in targets:
|
236 |
+
# pad gt
|
237 |
+
gt_masks = targets_per_image.gt_masks
|
238 |
+
padded_masks = torch.zeros(
|
239 |
+
(gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device
|
240 |
+
)
|
241 |
+
padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
242 |
+
new_targets.append(
|
243 |
+
{
|
244 |
+
"labels": targets_per_image.gt_classes,
|
245 |
+
"masks": padded_masks,
|
246 |
+
}
|
247 |
+
)
|
248 |
+
return new_targets
|
249 |
+
|
250 |
+
def semantic_inference(self, mask_cls, mask_pred):
|
251 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
252 |
+
mask_pred = mask_pred.sigmoid()
|
253 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
254 |
+
return semseg
|
open_vocab_seg/modeling/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
open_vocab_seg/modeling/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from .backbone.swin import D2SwinTransformer
|
5 |
+
from .backbone.clip_resnet import D2ModifiedResNet
|
6 |
+
from .heads.mask_former_head import MaskFormerHead
|
7 |
+
from .heads.open_vocab_mask_former_head import OpenVocabMaskFormerHead
|
8 |
+
from .heads.pixel_decoder import BasePixelDecoder
|
open_vocab_seg/modeling/backbone/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
open_vocab_seg/modeling/backbone/clip_resnet.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from collections import OrderedDict
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
|
20 |
+
self.conv2 = nn.Conv2d(
|
21 |
+
planes, planes, 3, padding=1 * dilation, bias=False, dilation=dilation
|
22 |
+
)
|
23 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
|
30 |
+
self.relu = nn.ReLU(inplace=True)
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(
|
37 |
+
OrderedDict(
|
38 |
+
[
|
39 |
+
("-1", nn.AvgPool2d(stride)),
|
40 |
+
(
|
41 |
+
"0",
|
42 |
+
nn.Conv2d(
|
43 |
+
inplanes,
|
44 |
+
planes * self.expansion,
|
45 |
+
1,
|
46 |
+
stride=1,
|
47 |
+
bias=False,
|
48 |
+
),
|
49 |
+
),
|
50 |
+
("1", nn.BatchNorm2d(planes * self.expansion)),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor):
|
56 |
+
identity = x
|
57 |
+
|
58 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
59 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
60 |
+
out = self.avgpool(out)
|
61 |
+
out = self.bn3(self.conv3(out))
|
62 |
+
|
63 |
+
if self.downsample is not None:
|
64 |
+
identity = self.downsample(x)
|
65 |
+
|
66 |
+
out += identity
|
67 |
+
out = self.relu(out)
|
68 |
+
return out
|
69 |
+
|
70 |
+
|
71 |
+
class ModifiedResNet(nn.Module):
|
72 |
+
"""
|
73 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
74 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
75 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
76 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, layers, width=64, strides=[2, 1, 2, 2, 2], multi_grid=[1, 1, 1]):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
# the 3-layer stem
|
83 |
+
self.conv1 = nn.Conv2d(
|
84 |
+
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
|
85 |
+
)
|
86 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
87 |
+
self.conv2 = nn.Conv2d(
|
88 |
+
width // 2, width // 2, kernel_size=3, padding=1, bias=False
|
89 |
+
)
|
90 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
91 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
92 |
+
self.bn3 = nn.BatchNorm2d(width)
|
93 |
+
self.avgpool = nn.AvgPool2d(strides[0]) if strides[0] > 1 else nn.Identity()
|
94 |
+
self.relu = nn.ReLU(inplace=True)
|
95 |
+
|
96 |
+
# residual layers
|
97 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
98 |
+
self.layer1 = self._make_layer(width, layers[0], stride=strides[1])
|
99 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=strides[2])
|
100 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=strides[3])
|
101 |
+
self.layer4 = self._make_layer(
|
102 |
+
width * 8, layers[3], stride=strides[4], dilations=multi_grid
|
103 |
+
)
|
104 |
+
self.num_features = [width * 4, width * 8, width * 16, width * 32]
|
105 |
+
|
106 |
+
def _make_layer(self, planes, blocks, stride=1, dilations=None):
|
107 |
+
if dilations is None:
|
108 |
+
dilations = [1] * blocks
|
109 |
+
layers = [Bottleneck(self._inplanes, planes, stride, dilation=dilations[0])]
|
110 |
+
self._inplanes = planes * Bottleneck.expansion
|
111 |
+
|
112 |
+
for i in range(1, blocks):
|
113 |
+
layers.append(Bottleneck(self._inplanes, planes, dilation=dilations[i]))
|
114 |
+
|
115 |
+
return nn.Sequential(*layers)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
def stem(x):
|
119 |
+
for conv, bn in [
|
120 |
+
(self.conv1, self.bn1),
|
121 |
+
(self.conv2, self.bn2),
|
122 |
+
(self.conv3, self.bn3),
|
123 |
+
]:
|
124 |
+
x = self.relu(bn(conv(x)))
|
125 |
+
x = self.avgpool(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
output = {}
|
129 |
+
x = x.type(self.conv1.weight.dtype)
|
130 |
+
x = stem(x) # 1/4,1/4
|
131 |
+
x = self.layer1(x)
|
132 |
+
output["res2"] = x
|
133 |
+
x = self.layer2(x) # 1/8,1/8
|
134 |
+
output["res3"] = x
|
135 |
+
x = self.layer3(x) # 1/16,1/16
|
136 |
+
output["res4"] = x
|
137 |
+
x = self.layer4(x) # 1/32,1/32
|
138 |
+
output["res5"] = x
|
139 |
+
return output
|
140 |
+
|
141 |
+
|
142 |
+
@BACKBONE_REGISTRY.register()
|
143 |
+
class D2ModifiedResNet(ModifiedResNet, Backbone):
|
144 |
+
def __init__(self, cfg, input_shape):
|
145 |
+
depth = cfg.MODEL.RESNETS.DEPTH
|
146 |
+
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
147 |
+
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
148 |
+
bottleneck_channels = num_groups * width_per_group
|
149 |
+
num_blocks_per_stage = {
|
150 |
+
18: [2, 2, 2, 2],
|
151 |
+
34: [3, 4, 6, 3],
|
152 |
+
50: [3, 4, 6, 3],
|
153 |
+
101: [3, 4, 23, 3],
|
154 |
+
152: [3, 8, 36, 3],
|
155 |
+
}[depth]
|
156 |
+
strides = [2, 1, 2, 2, 2]
|
157 |
+
multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID
|
158 |
+
if cfg.MODEL.RESNETS.STEM_TYPE == "deeplab":
|
159 |
+
strides = [1, 1, 2, 2, 2]
|
160 |
+
super().__init__(
|
161 |
+
num_blocks_per_stage,
|
162 |
+
bottleneck_channels,
|
163 |
+
strides=strides,
|
164 |
+
multi_grid=multi_grid,
|
165 |
+
)
|
166 |
+
self._out_features = cfg.MODEL.RESNETS.OUT_FEATURES
|
167 |
+
|
168 |
+
self._out_feature_strides = {
|
169 |
+
"res2": 4,
|
170 |
+
"res3": 8,
|
171 |
+
"res4": 16,
|
172 |
+
"res5": 32,
|
173 |
+
}
|
174 |
+
self._out_feature_channels = {
|
175 |
+
"res2": self.num_features[0],
|
176 |
+
"res3": self.num_features[1],
|
177 |
+
"res4": self.num_features[2],
|
178 |
+
"res5": self.num_features[3],
|
179 |
+
}
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
"""
|
183 |
+
Args:
|
184 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
185 |
+
Returns:
|
186 |
+
dict[str->Tensor]: names and the corresponding features
|
187 |
+
"""
|
188 |
+
outputs = {}
|
189 |
+
y = super().forward(x)
|
190 |
+
for k in y.keys():
|
191 |
+
if k in self._out_features:
|
192 |
+
outputs[k] = y[k]
|
193 |
+
return outputs
|
194 |
+
|
195 |
+
def output_shape(self):
|
196 |
+
return {
|
197 |
+
name: ShapeSpec(
|
198 |
+
channels=self._out_feature_channels[name],
|
199 |
+
stride=self._out_feature_strides[name],
|
200 |
+
)
|
201 |
+
for name in self._out_features
|
202 |
+
}
|
203 |
+
|
204 |
+
@property
|
205 |
+
def size_divisibility(self):
|
206 |
+
return 32
|
open_vocab_seg/modeling/backbone/swin.py
ADDED
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
9 |
+
# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
|
10 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.utils.checkpoint as checkpoint
|
17 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
18 |
+
|
19 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
20 |
+
|
21 |
+
|
22 |
+
class Mlp(nn.Module):
|
23 |
+
"""Multilayer perceptron."""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
in_features,
|
28 |
+
hidden_features=None,
|
29 |
+
out_features=None,
|
30 |
+
act_layer=nn.GELU,
|
31 |
+
drop=0.0,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
out_features = out_features or in_features
|
35 |
+
hidden_features = hidden_features or in_features
|
36 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
37 |
+
self.act = act_layer()
|
38 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
39 |
+
self.drop = nn.Dropout(drop)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.fc1(x)
|
43 |
+
x = self.act(x)
|
44 |
+
x = self.drop(x)
|
45 |
+
x = self.fc2(x)
|
46 |
+
x = self.drop(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
def window_partition(x, window_size):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
x: (B, H, W, C)
|
54 |
+
window_size (int): window size
|
55 |
+
Returns:
|
56 |
+
windows: (num_windows*B, window_size, window_size, C)
|
57 |
+
"""
|
58 |
+
B, H, W, C = x.shape
|
59 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
60 |
+
windows = (
|
61 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
62 |
+
)
|
63 |
+
return windows
|
64 |
+
|
65 |
+
|
66 |
+
def window_reverse(windows, window_size, H, W):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
windows: (num_windows*B, window_size, window_size, C)
|
70 |
+
window_size (int): Window size
|
71 |
+
H (int): Height of image
|
72 |
+
W (int): Width of image
|
73 |
+
Returns:
|
74 |
+
x: (B, H, W, C)
|
75 |
+
"""
|
76 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
77 |
+
x = windows.view(
|
78 |
+
B, H // window_size, W // window_size, window_size, window_size, -1
|
79 |
+
)
|
80 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
class WindowAttention(nn.Module):
|
85 |
+
"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
86 |
+
It supports both of shifted and non-shifted window.
|
87 |
+
Args:
|
88 |
+
dim (int): Number of input channels.
|
89 |
+
window_size (tuple[int]): The height and width of the window.
|
90 |
+
num_heads (int): Number of attention heads.
|
91 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
92 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
93 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
94 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
dim,
|
100 |
+
window_size,
|
101 |
+
num_heads,
|
102 |
+
qkv_bias=True,
|
103 |
+
qk_scale=None,
|
104 |
+
attn_drop=0.0,
|
105 |
+
proj_drop=0.0,
|
106 |
+
):
|
107 |
+
|
108 |
+
super().__init__()
|
109 |
+
self.dim = dim
|
110 |
+
self.window_size = window_size # Wh, Ww
|
111 |
+
self.num_heads = num_heads
|
112 |
+
head_dim = dim // num_heads
|
113 |
+
self.scale = qk_scale or head_dim ** -0.5
|
114 |
+
|
115 |
+
# define a parameter table of relative position bias
|
116 |
+
self.relative_position_bias_table = nn.Parameter(
|
117 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
118 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
119 |
+
|
120 |
+
# get pair-wise relative position index for each token inside the window
|
121 |
+
coords_h = torch.arange(self.window_size[0])
|
122 |
+
coords_w = torch.arange(self.window_size[1])
|
123 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
124 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
125 |
+
relative_coords = (
|
126 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
127 |
+
) # 2, Wh*Ww, Wh*Ww
|
128 |
+
relative_coords = relative_coords.permute(
|
129 |
+
1, 2, 0
|
130 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
131 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
132 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
133 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
134 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
135 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
136 |
+
|
137 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
138 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
139 |
+
self.proj = nn.Linear(dim, dim)
|
140 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
141 |
+
|
142 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
143 |
+
self.softmax = nn.Softmax(dim=-1)
|
144 |
+
|
145 |
+
def forward(self, x, mask=None):
|
146 |
+
"""Forward function.
|
147 |
+
Args:
|
148 |
+
x: input features with shape of (num_windows*B, N, C)
|
149 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
150 |
+
"""
|
151 |
+
B_, N, C = x.shape
|
152 |
+
qkv = (
|
153 |
+
self.qkv(x)
|
154 |
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
155 |
+
.permute(2, 0, 3, 1, 4)
|
156 |
+
)
|
157 |
+
q, k, v = (
|
158 |
+
qkv[0],
|
159 |
+
qkv[1],
|
160 |
+
qkv[2],
|
161 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
162 |
+
|
163 |
+
q = q * self.scale
|
164 |
+
attn = q @ k.transpose(-2, -1)
|
165 |
+
|
166 |
+
relative_position_bias = self.relative_position_bias_table[
|
167 |
+
self.relative_position_index.view(-1)
|
168 |
+
].view(
|
169 |
+
self.window_size[0] * self.window_size[1],
|
170 |
+
self.window_size[0] * self.window_size[1],
|
171 |
+
-1,
|
172 |
+
) # Wh*Ww,Wh*Ww,nH
|
173 |
+
relative_position_bias = relative_position_bias.permute(
|
174 |
+
2, 0, 1
|
175 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
176 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
177 |
+
|
178 |
+
if mask is not None:
|
179 |
+
nW = mask.shape[0]
|
180 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
181 |
+
1
|
182 |
+
).unsqueeze(0)
|
183 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
184 |
+
attn = self.softmax(attn)
|
185 |
+
else:
|
186 |
+
attn = self.softmax(attn)
|
187 |
+
|
188 |
+
attn = self.attn_drop(attn)
|
189 |
+
|
190 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
191 |
+
x = self.proj(x)
|
192 |
+
x = self.proj_drop(x)
|
193 |
+
return x
|
194 |
+
|
195 |
+
|
196 |
+
class SwinTransformerBlock(nn.Module):
|
197 |
+
"""Swin Transformer Block.
|
198 |
+
Args:
|
199 |
+
dim (int): Number of input channels.
|
200 |
+
num_heads (int): Number of attention heads.
|
201 |
+
window_size (int): Window size.
|
202 |
+
shift_size (int): Shift size for SW-MSA.
|
203 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
204 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
205 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
206 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
207 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
208 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
209 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
210 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
dim,
|
216 |
+
num_heads,
|
217 |
+
window_size=7,
|
218 |
+
shift_size=0,
|
219 |
+
mlp_ratio=4.0,
|
220 |
+
qkv_bias=True,
|
221 |
+
qk_scale=None,
|
222 |
+
drop=0.0,
|
223 |
+
attn_drop=0.0,
|
224 |
+
drop_path=0.0,
|
225 |
+
act_layer=nn.GELU,
|
226 |
+
norm_layer=nn.LayerNorm,
|
227 |
+
):
|
228 |
+
super().__init__()
|
229 |
+
self.dim = dim
|
230 |
+
self.num_heads = num_heads
|
231 |
+
self.window_size = window_size
|
232 |
+
self.shift_size = shift_size
|
233 |
+
self.mlp_ratio = mlp_ratio
|
234 |
+
assert (
|
235 |
+
0 <= self.shift_size < self.window_size
|
236 |
+
), "shift_size must in 0-window_size"
|
237 |
+
|
238 |
+
self.norm1 = norm_layer(dim)
|
239 |
+
self.attn = WindowAttention(
|
240 |
+
dim,
|
241 |
+
window_size=to_2tuple(self.window_size),
|
242 |
+
num_heads=num_heads,
|
243 |
+
qkv_bias=qkv_bias,
|
244 |
+
qk_scale=qk_scale,
|
245 |
+
attn_drop=attn_drop,
|
246 |
+
proj_drop=drop,
|
247 |
+
)
|
248 |
+
|
249 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
250 |
+
self.norm2 = norm_layer(dim)
|
251 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
252 |
+
self.mlp = Mlp(
|
253 |
+
in_features=dim,
|
254 |
+
hidden_features=mlp_hidden_dim,
|
255 |
+
act_layer=act_layer,
|
256 |
+
drop=drop,
|
257 |
+
)
|
258 |
+
|
259 |
+
self.H = None
|
260 |
+
self.W = None
|
261 |
+
|
262 |
+
def forward(self, x, mask_matrix):
|
263 |
+
"""Forward function.
|
264 |
+
Args:
|
265 |
+
x: Input feature, tensor size (B, H*W, C).
|
266 |
+
H, W: Spatial resolution of the input feature.
|
267 |
+
mask_matrix: Attention mask for cyclic shift.
|
268 |
+
"""
|
269 |
+
B, L, C = x.shape
|
270 |
+
H, W = self.H, self.W
|
271 |
+
assert L == H * W, "input feature has wrong size"
|
272 |
+
|
273 |
+
shortcut = x
|
274 |
+
x = self.norm1(x)
|
275 |
+
x = x.view(B, H, W, C)
|
276 |
+
|
277 |
+
# pad feature maps to multiples of window size
|
278 |
+
pad_l = pad_t = 0
|
279 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
280 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
281 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
282 |
+
_, Hp, Wp, _ = x.shape
|
283 |
+
|
284 |
+
# cyclic shift
|
285 |
+
if self.shift_size > 0:
|
286 |
+
shifted_x = torch.roll(
|
287 |
+
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
288 |
+
)
|
289 |
+
attn_mask = mask_matrix
|
290 |
+
else:
|
291 |
+
shifted_x = x
|
292 |
+
attn_mask = None
|
293 |
+
|
294 |
+
# partition windows
|
295 |
+
x_windows = window_partition(
|
296 |
+
shifted_x, self.window_size
|
297 |
+
) # nW*B, window_size, window_size, C
|
298 |
+
x_windows = x_windows.view(
|
299 |
+
-1, self.window_size * self.window_size, C
|
300 |
+
) # nW*B, window_size*window_size, C
|
301 |
+
|
302 |
+
# W-MSA/SW-MSA
|
303 |
+
attn_windows = self.attn(
|
304 |
+
x_windows, mask=attn_mask
|
305 |
+
) # nW*B, window_size*window_size, C
|
306 |
+
|
307 |
+
# merge windows
|
308 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
309 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
310 |
+
|
311 |
+
# reverse cyclic shift
|
312 |
+
if self.shift_size > 0:
|
313 |
+
x = torch.roll(
|
314 |
+
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
x = shifted_x
|
318 |
+
|
319 |
+
if pad_r > 0 or pad_b > 0:
|
320 |
+
x = x[:, :H, :W, :].contiguous()
|
321 |
+
|
322 |
+
x = x.view(B, H * W, C)
|
323 |
+
|
324 |
+
# FFN
|
325 |
+
x = shortcut + self.drop_path(x)
|
326 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
327 |
+
|
328 |
+
return x
|
329 |
+
|
330 |
+
|
331 |
+
class PatchMerging(nn.Module):
|
332 |
+
"""Patch Merging Layer
|
333 |
+
Args:
|
334 |
+
dim (int): Number of input channels.
|
335 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
339 |
+
super().__init__()
|
340 |
+
self.dim = dim
|
341 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
342 |
+
self.norm = norm_layer(4 * dim)
|
343 |
+
|
344 |
+
def forward(self, x, H, W):
|
345 |
+
"""Forward function.
|
346 |
+
Args:
|
347 |
+
x: Input feature, tensor size (B, H*W, C).
|
348 |
+
H, W: Spatial resolution of the input feature.
|
349 |
+
"""
|
350 |
+
B, L, C = x.shape
|
351 |
+
assert L == H * W, "input feature has wrong size"
|
352 |
+
|
353 |
+
x = x.view(B, H, W, C)
|
354 |
+
|
355 |
+
# padding
|
356 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
357 |
+
if pad_input:
|
358 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
359 |
+
|
360 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
361 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
362 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
363 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
364 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
365 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
366 |
+
|
367 |
+
x = self.norm(x)
|
368 |
+
x = self.reduction(x)
|
369 |
+
|
370 |
+
return x
|
371 |
+
|
372 |
+
|
373 |
+
class BasicLayer(nn.Module):
|
374 |
+
"""A basic Swin Transformer layer for one stage.
|
375 |
+
Args:
|
376 |
+
dim (int): Number of feature channels
|
377 |
+
depth (int): Depths of this stage.
|
378 |
+
num_heads (int): Number of attention head.
|
379 |
+
window_size (int): Local window size. Default: 7.
|
380 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
381 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
382 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
383 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
384 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
385 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
386 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
387 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
388 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
389 |
+
"""
|
390 |
+
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
dim,
|
394 |
+
depth,
|
395 |
+
num_heads,
|
396 |
+
window_size=7,
|
397 |
+
mlp_ratio=4.0,
|
398 |
+
qkv_bias=True,
|
399 |
+
qk_scale=None,
|
400 |
+
drop=0.0,
|
401 |
+
attn_drop=0.0,
|
402 |
+
drop_path=0.0,
|
403 |
+
norm_layer=nn.LayerNorm,
|
404 |
+
downsample=None,
|
405 |
+
use_checkpoint=False,
|
406 |
+
):
|
407 |
+
super().__init__()
|
408 |
+
self.window_size = window_size
|
409 |
+
self.shift_size = window_size // 2
|
410 |
+
self.depth = depth
|
411 |
+
self.use_checkpoint = use_checkpoint
|
412 |
+
|
413 |
+
# build blocks
|
414 |
+
self.blocks = nn.ModuleList(
|
415 |
+
[
|
416 |
+
SwinTransformerBlock(
|
417 |
+
dim=dim,
|
418 |
+
num_heads=num_heads,
|
419 |
+
window_size=window_size,
|
420 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
421 |
+
mlp_ratio=mlp_ratio,
|
422 |
+
qkv_bias=qkv_bias,
|
423 |
+
qk_scale=qk_scale,
|
424 |
+
drop=drop,
|
425 |
+
attn_drop=attn_drop,
|
426 |
+
drop_path=drop_path[i]
|
427 |
+
if isinstance(drop_path, list)
|
428 |
+
else drop_path,
|
429 |
+
norm_layer=norm_layer,
|
430 |
+
)
|
431 |
+
for i in range(depth)
|
432 |
+
]
|
433 |
+
)
|
434 |
+
|
435 |
+
# patch merging layer
|
436 |
+
if downsample is not None:
|
437 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
438 |
+
else:
|
439 |
+
self.downsample = None
|
440 |
+
|
441 |
+
def forward(self, x, H, W):
|
442 |
+
"""Forward function.
|
443 |
+
Args:
|
444 |
+
x: Input feature, tensor size (B, H*W, C).
|
445 |
+
H, W: Spatial resolution of the input feature.
|
446 |
+
"""
|
447 |
+
|
448 |
+
# calculate attention mask for SW-MSA
|
449 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
450 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
451 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
452 |
+
h_slices = (
|
453 |
+
slice(0, -self.window_size),
|
454 |
+
slice(-self.window_size, -self.shift_size),
|
455 |
+
slice(-self.shift_size, None),
|
456 |
+
)
|
457 |
+
w_slices = (
|
458 |
+
slice(0, -self.window_size),
|
459 |
+
slice(-self.window_size, -self.shift_size),
|
460 |
+
slice(-self.shift_size, None),
|
461 |
+
)
|
462 |
+
cnt = 0
|
463 |
+
for h in h_slices:
|
464 |
+
for w in w_slices:
|
465 |
+
img_mask[:, h, w, :] = cnt
|
466 |
+
cnt += 1
|
467 |
+
|
468 |
+
mask_windows = window_partition(
|
469 |
+
img_mask, self.window_size
|
470 |
+
) # nW, window_size, window_size, 1
|
471 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
472 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
473 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
474 |
+
attn_mask == 0, float(0.0)
|
475 |
+
)
|
476 |
+
|
477 |
+
for blk in self.blocks:
|
478 |
+
blk.H, blk.W = H, W
|
479 |
+
if self.use_checkpoint:
|
480 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
481 |
+
else:
|
482 |
+
x = blk(x, attn_mask)
|
483 |
+
if self.downsample is not None:
|
484 |
+
x_down = self.downsample(x, H, W)
|
485 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
486 |
+
return x, H, W, x_down, Wh, Ww
|
487 |
+
else:
|
488 |
+
return x, H, W, x, H, W
|
489 |
+
|
490 |
+
|
491 |
+
class PatchEmbed(nn.Module):
|
492 |
+
"""Image to Patch Embedding
|
493 |
+
Args:
|
494 |
+
patch_size (int): Patch token size. Default: 4.
|
495 |
+
in_chans (int): Number of input image channels. Default: 3.
|
496 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
497 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
498 |
+
"""
|
499 |
+
|
500 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
501 |
+
super().__init__()
|
502 |
+
patch_size = to_2tuple(patch_size)
|
503 |
+
self.patch_size = patch_size
|
504 |
+
|
505 |
+
self.in_chans = in_chans
|
506 |
+
self.embed_dim = embed_dim
|
507 |
+
|
508 |
+
self.proj = nn.Conv2d(
|
509 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
510 |
+
)
|
511 |
+
if norm_layer is not None:
|
512 |
+
self.norm = norm_layer(embed_dim)
|
513 |
+
else:
|
514 |
+
self.norm = None
|
515 |
+
|
516 |
+
def forward(self, x):
|
517 |
+
"""Forward function."""
|
518 |
+
# padding
|
519 |
+
_, _, H, W = x.size()
|
520 |
+
if W % self.patch_size[1] != 0:
|
521 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
522 |
+
if H % self.patch_size[0] != 0:
|
523 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
524 |
+
|
525 |
+
x = self.proj(x) # B C Wh Ww
|
526 |
+
if self.norm is not None:
|
527 |
+
Wh, Ww = x.size(2), x.size(3)
|
528 |
+
x = x.flatten(2).transpose(1, 2)
|
529 |
+
x = self.norm(x)
|
530 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
531 |
+
|
532 |
+
return x
|
533 |
+
|
534 |
+
|
535 |
+
class SwinTransformer(nn.Module):
|
536 |
+
"""Swin Transformer backbone.
|
537 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
538 |
+
https://arxiv.org/pdf/2103.14030
|
539 |
+
Args:
|
540 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
541 |
+
used in absolute postion embedding. Default 224.
|
542 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
543 |
+
in_chans (int): Number of input image channels. Default: 3.
|
544 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
545 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
546 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
547 |
+
window_size (int): Window size. Default: 7.
|
548 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
549 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
550 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
551 |
+
drop_rate (float): Dropout rate.
|
552 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
553 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
554 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
555 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
556 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
557 |
+
out_indices (Sequence[int]): Output from which stages.
|
558 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
559 |
+
-1 means not freezing any parameters.
|
560 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
561 |
+
"""
|
562 |
+
|
563 |
+
def __init__(
|
564 |
+
self,
|
565 |
+
pretrain_img_size=224,
|
566 |
+
patch_size=4,
|
567 |
+
in_chans=3,
|
568 |
+
embed_dim=96,
|
569 |
+
depths=[2, 2, 6, 2],
|
570 |
+
num_heads=[3, 6, 12, 24],
|
571 |
+
window_size=7,
|
572 |
+
mlp_ratio=4.0,
|
573 |
+
qkv_bias=True,
|
574 |
+
qk_scale=None,
|
575 |
+
drop_rate=0.0,
|
576 |
+
attn_drop_rate=0.0,
|
577 |
+
drop_path_rate=0.2,
|
578 |
+
norm_layer=nn.LayerNorm,
|
579 |
+
ape=False,
|
580 |
+
patch_norm=True,
|
581 |
+
out_indices=(0, 1, 2, 3),
|
582 |
+
norm_indices=None,
|
583 |
+
frozen_stages=-1,
|
584 |
+
use_checkpoint=False,
|
585 |
+
projection=False,
|
586 |
+
project_dim=256,
|
587 |
+
):
|
588 |
+
super().__init__()
|
589 |
+
|
590 |
+
self.pretrain_img_size = pretrain_img_size
|
591 |
+
self.num_layers = len(depths)
|
592 |
+
self.embed_dim = embed_dim
|
593 |
+
self.ape = ape
|
594 |
+
self.patch_norm = patch_norm
|
595 |
+
self.out_indices = out_indices
|
596 |
+
self.norm_indices = norm_indices if norm_indices is not None else out_indices
|
597 |
+
self.frozen_stages = frozen_stages
|
598 |
+
|
599 |
+
# split image into non-overlapping patches
|
600 |
+
self.patch_embed = PatchEmbed(
|
601 |
+
patch_size=patch_size,
|
602 |
+
in_chans=in_chans,
|
603 |
+
embed_dim=embed_dim,
|
604 |
+
norm_layer=norm_layer if self.patch_norm else None,
|
605 |
+
)
|
606 |
+
|
607 |
+
# absolute position embedding
|
608 |
+
if self.ape:
|
609 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
610 |
+
patch_size = to_2tuple(patch_size)
|
611 |
+
patches_resolution = [
|
612 |
+
pretrain_img_size[0] // patch_size[0],
|
613 |
+
pretrain_img_size[1] // patch_size[1],
|
614 |
+
]
|
615 |
+
|
616 |
+
self.absolute_pos_embed = nn.Parameter(
|
617 |
+
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
|
618 |
+
)
|
619 |
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
620 |
+
|
621 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
622 |
+
|
623 |
+
# stochastic depth
|
624 |
+
dpr = [
|
625 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
626 |
+
] # stochastic depth decay rule
|
627 |
+
|
628 |
+
# build layers
|
629 |
+
self.layers = nn.ModuleList()
|
630 |
+
for i_layer in range(self.num_layers):
|
631 |
+
layer = BasicLayer(
|
632 |
+
dim=int(embed_dim * 2 ** i_layer),
|
633 |
+
depth=depths[i_layer],
|
634 |
+
num_heads=num_heads[i_layer],
|
635 |
+
window_size=window_size,
|
636 |
+
mlp_ratio=mlp_ratio,
|
637 |
+
qkv_bias=qkv_bias,
|
638 |
+
qk_scale=qk_scale,
|
639 |
+
drop=drop_rate,
|
640 |
+
attn_drop=attn_drop_rate,
|
641 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
642 |
+
norm_layer=norm_layer,
|
643 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
644 |
+
use_checkpoint=use_checkpoint,
|
645 |
+
)
|
646 |
+
self.layers.append(layer)
|
647 |
+
|
648 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
649 |
+
self.num_features = num_features
|
650 |
+
|
651 |
+
# add a norm layer for each output
|
652 |
+
for i_layer in self.norm_indices:
|
653 |
+
if i_layer >= len(self.num_features):
|
654 |
+
continue
|
655 |
+
layer = norm_layer(num_features[i_layer])
|
656 |
+
layer_name = f"norm{i_layer}"
|
657 |
+
self.add_module(layer_name, layer)
|
658 |
+
# add projector head
|
659 |
+
self.projection = projection
|
660 |
+
if projection:
|
661 |
+
self.project_dim = project_dim
|
662 |
+
self.norm = norm_layer(self.num_features[-1])
|
663 |
+
self.projector = nn.Linear(self.num_features[-1], project_dim, bias=False)
|
664 |
+
self._freeze_stages()
|
665 |
+
|
666 |
+
def _freeze_stages(self):
|
667 |
+
if self.frozen_stages >= 0:
|
668 |
+
self.patch_embed.eval()
|
669 |
+
for param in self.patch_embed.parameters():
|
670 |
+
param.requires_grad = False
|
671 |
+
|
672 |
+
if self.frozen_stages >= 1 and self.ape:
|
673 |
+
self.absolute_pos_embed.requires_grad = False
|
674 |
+
|
675 |
+
if self.frozen_stages >= 2:
|
676 |
+
self.pos_drop.eval()
|
677 |
+
for i in range(0, self.frozen_stages - 1):
|
678 |
+
m = self.layers[i]
|
679 |
+
m.eval()
|
680 |
+
for param in m.parameters():
|
681 |
+
param.requires_grad = False
|
682 |
+
|
683 |
+
def init_weights(self, pretrained=None):
|
684 |
+
"""Initialize the weights in backbone.
|
685 |
+
Args:
|
686 |
+
pretrained (str, optional): Path to pre-trained weights.
|
687 |
+
Defaults to None.
|
688 |
+
"""
|
689 |
+
|
690 |
+
def _init_weights(m):
|
691 |
+
if isinstance(m, nn.Linear):
|
692 |
+
trunc_normal_(m.weight, std=0.02)
|
693 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
694 |
+
nn.init.constant_(m.bias, 0)
|
695 |
+
elif isinstance(m, nn.LayerNorm):
|
696 |
+
nn.init.constant_(m.bias, 0)
|
697 |
+
nn.init.constant_(m.weight, 1.0)
|
698 |
+
|
699 |
+
def forward(self, x):
|
700 |
+
"""Forward function."""
|
701 |
+
x = self.patch_embed(x)
|
702 |
+
|
703 |
+
Wh, Ww = x.size(2), x.size(3)
|
704 |
+
if self.ape:
|
705 |
+
# interpolate the position embedding to the corresponding size
|
706 |
+
absolute_pos_embed = F.interpolate(
|
707 |
+
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
708 |
+
)
|
709 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
710 |
+
else:
|
711 |
+
x = x.flatten(2).transpose(1, 2)
|
712 |
+
x = self.pos_drop(x)
|
713 |
+
|
714 |
+
outs = {}
|
715 |
+
for i in range(self.num_layers):
|
716 |
+
layer = self.layers[i]
|
717 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
718 |
+
|
719 |
+
if i in self.out_indices:
|
720 |
+
if i in self.norm_indices:
|
721 |
+
norm_layer = getattr(self, f"norm{i}")
|
722 |
+
x_out = norm_layer(x_out)
|
723 |
+
out = (
|
724 |
+
x_out.view(-1, H, W, self.num_features[i])
|
725 |
+
.permute(0, 3, 1, 2)
|
726 |
+
.contiguous()
|
727 |
+
)
|
728 |
+
outs["res{}".format(i + 2)] = out
|
729 |
+
if self.projection:
|
730 |
+
x_out = self.norm(x_out)
|
731 |
+
x_out = x_out.view(-1, H, W, self.num_features[-1]).contiguous()
|
732 |
+
outs["fc"] = self.projector(x_out).permute(0, 3, 1, 2)
|
733 |
+
|
734 |
+
return outs
|
735 |
+
|
736 |
+
def train(self, mode=True):
|
737 |
+
"""Convert the model into training mode while keep layers freezed."""
|
738 |
+
super(SwinTransformer, self).train(mode)
|
739 |
+
self._freeze_stages()
|
740 |
+
|
741 |
+
|
742 |
+
@BACKBONE_REGISTRY.register()
|
743 |
+
class D2SwinTransformer(SwinTransformer, Backbone):
|
744 |
+
def __init__(self, cfg, input_shape):
|
745 |
+
|
746 |
+
pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
|
747 |
+
patch_size = cfg.MODEL.SWIN.PATCH_SIZE
|
748 |
+
in_chans = 3
|
749 |
+
embed_dim = cfg.MODEL.SWIN.EMBED_DIM
|
750 |
+
depths = cfg.MODEL.SWIN.DEPTHS
|
751 |
+
num_heads = cfg.MODEL.SWIN.NUM_HEADS
|
752 |
+
window_size = cfg.MODEL.SWIN.WINDOW_SIZE
|
753 |
+
mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
|
754 |
+
qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
|
755 |
+
qk_scale = cfg.MODEL.SWIN.QK_SCALE
|
756 |
+
drop_rate = cfg.MODEL.SWIN.DROP_RATE
|
757 |
+
attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
|
758 |
+
drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
|
759 |
+
norm_layer = nn.LayerNorm
|
760 |
+
ape = cfg.MODEL.SWIN.APE
|
761 |
+
patch_norm = cfg.MODEL.SWIN.PATCH_NORM
|
762 |
+
norm_indices = cfg.MODEL.SWIN.NORM_INDICES
|
763 |
+
projection = cfg.MODEL.SWIN.PROJECTION
|
764 |
+
project_dim = cfg.MODEL.SWIN.PROJECT_DIM
|
765 |
+
super().__init__(
|
766 |
+
pretrain_img_size,
|
767 |
+
patch_size,
|
768 |
+
in_chans,
|
769 |
+
embed_dim,
|
770 |
+
depths,
|
771 |
+
num_heads,
|
772 |
+
window_size,
|
773 |
+
mlp_ratio,
|
774 |
+
qkv_bias,
|
775 |
+
qk_scale,
|
776 |
+
drop_rate,
|
777 |
+
attn_drop_rate,
|
778 |
+
drop_path_rate,
|
779 |
+
norm_layer,
|
780 |
+
ape,
|
781 |
+
patch_norm,
|
782 |
+
norm_indices=norm_indices,
|
783 |
+
projection=projection,
|
784 |
+
project_dim=project_dim,
|
785 |
+
)
|
786 |
+
|
787 |
+
self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
|
788 |
+
|
789 |
+
self._out_feature_strides = {
|
790 |
+
"res2": 4,
|
791 |
+
"res3": 8,
|
792 |
+
"res4": 16,
|
793 |
+
"res5": 32,
|
794 |
+
"fc": 32,
|
795 |
+
}
|
796 |
+
self._out_feature_channels = {
|
797 |
+
"res2": self.num_features[0],
|
798 |
+
"res3": self.num_features[1],
|
799 |
+
"res4": self.num_features[2],
|
800 |
+
"res5": self.num_features[3],
|
801 |
+
"fc": self.num_features[3],
|
802 |
+
}
|
803 |
+
|
804 |
+
def forward(self, x):
|
805 |
+
"""
|
806 |
+
Args:
|
807 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
808 |
+
Returns:
|
809 |
+
dict[str->Tensor]: names and the corresponding features
|
810 |
+
"""
|
811 |
+
assert (
|
812 |
+
x.dim() == 4
|
813 |
+
), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
814 |
+
outputs = {}
|
815 |
+
y = super().forward(x)
|
816 |
+
for k in y.keys():
|
817 |
+
if k in self._out_features:
|
818 |
+
outputs[k] = y[k]
|
819 |
+
return outputs
|
820 |
+
|
821 |
+
def output_shape(self):
|
822 |
+
return {
|
823 |
+
name: ShapeSpec(
|
824 |
+
channels=self._out_feature_channels[name],
|
825 |
+
stride=self._out_feature_strides[name],
|
826 |
+
)
|
827 |
+
for name in self._out_features
|
828 |
+
}
|
829 |
+
|
830 |
+
@property
|
831 |
+
def size_divisibility(self):
|
832 |
+
return 32
|
open_vocab_seg/modeling/clip_adapter/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from .text_template import (
|
5 |
+
PredefinedPromptExtractor,
|
6 |
+
ImageNetPromptExtractor,
|
7 |
+
VILDPromptExtractor,
|
8 |
+
)
|
9 |
+
from .adapter import ClipAdapter, MaskFormerClipAdapter
|
10 |
+
|
11 |
+
|
12 |
+
def build_text_prompt(cfg):
|
13 |
+
if cfg.TEXT_TEMPLATES == "predefined":
|
14 |
+
text_templates = PredefinedPromptExtractor(cfg.PREDEFINED_PROMPT_TEMPLATES)
|
15 |
+
elif cfg.TEXT_TEMPLATES == "imagenet":
|
16 |
+
text_templates = ImageNetPromptExtractor()
|
17 |
+
elif cfg.TEXT_TEMPLATES == "vild":
|
18 |
+
text_templates = VILDPromptExtractor()
|
19 |
+
else:
|
20 |
+
raise NotImplementedError(
|
21 |
+
"Prompt learner {} is not supported".format(cfg.TEXT_TEMPLATES)
|
22 |
+
)
|
23 |
+
return text_templates
|
24 |
+
|
25 |
+
from .clip import tokenize
|
open_vocab_seg/modeling/clip_adapter/adapter.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
# Modified by Feng Liang from
|
4 |
+
# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/adapter.py
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from detectron2.structures import BitMasks
|
11 |
+
from .utils import build_clip_model, crop_with_mask
|
12 |
+
from .text_template import PromptExtractor
|
13 |
+
|
14 |
+
|
15 |
+
PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
16 |
+
PIXEL_STD = (0.26862954, 0.26130258, 0.27577711)
|
17 |
+
|
18 |
+
|
19 |
+
class ClipAdapter(nn.Module):
|
20 |
+
def __init__(self, clip_model_name: str, mask_prompt_depth: int, text_templates: PromptExtractor):
|
21 |
+
super().__init__()
|
22 |
+
self.clip_model = build_clip_model(clip_model_name, mask_prompt_depth)
|
23 |
+
self.text_templates = text_templates
|
24 |
+
self.text_templates.init_buffer(self.clip_model)
|
25 |
+
self.text_feature_buffer = {}
|
26 |
+
|
27 |
+
def forward(self, image: torch.Tensor, text: List[str], **kwargs):
|
28 |
+
image = self._preprocess_image(image, **kwargs)
|
29 |
+
text_feature = self.get_text_features(text) # k,feat_dim
|
30 |
+
image_features = self.get_image_features(image)
|
31 |
+
return self.get_sim_logits(text_feature, image_features)
|
32 |
+
|
33 |
+
def _preprocess_image(self, image: torch.Tensor):
|
34 |
+
return image
|
35 |
+
|
36 |
+
def _get_text_features(self, noun_list: List[str]):
|
37 |
+
left_noun_list = [
|
38 |
+
noun for noun in noun_list if noun not in self.text_feature_buffer
|
39 |
+
]
|
40 |
+
if len(left_noun_list) > 0:
|
41 |
+
left_text_features = self.text_templates(
|
42 |
+
left_noun_list, self.clip_model
|
43 |
+
)
|
44 |
+
self.text_feature_buffer.update(
|
45 |
+
{
|
46 |
+
noun: text_feature
|
47 |
+
for noun, text_feature in zip(
|
48 |
+
left_noun_list, left_text_features
|
49 |
+
)
|
50 |
+
}
|
51 |
+
)
|
52 |
+
return torch.stack([self.text_feature_buffer[noun] for noun in noun_list])
|
53 |
+
|
54 |
+
|
55 |
+
def get_text_features(self, noun_list: List[str]):
|
56 |
+
return self._get_text_features(noun_list)
|
57 |
+
|
58 |
+
def get_image_features(self, image: torch.Tensor):
|
59 |
+
image_features = self.clip_model.visual(image)
|
60 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
61 |
+
return image_features
|
62 |
+
|
63 |
+
def get_sim_logits(
|
64 |
+
self,
|
65 |
+
text_features: torch.Tensor,
|
66 |
+
image_features: torch.Tensor,
|
67 |
+
temperature: float = 100,
|
68 |
+
):
|
69 |
+
return temperature * image_features @ text_features.T
|
70 |
+
|
71 |
+
def normalize_feature(self, feat: torch.Tensor):
|
72 |
+
return feat / feat.norm(dim=-1, keepdim=True)
|
73 |
+
|
74 |
+
|
75 |
+
class MaskFormerClipAdapter(ClipAdapter):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
clip_model_name: str,
|
79 |
+
text_templates: PromptExtractor,
|
80 |
+
mask_fill: str = "mean",
|
81 |
+
mask_expand_ratio: float = 1.0,
|
82 |
+
mask_thr: float = 0.5,
|
83 |
+
mask_matting: bool = False,
|
84 |
+
region_resized: bool = True,
|
85 |
+
mask_prompt_depth: int = 0,
|
86 |
+
mask_prompt_fwd: bool = False,
|
87 |
+
):
|
88 |
+
super().__init__(clip_model_name, mask_prompt_depth, text_templates)
|
89 |
+
self.non_object_embedding = nn.Parameter(
|
90 |
+
torch.empty(1, self.clip_model.text_projection.shape[-1])
|
91 |
+
)
|
92 |
+
nn.init.normal_(
|
93 |
+
self.non_object_embedding.data,
|
94 |
+
std=self.clip_model.transformer.width ** -0.5,
|
95 |
+
)
|
96 |
+
# for test
|
97 |
+
self.mask_fill = mask_fill
|
98 |
+
if self.mask_fill == "zero":
|
99 |
+
self.mask_fill = (0.0, 0.0, 0.0)
|
100 |
+
elif self.mask_fill == "mean":
|
101 |
+
self.mask_fill = [255.0 * c for c in PIXEL_MEAN]
|
102 |
+
else:
|
103 |
+
raise NotImplementedError(
|
104 |
+
"Unknown mask_fill method: {}".format(self.mask_fill)
|
105 |
+
)
|
106 |
+
self.mask_expand_ratio = mask_expand_ratio
|
107 |
+
self.mask_thr = mask_thr
|
108 |
+
self.mask_matting = mask_matting
|
109 |
+
self.region_resized = region_resized
|
110 |
+
self.mask_prompt_fwd = mask_prompt_fwd
|
111 |
+
self.register_buffer(
|
112 |
+
"pixel_mean", torch.Tensor(PIXEL_MEAN).reshape(1, 3, 1, 1) * 255.0
|
113 |
+
)
|
114 |
+
self.register_buffer(
|
115 |
+
"pixel_std", torch.Tensor(PIXEL_STD).reshape(1, 3, 1, 1) * 255.0
|
116 |
+
)
|
117 |
+
|
118 |
+
def forward(
|
119 |
+
self,
|
120 |
+
image: torch.Tensor,
|
121 |
+
text: List[str],
|
122 |
+
mask: torch.Tensor,
|
123 |
+
normalize: bool = True,
|
124 |
+
fwd_w_region_mask: bool = False,
|
125 |
+
):
|
126 |
+
(regions, unnorm_regions), region_masks, valid_flag = self._preprocess_image(image, mask, normalize=normalize)
|
127 |
+
if regions is None:
|
128 |
+
return None, valid_flag
|
129 |
+
if isinstance(regions, list):
|
130 |
+
assert NotImplementedError
|
131 |
+
image_features = torch.cat(
|
132 |
+
[self.get_image_features(image_i) for image_i in regions], dim=0
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
if self.mask_prompt_fwd:
|
136 |
+
image_features = self.get_image_features(regions, region_masks)
|
137 |
+
else:
|
138 |
+
image_features = self.get_image_features(regions)
|
139 |
+
text_feature = self.get_text_features(text) # k,feat_dim
|
140 |
+
return self.get_sim_logits(text_feature, image_features), unnorm_regions, valid_flag
|
141 |
+
|
142 |
+
def get_image_features(self, image, region_masks=None):
|
143 |
+
image_features = self.clip_model.visual(image, region_masks)
|
144 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
145 |
+
return image_features
|
146 |
+
|
147 |
+
def _preprocess_image(
|
148 |
+
self, image: torch.Tensor, mask: torch.Tensor, normalize: bool = True
|
149 |
+
):
|
150 |
+
"""crop, mask and normalize the image
|
151 |
+
|
152 |
+
Args:
|
153 |
+
image ([type]): [C,H,W]
|
154 |
+
mask ([type]): [K,H,W
|
155 |
+
normalize (bool, optional): [description]. Defaults to True.
|
156 |
+
"""
|
157 |
+
dtype = mask.dtype
|
158 |
+
bin_mask = mask > self.mask_thr
|
159 |
+
valid = bin_mask.sum(dim=(-1, -2)) > 0
|
160 |
+
bin_mask = bin_mask[valid]
|
161 |
+
mask = mask[valid]
|
162 |
+
if not self.mask_matting:
|
163 |
+
mask = bin_mask
|
164 |
+
bin_mask = BitMasks(bin_mask)
|
165 |
+
bboxes = bin_mask.get_bounding_boxes()
|
166 |
+
# crop,mask
|
167 |
+
regions = []
|
168 |
+
region_masks = []
|
169 |
+
for bbox, single_mask in zip(bboxes, mask):
|
170 |
+
region, region_mask = crop_with_mask(
|
171 |
+
image.type(dtype),
|
172 |
+
single_mask.type(dtype),
|
173 |
+
bbox,
|
174 |
+
fill=self.mask_fill,
|
175 |
+
expand_ratio=self.mask_expand_ratio,
|
176 |
+
)
|
177 |
+
regions.append(region.unsqueeze(0))
|
178 |
+
region_masks.append(region_mask.unsqueeze(0))
|
179 |
+
if len(regions) == 0:
|
180 |
+
return None, valid
|
181 |
+
unnorm_regions = regions
|
182 |
+
if normalize:
|
183 |
+
regions = [(r - self.pixel_mean) / self.pixel_std for r in regions]
|
184 |
+
# resize
|
185 |
+
if self.region_resized:
|
186 |
+
regions = [
|
187 |
+
F.interpolate(r, size=(224, 224), mode="bicubic") for r in regions
|
188 |
+
]
|
189 |
+
regions = torch.cat(regions)
|
190 |
+
region_masks = [
|
191 |
+
F.interpolate(r, size=(224, 224), mode="nearest") for r in region_masks
|
192 |
+
]
|
193 |
+
region_masks = torch.cat(region_masks)
|
194 |
+
unnorm_regions = [
|
195 |
+
F.interpolate(r, size=(224, 224), mode="bicubic") for r in unnorm_regions
|
196 |
+
]
|
197 |
+
unnorm_regions = torch.cat(unnorm_regions)
|
198 |
+
return (regions, unnorm_regions), region_masks, valid
|
199 |
+
|
200 |
+
def get_text_features(self, noun_list: List[str]):
|
201 |
+
object_text_features = self._get_text_features(noun_list)
|
202 |
+
non_object_text_features = (
|
203 |
+
self.non_object_embedding
|
204 |
+
/ self.non_object_embedding.norm(dim=-1, keepdim=True)
|
205 |
+
)
|
206 |
+
return torch.cat([object_text_features, non_object_text_features], dim=0)
|
open_vocab_seg/modeling/clip_adapter/clip/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .clip import *
|
open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
open_vocab_seg/modeling/clip_adapter/clip/clip.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from collections import OrderedDict
|
6 |
+
from typing import Union, List
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from .model import build_model
|
14 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
15 |
+
|
16 |
+
try:
|
17 |
+
from torchvision.transforms import InterpolationMode
|
18 |
+
|
19 |
+
BICUBIC = InterpolationMode.BICUBIC
|
20 |
+
except ImportError:
|
21 |
+
BICUBIC = Image.BICUBIC
|
22 |
+
|
23 |
+
|
24 |
+
if torch.__version__.split(".") < ["1", "7", "1"]:
|
25 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
26 |
+
|
27 |
+
|
28 |
+
__all__ = ["available_models", "load", "tokenize"]
|
29 |
+
_tokenizer = _Tokenizer()
|
30 |
+
|
31 |
+
_MODELS = {
|
32 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
33 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
34 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
35 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
36 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
37 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
38 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
39 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
44 |
+
os.makedirs(root, exist_ok=True)
|
45 |
+
filename = os.path.basename(url)
|
46 |
+
|
47 |
+
expected_sha256 = url.split("/")[-2]
|
48 |
+
download_target = os.path.join(root, filename)
|
49 |
+
|
50 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
51 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
52 |
+
|
53 |
+
if os.path.isfile(download_target):
|
54 |
+
if (
|
55 |
+
hashlib.sha256(open(download_target, "rb").read()).hexdigest()
|
56 |
+
== expected_sha256
|
57 |
+
):
|
58 |
+
return download_target
|
59 |
+
else:
|
60 |
+
warnings.warn(
|
61 |
+
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
62 |
+
)
|
63 |
+
|
64 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
65 |
+
with tqdm(
|
66 |
+
total=int(source.info().get("Content-Length")),
|
67 |
+
ncols=80,
|
68 |
+
unit="iB",
|
69 |
+
unit_scale=True,
|
70 |
+
) as loop:
|
71 |
+
while True:
|
72 |
+
buffer = source.read(8192)
|
73 |
+
if not buffer:
|
74 |
+
break
|
75 |
+
|
76 |
+
output.write(buffer)
|
77 |
+
loop.update(len(buffer))
|
78 |
+
|
79 |
+
if (
|
80 |
+
hashlib.sha256(open(download_target, "rb").read()).hexdigest()
|
81 |
+
!= expected_sha256
|
82 |
+
):
|
83 |
+
raise RuntimeError(
|
84 |
+
f"Model has been downloaded but the SHA256 checksum does not not match"
|
85 |
+
)
|
86 |
+
|
87 |
+
return download_target
|
88 |
+
|
89 |
+
|
90 |
+
def _transform(n_px):
|
91 |
+
return Compose(
|
92 |
+
[
|
93 |
+
Resize(n_px, interpolation=BICUBIC),
|
94 |
+
CenterCrop(n_px),
|
95 |
+
lambda image: image.convert("RGB"),
|
96 |
+
ToTensor(),
|
97 |
+
Normalize(
|
98 |
+
(0.48145466, 0.4578275, 0.40821073),
|
99 |
+
(0.26862954, 0.26130258, 0.27577711),
|
100 |
+
),
|
101 |
+
]
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def available_models() -> List[str]:
|
106 |
+
"""Returns the names of available CLIP models"""
|
107 |
+
return list(_MODELS.keys())
|
108 |
+
|
109 |
+
|
110 |
+
def load(
|
111 |
+
name: str,
|
112 |
+
mask_prompt_depth: int = 0,
|
113 |
+
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
114 |
+
jit=False,
|
115 |
+
):
|
116 |
+
"""Load a CLIP model
|
117 |
+
|
118 |
+
Parameters
|
119 |
+
----------
|
120 |
+
name : str
|
121 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
122 |
+
|
123 |
+
device : Union[str, torch.device]
|
124 |
+
The device to put the loaded model
|
125 |
+
|
126 |
+
jit : bool
|
127 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
128 |
+
|
129 |
+
Returns
|
130 |
+
-------
|
131 |
+
model : torch.nn.Module
|
132 |
+
The CLIP model
|
133 |
+
|
134 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
135 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
136 |
+
"""
|
137 |
+
if name in _MODELS:
|
138 |
+
model_path = _download(_MODELS[name])
|
139 |
+
elif os.path.isfile(name):
|
140 |
+
model_path = name
|
141 |
+
else:
|
142 |
+
raise RuntimeError(
|
143 |
+
f"Model {name} not found; available models = {available_models()}"
|
144 |
+
)
|
145 |
+
|
146 |
+
try:
|
147 |
+
# loading JIT archive
|
148 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
149 |
+
state_dict = None
|
150 |
+
except RuntimeError:
|
151 |
+
# loading saved state dict
|
152 |
+
if jit:
|
153 |
+
warnings.warn(
|
154 |
+
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
|
155 |
+
)
|
156 |
+
jit = False
|
157 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
158 |
+
if 'state_dict' in state_dict:
|
159 |
+
new_state_dict = OrderedDict()
|
160 |
+
for k, v in state_dict['state_dict'].items():
|
161 |
+
if k.startswith('module.'):
|
162 |
+
name = k[7:] # remove `module.`
|
163 |
+
new_state_dict[name] = v
|
164 |
+
state_dict = new_state_dict
|
165 |
+
|
166 |
+
if not jit:
|
167 |
+
model = build_model(state_dict or model.state_dict(), mask_prompt_depth).to(device)
|
168 |
+
if str(device) == "cpu":
|
169 |
+
model.float()
|
170 |
+
return model, _transform(model.visual.input_resolution)
|
171 |
+
|
172 |
+
# patch the device names
|
173 |
+
device_holder = torch.jit.trace(
|
174 |
+
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
|
175 |
+
)
|
176 |
+
device_node = [
|
177 |
+
n
|
178 |
+
for n in device_holder.graph.findAllNodes("prim::Constant")
|
179 |
+
if "Device" in repr(n)
|
180 |
+
][-1]
|
181 |
+
|
182 |
+
def patch_device(module):
|
183 |
+
try:
|
184 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
185 |
+
except RuntimeError:
|
186 |
+
graphs = []
|
187 |
+
|
188 |
+
if hasattr(module, "forward1"):
|
189 |
+
graphs.append(module.forward1.graph)
|
190 |
+
|
191 |
+
for graph in graphs:
|
192 |
+
for node in graph.findAllNodes("prim::Constant"):
|
193 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith(
|
194 |
+
"cuda"
|
195 |
+
):
|
196 |
+
node.copyAttributes(device_node)
|
197 |
+
|
198 |
+
model.apply(patch_device)
|
199 |
+
patch_device(model.encode_image)
|
200 |
+
patch_device(model.encode_text)
|
201 |
+
|
202 |
+
# patch dtype to float32 on CPU
|
203 |
+
if str(device) == "cpu":
|
204 |
+
float_holder = torch.jit.trace(
|
205 |
+
lambda: torch.ones([]).float(), example_inputs=[]
|
206 |
+
)
|
207 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
208 |
+
float_node = float_input.node()
|
209 |
+
|
210 |
+
def patch_float(module):
|
211 |
+
try:
|
212 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
213 |
+
except RuntimeError:
|
214 |
+
graphs = []
|
215 |
+
|
216 |
+
if hasattr(module, "forward1"):
|
217 |
+
graphs.append(module.forward1.graph)
|
218 |
+
|
219 |
+
for graph in graphs:
|
220 |
+
for node in graph.findAllNodes("aten::to"):
|
221 |
+
inputs = list(node.inputs())
|
222 |
+
for i in [
|
223 |
+
1,
|
224 |
+
2,
|
225 |
+
]: # dtype can be the second or third argument to aten::to()
|
226 |
+
if inputs[i].node()["value"] == 5:
|
227 |
+
inputs[i].node().copyAttributes(float_node)
|
228 |
+
|
229 |
+
model.apply(patch_float)
|
230 |
+
patch_float(model.encode_image)
|
231 |
+
patch_float(model.encode_text)
|
232 |
+
|
233 |
+
model.float()
|
234 |
+
|
235 |
+
return model, _transform(model.input_resolution.item())
|
236 |
+
|
237 |
+
|
238 |
+
def tokenize(
|
239 |
+
texts: Union[str, List[str]],
|
240 |
+
context_length: int = 77,
|
241 |
+
truncate: bool = False,
|
242 |
+
return_length: bool = False,
|
243 |
+
) -> torch.LongTensor:
|
244 |
+
"""
|
245 |
+
Returns the tokenized representation of given input string(s)
|
246 |
+
|
247 |
+
Parameters
|
248 |
+
----------
|
249 |
+
texts : Union[str, List[str]]
|
250 |
+
An input string or a list of input strings to tokenize
|
251 |
+
|
252 |
+
context_length : int
|
253 |
+
The context length to use; all CLIP models use 77 as the context length
|
254 |
+
|
255 |
+
truncate: bool
|
256 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
257 |
+
|
258 |
+
Returns
|
259 |
+
-------
|
260 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
261 |
+
"""
|
262 |
+
if isinstance(texts, str):
|
263 |
+
texts = [texts]
|
264 |
+
|
265 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
266 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
267 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
268 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
269 |
+
length = []
|
270 |
+
for i, tokens in enumerate(all_tokens):
|
271 |
+
if len(tokens) > context_length:
|
272 |
+
if truncate:
|
273 |
+
tokens = tokens[:context_length]
|
274 |
+
tokens[-1] = eot_token
|
275 |
+
length.append(context_length)
|
276 |
+
else:
|
277 |
+
raise RuntimeError(
|
278 |
+
f"Input {texts[i]} is too long for context length {context_length}"
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
length.append(len(tokens))
|
282 |
+
result[i, : len(tokens)] = torch.tensor(tokens)
|
283 |
+
if return_length:
|
284 |
+
return result, length
|
285 |
+
return result
|
open_vocab_seg/modeling/clip_adapter/clip/model.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
# Modified by Feng Liang from https://github.com/openai/CLIP/blob/main/clip/model.py
|
4 |
+
|
5 |
+
from collections import OrderedDict
|
6 |
+
from typing import Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
class Bottleneck(nn.Module):
|
15 |
+
expansion = 4
|
16 |
+
|
17 |
+
def __init__(self, inplanes, planes, stride=1):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
21 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
22 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
23 |
+
|
24 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
25 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
26 |
+
|
27 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
28 |
+
|
29 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
30 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
31 |
+
|
32 |
+
self.relu = nn.ReLU(inplace=True)
|
33 |
+
self.downsample = None
|
34 |
+
self.stride = stride
|
35 |
+
|
36 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
37 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
38 |
+
self.downsample = nn.Sequential(
|
39 |
+
OrderedDict(
|
40 |
+
[
|
41 |
+
("-1", nn.AvgPool2d(stride)),
|
42 |
+
(
|
43 |
+
"0",
|
44 |
+
nn.Conv2d(
|
45 |
+
inplanes,
|
46 |
+
planes * self.expansion,
|
47 |
+
1,
|
48 |
+
stride=1,
|
49 |
+
bias=False,
|
50 |
+
),
|
51 |
+
),
|
52 |
+
("1", nn.BatchNorm2d(planes * self.expansion)),
|
53 |
+
]
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor):
|
58 |
+
identity = x
|
59 |
+
|
60 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
61 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
62 |
+
out = self.avgpool(out)
|
63 |
+
out = self.bn3(self.conv3(out))
|
64 |
+
|
65 |
+
if self.downsample is not None:
|
66 |
+
identity = self.downsample(x)
|
67 |
+
|
68 |
+
out += identity
|
69 |
+
out = self.relu(out)
|
70 |
+
return out
|
71 |
+
|
72 |
+
|
73 |
+
class AttentionPool2d(nn.Module):
|
74 |
+
def __init__(
|
75 |
+
self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.positional_embedding = nn.Parameter(
|
79 |
+
torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
|
80 |
+
)
|
81 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
82 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
83 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
84 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
85 |
+
self.num_heads = num_heads
|
86 |
+
self.grid_size = spacial_dim
|
87 |
+
|
88 |
+
def forward(self, x, mask=None, return_cls=True):
|
89 |
+
b, c, gh, gw = x.shape
|
90 |
+
# remove irrelated feature
|
91 |
+
if mask is not None:
|
92 |
+
mask = F.interpolate(mask[:, None, ...], size=(gh, gw)).squeeze(
|
93 |
+
1
|
94 |
+
) # [N,H,W] -> [N,grid,grid]
|
95 |
+
mask = (mask > 0.5).reshape(mask.shape[0], -1)
|
96 |
+
mask = torch.cat([mask, mask.new_ones(mask.shape[0], 1)], dim=1)
|
97 |
+
if x.size()[0] == 1:
|
98 |
+
x = x.expand(mask.shape[0], c, gh, gw)
|
99 |
+
|
100 |
+
x = x.reshape(x.shape[0], c, gh * gw).permute(2, 0, 1) # NCHW -> (HW)NC
|
101 |
+
|
102 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
103 |
+
positional_embedding = self.positional_embedding
|
104 |
+
if not (self.positional_embedding.shape[0] == x.shape[0]):
|
105 |
+
cls_pos = positional_embedding[0:1, :]
|
106 |
+
per_pos_embedding = (
|
107 |
+
F.interpolate(
|
108 |
+
positional_embedding[1:, :]
|
109 |
+
.permute(1, 0)
|
110 |
+
.view(1, -1, self.grid_size, self.grid_size),
|
111 |
+
size=(gh, gw),
|
112 |
+
mode="bicubic",
|
113 |
+
)
|
114 |
+
.reshape(-1, gh * gw)
|
115 |
+
.permute(1, 0)
|
116 |
+
)
|
117 |
+
positional_embedding = torch.cat([cls_pos, per_pos_embedding])
|
118 |
+
|
119 |
+
x = x + positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
120 |
+
x, _ = F.multi_head_attention_forward(
|
121 |
+
query=x,
|
122 |
+
key=x,
|
123 |
+
value=x,
|
124 |
+
embed_dim_to_check=x.shape[-1],
|
125 |
+
num_heads=self.num_heads,
|
126 |
+
q_proj_weight=self.q_proj.weight,
|
127 |
+
k_proj_weight=self.k_proj.weight,
|
128 |
+
v_proj_weight=self.v_proj.weight,
|
129 |
+
in_proj_weight=None,
|
130 |
+
in_proj_bias=torch.cat(
|
131 |
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
|
132 |
+
),
|
133 |
+
bias_k=None,
|
134 |
+
bias_v=None,
|
135 |
+
add_zero_attn=False,
|
136 |
+
dropout_p=0,
|
137 |
+
out_proj_weight=self.c_proj.weight,
|
138 |
+
out_proj_bias=self.c_proj.bias,
|
139 |
+
use_separate_proj_weight=True,
|
140 |
+
training=self.training,
|
141 |
+
need_weights=False,
|
142 |
+
key_padding_mask=mask,
|
143 |
+
)
|
144 |
+
|
145 |
+
if return_cls:
|
146 |
+
return x[0]
|
147 |
+
else:
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class ModifiedResNet(nn.Module):
|
152 |
+
"""
|
153 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
154 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
155 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
156 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
160 |
+
super().__init__()
|
161 |
+
self.output_dim = output_dim
|
162 |
+
self.input_resolution = input_resolution
|
163 |
+
|
164 |
+
# the 3-layer stem
|
165 |
+
self.conv1 = nn.Conv2d(
|
166 |
+
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
|
167 |
+
)
|
168 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
169 |
+
self.conv2 = nn.Conv2d(
|
170 |
+
width // 2, width // 2, kernel_size=3, padding=1, bias=False
|
171 |
+
)
|
172 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
173 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
174 |
+
self.bn3 = nn.BatchNorm2d(width)
|
175 |
+
self.avgpool = nn.AvgPool2d(2)
|
176 |
+
self.relu = nn.ReLU(inplace=True)
|
177 |
+
|
178 |
+
# residual layers
|
179 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
180 |
+
self.layer1 = self._make_layer(width, layers[0])
|
181 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
182 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
183 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
184 |
+
|
185 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
186 |
+
self.attnpool = AttentionPool2d(
|
187 |
+
input_resolution // 32, embed_dim, heads, output_dim
|
188 |
+
)
|
189 |
+
|
190 |
+
def _make_layer(self, planes, blocks, stride=1):
|
191 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
192 |
+
|
193 |
+
self._inplanes = planes * Bottleneck.expansion
|
194 |
+
for _ in range(1, blocks):
|
195 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
196 |
+
|
197 |
+
return nn.Sequential(*layers)
|
198 |
+
|
199 |
+
def forward(self, x, mask: torch.Tensor = None, return_cls=True):
|
200 |
+
def stem(x):
|
201 |
+
for conv, bn in [
|
202 |
+
(self.conv1, self.bn1),
|
203 |
+
(self.conv2, self.bn2),
|
204 |
+
(self.conv3, self.bn3),
|
205 |
+
]:
|
206 |
+
x = self.relu(bn(conv(x)))
|
207 |
+
x = self.avgpool(x)
|
208 |
+
return x
|
209 |
+
|
210 |
+
x = x.type(self.conv1.weight.dtype)
|
211 |
+
x = stem(x) # 1/4,1/4
|
212 |
+
x = self.layer1(x)
|
213 |
+
x = self.layer2(x) # 1/8,1/8
|
214 |
+
x = self.layer3(x) # 1/16,1/16
|
215 |
+
x = self.layer4(x) # 1/32,1/32
|
216 |
+
b, c, gh, gw = x.shape
|
217 |
+
x = self.attnpool(x, mask, return_cls)
|
218 |
+
if not return_cls:
|
219 |
+
return x[1:].permute(1, 0, 2).reshape(b, gh, gw, x.shape[-1]) # N,L,C
|
220 |
+
return x
|
221 |
+
|
222 |
+
|
223 |
+
class LayerNorm(nn.LayerNorm):
|
224 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
225 |
+
|
226 |
+
def forward(self, x: torch.Tensor):
|
227 |
+
orig_type = x.dtype
|
228 |
+
ret = super().forward(x.type(torch.float32))
|
229 |
+
return ret.type(orig_type)
|
230 |
+
|
231 |
+
|
232 |
+
class QuickGELU(nn.Module):
|
233 |
+
def forward(self, x: torch.Tensor):
|
234 |
+
return x * torch.sigmoid(1.702 * x)
|
235 |
+
|
236 |
+
|
237 |
+
class ResidualAttentionBlock(nn.Module):
|
238 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
239 |
+
super().__init__()
|
240 |
+
|
241 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
242 |
+
self.ln_1 = LayerNorm(d_model)
|
243 |
+
self.mlp = nn.Sequential(
|
244 |
+
OrderedDict(
|
245 |
+
[
|
246 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
247 |
+
("gelu", QuickGELU()),
|
248 |
+
("c_proj", nn.Linear(d_model * 4, d_model)),
|
249 |
+
]
|
250 |
+
)
|
251 |
+
)
|
252 |
+
self.ln_2 = LayerNorm(d_model)
|
253 |
+
self.attn_mask = attn_mask
|
254 |
+
|
255 |
+
def attention(self, x: torch.Tensor, **kwargs):
|
256 |
+
self.attn_mask = (
|
257 |
+
self.attn_mask.to(dtype=x.dtype, device=x.device)
|
258 |
+
if self.attn_mask is not None
|
259 |
+
else None
|
260 |
+
)
|
261 |
+
return self.attn(
|
262 |
+
x, x, x, need_weights=False, attn_mask=self.attn_mask, **kwargs
|
263 |
+
)[0]
|
264 |
+
|
265 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
266 |
+
x = x + self.attention(self.ln_1(x), **kwargs)
|
267 |
+
x = x + self.mlp(self.ln_2(x))
|
268 |
+
return x
|
269 |
+
|
270 |
+
|
271 |
+
class Transformer(nn.Module):
|
272 |
+
def __init__(
|
273 |
+
self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
|
274 |
+
):
|
275 |
+
super().__init__()
|
276 |
+
self.width = width
|
277 |
+
self.layers = layers
|
278 |
+
self.resblocks = nn.Sequential(
|
279 |
+
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
|
280 |
+
)
|
281 |
+
|
282 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
283 |
+
for block in self.resblocks:
|
284 |
+
x = block(x, **kwargs)
|
285 |
+
return x
|
286 |
+
|
287 |
+
|
288 |
+
class VisionTransformer(nn.Module):
|
289 |
+
def __init__(
|
290 |
+
self,
|
291 |
+
input_resolution: int,
|
292 |
+
patch_size: int,
|
293 |
+
mask_prompt_depth: int,
|
294 |
+
width: int,
|
295 |
+
layers: int,
|
296 |
+
heads: int,
|
297 |
+
output_dim: int,
|
298 |
+
):
|
299 |
+
super().__init__()
|
300 |
+
self.input_resolution = input_resolution
|
301 |
+
self.output_dim = output_dim
|
302 |
+
self.conv1 = nn.Conv2d(
|
303 |
+
in_channels=3,
|
304 |
+
out_channels=width,
|
305 |
+
kernel_size=patch_size,
|
306 |
+
stride=patch_size,
|
307 |
+
bias=False,
|
308 |
+
)
|
309 |
+
|
310 |
+
scale = width ** -0.5
|
311 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
312 |
+
self.positional_embedding = nn.Parameter(
|
313 |
+
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
|
314 |
+
)
|
315 |
+
self.grid_size = input_resolution // patch_size
|
316 |
+
self.ln_pre = LayerNorm(width)
|
317 |
+
|
318 |
+
self.transformer = Transformer(width, layers, heads)
|
319 |
+
|
320 |
+
self.ln_post = LayerNorm(width)
|
321 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
322 |
+
|
323 |
+
self.mask_pool = nn.AvgPool2d(patch_size, stride=patch_size)
|
324 |
+
self.mask_prompt_depth = mask_prompt_depth
|
325 |
+
self.mask_embedding = nn.Parameter(torch.zeros(self.mask_prompt_depth, self.grid_size * self.grid_size, width))
|
326 |
+
|
327 |
+
def forward(self, x: torch.Tensor, m: torch.Tensor = None):
|
328 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
329 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
330 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
331 |
+
if m is not None:
|
332 |
+
m = self.mask_pool(m.to(torch.float).squeeze()).reshape(m.shape[0], -1).unsqueeze(-1)
|
333 |
+
m = torch.ceil(m)
|
334 |
+
if self.mask_embedding.shape[1] == 1:
|
335 |
+
mask_embedding = self.mask_embedding.to(x.dtype).repeat(1, x.shape[1], 1)
|
336 |
+
else:
|
337 |
+
mask_embedding = self.mask_embedding.to(x.dtype)
|
338 |
+
x = x * m + mask_embedding[0].unsqueeze(0) * (1 - m)
|
339 |
+
|
340 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
341 |
+
x = x + self.positional_embedding.to(x.dtype)
|
342 |
+
x = self.ln_pre(x)
|
343 |
+
|
344 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
345 |
+
if m is not None:
|
346 |
+
for i, blk in enumerate(self.transformer.resblocks):
|
347 |
+
d = i + 1
|
348 |
+
x = blk(x)
|
349 |
+
if d < self.mask_prompt_depth:
|
350 |
+
masked_x = x[1:, :, :] * m.permute(1, 0, 2) + \
|
351 |
+
mask_embedding[d].unsqueeze(0).permute(1, 0, 2) * (1 - m.permute(1, 0, 2))
|
352 |
+
x = torch.cat([x[:1, :, :], masked_x], dim=0)
|
353 |
+
else:
|
354 |
+
x = self.transformer(x)
|
355 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
356 |
+
|
357 |
+
x = self.ln_post(x[:, 0, :])
|
358 |
+
|
359 |
+
if self.proj is not None:
|
360 |
+
x = x @ self.proj
|
361 |
+
|
362 |
+
return x
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
class CLIP(nn.Module):
|
367 |
+
def __init__(
|
368 |
+
self,
|
369 |
+
embed_dim: int,
|
370 |
+
# vision
|
371 |
+
image_resolution: int,
|
372 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
373 |
+
vision_width: int,
|
374 |
+
vision_patch_size: int,
|
375 |
+
mask_prompt_depth: int,
|
376 |
+
# text
|
377 |
+
context_length: int,
|
378 |
+
vocab_size: int,
|
379 |
+
transformer_width: int,
|
380 |
+
transformer_heads: int,
|
381 |
+
transformer_layers: int,
|
382 |
+
):
|
383 |
+
super().__init__()
|
384 |
+
|
385 |
+
self.context_length = context_length
|
386 |
+
|
387 |
+
if isinstance(vision_layers, (tuple, list)):
|
388 |
+
vision_heads = vision_width * 32 // 64
|
389 |
+
self.visual = ModifiedResNet(
|
390 |
+
layers=vision_layers,
|
391 |
+
output_dim=embed_dim,
|
392 |
+
heads=vision_heads,
|
393 |
+
input_resolution=image_resolution,
|
394 |
+
width=vision_width,
|
395 |
+
)
|
396 |
+
else:
|
397 |
+
vision_heads = vision_width // 64
|
398 |
+
self.visual = VisionTransformer(
|
399 |
+
input_resolution=image_resolution,
|
400 |
+
patch_size=vision_patch_size,
|
401 |
+
mask_prompt_depth=mask_prompt_depth,
|
402 |
+
width=vision_width,
|
403 |
+
layers=vision_layers,
|
404 |
+
heads=vision_heads,
|
405 |
+
output_dim=embed_dim,
|
406 |
+
)
|
407 |
+
|
408 |
+
self.transformer = Transformer(
|
409 |
+
width=transformer_width,
|
410 |
+
layers=transformer_layers,
|
411 |
+
heads=transformer_heads,
|
412 |
+
attn_mask=self.build_attention_mask(),
|
413 |
+
)
|
414 |
+
|
415 |
+
self.vocab_size = vocab_size
|
416 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
417 |
+
self.positional_embedding = nn.Parameter(
|
418 |
+
torch.empty(self.context_length, transformer_width)
|
419 |
+
)
|
420 |
+
self.ln_final = LayerNorm(transformer_width)
|
421 |
+
|
422 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
423 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
424 |
+
|
425 |
+
self.initialize_parameters()
|
426 |
+
|
427 |
+
def initialize_parameters(self):
|
428 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
429 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
430 |
+
|
431 |
+
if isinstance(self.visual, ModifiedResNet):
|
432 |
+
if self.visual.attnpool is not None:
|
433 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
434 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
435 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
436 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
437 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
438 |
+
|
439 |
+
for resnet_block in [
|
440 |
+
self.visual.layer1,
|
441 |
+
self.visual.layer2,
|
442 |
+
self.visual.layer3,
|
443 |
+
self.visual.layer4,
|
444 |
+
]:
|
445 |
+
for name, param in resnet_block.named_parameters():
|
446 |
+
if name.endswith("bn3.weight"):
|
447 |
+
nn.init.zeros_(param)
|
448 |
+
|
449 |
+
proj_std = (self.transformer.width ** -0.5) * (
|
450 |
+
(2 * self.transformer.layers) ** -0.5
|
451 |
+
)
|
452 |
+
attn_std = self.transformer.width ** -0.5
|
453 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
454 |
+
for block in self.transformer.resblocks:
|
455 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
456 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
457 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
458 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
459 |
+
|
460 |
+
if self.text_projection is not None:
|
461 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
462 |
+
|
463 |
+
def build_attention_mask(self):
|
464 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
465 |
+
# pytorch uses additive attention mask; fill with -inf
|
466 |
+
mask = torch.empty(self.context_length, self.context_length)
|
467 |
+
mask.fill_(float("-inf"))
|
468 |
+
mask.triu_(1) # zero out the lower diagonal
|
469 |
+
return mask
|
470 |
+
|
471 |
+
@property
|
472 |
+
def dtype(self):
|
473 |
+
return self.visual.conv1.weight.dtype
|
474 |
+
|
475 |
+
def encode_image(self, image, **kwargs):
|
476 |
+
return self.visual(image.type(self.dtype), **kwargs)
|
477 |
+
|
478 |
+
def encode_text(self, text):
|
479 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
480 |
+
|
481 |
+
x = x + self.positional_embedding.type(self.dtype)
|
482 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
483 |
+
x = self.transformer(x)
|
484 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
485 |
+
x = self.ln_final(x).type(self.dtype)
|
486 |
+
|
487 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
488 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
489 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
490 |
+
|
491 |
+
return x
|
492 |
+
|
493 |
+
def forward(self, image, text):
|
494 |
+
image_features = self.encode_image(image)
|
495 |
+
text_features = self.encode_text(text)
|
496 |
+
|
497 |
+
# normalized features
|
498 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
499 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
500 |
+
|
501 |
+
# cosine similarity as logits
|
502 |
+
logit_scale = self.logit_scale.exp()
|
503 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
504 |
+
logits_per_text = logit_scale * text_features @ image_features.t()
|
505 |
+
|
506 |
+
# shape = [global_batch_size, global_batch_size]
|
507 |
+
return logits_per_image, logits_per_text
|
508 |
+
|
509 |
+
|
510 |
+
def convert_weights(model: nn.Module):
|
511 |
+
"""Convert applicable model parameters to fp16"""
|
512 |
+
|
513 |
+
def _convert_weights_to_fp16(l):
|
514 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
515 |
+
l.weight.data = l.weight.data.half()
|
516 |
+
if l.bias is not None:
|
517 |
+
l.bias.data = l.bias.data.half()
|
518 |
+
|
519 |
+
if isinstance(l, nn.MultiheadAttention):
|
520 |
+
for attr in [
|
521 |
+
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
|
522 |
+
"in_proj_bias",
|
523 |
+
"bias_k",
|
524 |
+
"bias_v",
|
525 |
+
]:
|
526 |
+
tensor = getattr(l, attr)
|
527 |
+
if tensor is not None:
|
528 |
+
tensor.data = tensor.data.half()
|
529 |
+
|
530 |
+
for name in ["text_projection", "proj"]:
|
531 |
+
if hasattr(l, name):
|
532 |
+
attr = getattr(l, name)
|
533 |
+
if attr is not None:
|
534 |
+
attr.data = attr.data.half()
|
535 |
+
|
536 |
+
model.apply(_convert_weights_to_fp16)
|
537 |
+
|
538 |
+
|
539 |
+
def build_model(state_dict: dict, mask_prompt_depth: int = 0):
|
540 |
+
vit = "visual.proj" in state_dict
|
541 |
+
|
542 |
+
if vit:
|
543 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
544 |
+
vision_layers = len(
|
545 |
+
[
|
546 |
+
k
|
547 |
+
for k in state_dict.keys()
|
548 |
+
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
|
549 |
+
]
|
550 |
+
)
|
551 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
552 |
+
grid_size = round(
|
553 |
+
(state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
|
554 |
+
)
|
555 |
+
image_resolution = vision_patch_size * grid_size
|
556 |
+
else:
|
557 |
+
assert mask_prompt_depth == 0, 'ResNets do not support mask prompt tuning'
|
558 |
+
counts: list = [
|
559 |
+
len(
|
560 |
+
set(
|
561 |
+
k.split(".")[2]
|
562 |
+
for k in state_dict
|
563 |
+
if k.startswith(f"visual.layer{b}")
|
564 |
+
)
|
565 |
+
)
|
566 |
+
for b in [1, 2, 3, 4]
|
567 |
+
]
|
568 |
+
vision_layers = tuple(counts)
|
569 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
570 |
+
output_width = round(
|
571 |
+
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
|
572 |
+
)
|
573 |
+
vision_patch_size = None
|
574 |
+
assert (
|
575 |
+
output_width ** 2 + 1
|
576 |
+
== state_dict["visual.attnpool.positional_embedding"].shape[0]
|
577 |
+
)
|
578 |
+
image_resolution = output_width * 32
|
579 |
+
|
580 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
581 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
582 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
583 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
584 |
+
transformer_heads = transformer_width // 64
|
585 |
+
transformer_layers = len(
|
586 |
+
set(
|
587 |
+
k.split(".")[2]
|
588 |
+
for k in state_dict
|
589 |
+
if k.startswith(f"transformer.resblocks")
|
590 |
+
)
|
591 |
+
)
|
592 |
+
|
593 |
+
model = CLIP(
|
594 |
+
embed_dim,
|
595 |
+
image_resolution,
|
596 |
+
vision_layers,
|
597 |
+
vision_width,
|
598 |
+
vision_patch_size,
|
599 |
+
mask_prompt_depth,
|
600 |
+
context_length,
|
601 |
+
vocab_size,
|
602 |
+
transformer_width,
|
603 |
+
transformer_heads,
|
604 |
+
transformer_layers,
|
605 |
+
)
|
606 |
+
|
607 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
608 |
+
if key in state_dict:
|
609 |
+
del state_dict[key]
|
610 |
+
|
611 |
+
convert_weights(model)
|
612 |
+
model.load_state_dict(state_dict, strict=False)
|
613 |
+
return model.eval()
|
open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(
|
13 |
+
os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
@lru_cache()
|
18 |
+
def bytes_to_unicode():
|
19 |
+
"""
|
20 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
21 |
+
The reversible bpe codes work on unicode strings.
|
22 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
23 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
24 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
25 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
26 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
27 |
+
"""
|
28 |
+
bs = (
|
29 |
+
list(range(ord("!"), ord("~") + 1))
|
30 |
+
+ list(range(ord("¡"), ord("¬") + 1))
|
31 |
+
+ list(range(ord("®"), ord("ÿ") + 1))
|
32 |
+
)
|
33 |
+
cs = bs[:]
|
34 |
+
n = 0
|
35 |
+
for b in range(2 ** 8):
|
36 |
+
if b not in bs:
|
37 |
+
bs.append(b)
|
38 |
+
cs.append(2 ** 8 + n)
|
39 |
+
n += 1
|
40 |
+
cs = [chr(n) for n in cs]
|
41 |
+
return dict(zip(bs, cs))
|
42 |
+
|
43 |
+
|
44 |
+
def get_pairs(word):
|
45 |
+
"""Return set of symbol pairs in a word.
|
46 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
47 |
+
"""
|
48 |
+
pairs = set()
|
49 |
+
prev_char = word[0]
|
50 |
+
for char in word[1:]:
|
51 |
+
pairs.add((prev_char, char))
|
52 |
+
prev_char = char
|
53 |
+
return pairs
|
54 |
+
|
55 |
+
|
56 |
+
def basic_clean(text):
|
57 |
+
text = ftfy.fix_text(text)
|
58 |
+
text = html.unescape(html.unescape(text))
|
59 |
+
return text.strip()
|
60 |
+
|
61 |
+
|
62 |
+
def whitespace_clean(text):
|
63 |
+
text = re.sub(r"\s+", " ", text)
|
64 |
+
text = text.strip()
|
65 |
+
return text
|
66 |
+
|
67 |
+
|
68 |
+
class SimpleTokenizer(object):
|
69 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
70 |
+
self.byte_encoder = bytes_to_unicode()
|
71 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
72 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
|
73 |
+
merges = merges[1 : 49152 - 256 - 2 + 1]
|
74 |
+
merges = [tuple(merge.split()) for merge in merges]
|
75 |
+
vocab = list(bytes_to_unicode().values())
|
76 |
+
vocab = vocab + [v + "</w>" for v in vocab]
|
77 |
+
for merge in merges:
|
78 |
+
vocab.append("".join(merge))
|
79 |
+
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
|
80 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
81 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
82 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
83 |
+
self.cache = {
|
84 |
+
"<|startoftext|>": "<|startoftext|>",
|
85 |
+
"<|endoftext|>": "<|endoftext|>",
|
86 |
+
}
|
87 |
+
self.pat = re.compile(
|
88 |
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
89 |
+
re.IGNORECASE,
|
90 |
+
)
|
91 |
+
|
92 |
+
def bpe(self, token):
|
93 |
+
if token in self.cache:
|
94 |
+
return self.cache[token]
|
95 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
96 |
+
pairs = get_pairs(word)
|
97 |
+
|
98 |
+
if not pairs:
|
99 |
+
return token + "</w>"
|
100 |
+
|
101 |
+
while True:
|
102 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
103 |
+
if bigram not in self.bpe_ranks:
|
104 |
+
break
|
105 |
+
first, second = bigram
|
106 |
+
new_word = []
|
107 |
+
i = 0
|
108 |
+
while i < len(word):
|
109 |
+
try:
|
110 |
+
j = word.index(first, i)
|
111 |
+
new_word.extend(word[i:j])
|
112 |
+
i = j
|
113 |
+
except:
|
114 |
+
new_word.extend(word[i:])
|
115 |
+
break
|
116 |
+
|
117 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
118 |
+
new_word.append(first + second)
|
119 |
+
i += 2
|
120 |
+
else:
|
121 |
+
new_word.append(word[i])
|
122 |
+
i += 1
|
123 |
+
new_word = tuple(new_word)
|
124 |
+
word = new_word
|
125 |
+
if len(word) == 1:
|
126 |
+
break
|
127 |
+
else:
|
128 |
+
pairs = get_pairs(word)
|
129 |
+
word = " ".join(word)
|
130 |
+
self.cache[token] = word
|
131 |
+
return word
|
132 |
+
|
133 |
+
def encode(self, text):
|
134 |
+
bpe_tokens = []
|
135 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
136 |
+
for token in re.findall(self.pat, text):
|
137 |
+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
138 |
+
bpe_tokens.extend(
|
139 |
+
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
140 |
+
)
|
141 |
+
return bpe_tokens
|
142 |
+
|
143 |
+
def decode(self, tokens):
|
144 |
+
text = "".join([self.decoder[token] for token in tokens])
|
145 |
+
text = (
|
146 |
+
bytearray([self.byte_decoder[c] for c in text])
|
147 |
+
.decode("utf-8", errors="replace")
|
148 |
+
.replace("</w>", " ")
|
149 |
+
)
|
150 |
+
return text
|
open_vocab_seg/modeling/clip_adapter/text_template.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
# Modified by Feng Liang from
|
4 |
+
# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/text_prompt.py
|
5 |
+
# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/utils.py
|
6 |
+
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
# import clip
|
10 |
+
from .clip import tokenize
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
IMAGENET_PROMPT = [
|
15 |
+
"a bad photo of a {}.",
|
16 |
+
"a photo of many {}.",
|
17 |
+
"a sculpture of a {}.",
|
18 |
+
"a photo of the hard to see {}.",
|
19 |
+
"a low resolution photo of the {}.",
|
20 |
+
"a rendering of a {}.",
|
21 |
+
"graffiti of a {}.",
|
22 |
+
"a bad photo of the {}.",
|
23 |
+
"a cropped photo of the {}.",
|
24 |
+
"a tattoo of a {}.",
|
25 |
+
"the embroidered {}.",
|
26 |
+
"a photo of a hard to see {}.",
|
27 |
+
"a bright photo of a {}.",
|
28 |
+
"a photo of a clean {}.",
|
29 |
+
"a photo of a dirty {}.",
|
30 |
+
"a dark photo of the {}.",
|
31 |
+
"a drawing of a {}.",
|
32 |
+
"a photo of my {}.",
|
33 |
+
"the plastic {}.",
|
34 |
+
"a photo of the cool {}.",
|
35 |
+
"a close-up photo of a {}.",
|
36 |
+
"a black and white photo of the {}.",
|
37 |
+
"a painting of the {}.",
|
38 |
+
"a painting of a {}.",
|
39 |
+
"a pixelated photo of the {}.",
|
40 |
+
"a sculpture of the {}.",
|
41 |
+
"a bright photo of the {}.",
|
42 |
+
"a cropped photo of a {}.",
|
43 |
+
"a plastic {}.",
|
44 |
+
"a photo of the dirty {}.",
|
45 |
+
"a jpeg corrupted photo of a {}.",
|
46 |
+
"a blurry photo of the {}.",
|
47 |
+
"a photo of the {}.",
|
48 |
+
"a good photo of the {}.",
|
49 |
+
"a rendering of the {}.",
|
50 |
+
"a {} in a video game.",
|
51 |
+
"a photo of one {}.",
|
52 |
+
"a doodle of a {}.",
|
53 |
+
"a close-up photo of the {}.",
|
54 |
+
"a photo of a {}.",
|
55 |
+
"the origami {}.",
|
56 |
+
"the {} in a video game.",
|
57 |
+
"a sketch of a {}.",
|
58 |
+
"a doodle of the {}.",
|
59 |
+
"a origami {}.",
|
60 |
+
"a low resolution photo of a {}.",
|
61 |
+
"the toy {}.",
|
62 |
+
"a rendition of the {}.",
|
63 |
+
"a photo of the clean {}.",
|
64 |
+
"a photo of a large {}.",
|
65 |
+
"a rendition of a {}.",
|
66 |
+
"a photo of a nice {}.",
|
67 |
+
"a photo of a weird {}.",
|
68 |
+
"a blurry photo of a {}.",
|
69 |
+
"a cartoon {}.",
|
70 |
+
"art of a {}.",
|
71 |
+
"a sketch of the {}.",
|
72 |
+
"a embroidered {}.",
|
73 |
+
"a pixelated photo of a {}.",
|
74 |
+
"itap of the {}.",
|
75 |
+
"a jpeg corrupted photo of the {}.",
|
76 |
+
"a good photo of a {}.",
|
77 |
+
"a plushie {}.",
|
78 |
+
"a photo of the nice {}.",
|
79 |
+
"a photo of the small {}.",
|
80 |
+
"a photo of the weird {}.",
|
81 |
+
"the cartoon {}.",
|
82 |
+
"art of the {}.",
|
83 |
+
"a drawing of the {}.",
|
84 |
+
"a photo of the large {}.",
|
85 |
+
"a black and white photo of a {}.",
|
86 |
+
"the plushie {}.",
|
87 |
+
"a dark photo of a {}.",
|
88 |
+
"itap of a {}.",
|
89 |
+
"graffiti of the {}.",
|
90 |
+
"a toy {}.",
|
91 |
+
"itap of my {}.",
|
92 |
+
"a photo of a cool {}.",
|
93 |
+
"a photo of a small {}.",
|
94 |
+
"a tattoo of the {}.",
|
95 |
+
]
|
96 |
+
|
97 |
+
VILD_PROMPT = [
|
98 |
+
"a photo of a {}.",
|
99 |
+
"This is a photo of a {}",
|
100 |
+
"There is a {} in the scene",
|
101 |
+
"There is the {} in the scene",
|
102 |
+
"a photo of a {} in the scene",
|
103 |
+
"a photo of a small {}.",
|
104 |
+
"a photo of a medium {}.",
|
105 |
+
"a photo of a large {}.",
|
106 |
+
"This is a photo of a small {}.",
|
107 |
+
"This is a photo of a medium {}.",
|
108 |
+
"This is a photo of a large {}.",
|
109 |
+
"There is a small {} in the scene.",
|
110 |
+
"There is a medium {} in the scene.",
|
111 |
+
"There is a large {} in the scene.",
|
112 |
+
]
|
113 |
+
|
114 |
+
class PromptExtractor(nn.Module):
|
115 |
+
def __init__(self):
|
116 |
+
super().__init__()
|
117 |
+
self._buffer_init = False
|
118 |
+
|
119 |
+
def init_buffer(self, clip_model):
|
120 |
+
self._buffer_init = True
|
121 |
+
|
122 |
+
def forward(self, noun_list: List[str], clip_model: nn.Module):
|
123 |
+
raise NotImplementedError()
|
124 |
+
|
125 |
+
|
126 |
+
class PredefinedPromptExtractor(PromptExtractor):
|
127 |
+
def __init__(self, templates: List[str]):
|
128 |
+
super().__init__()
|
129 |
+
self.templates = templates
|
130 |
+
|
131 |
+
def forward(self, noun_list: List[str], clip_model: nn.Module):
|
132 |
+
text_features_bucket = []
|
133 |
+
for template in self.templates:
|
134 |
+
noun_tokens = [tokenize(template.format(noun)) for noun in noun_list]
|
135 |
+
text_inputs = torch.cat(noun_tokens).to(
|
136 |
+
clip_model.text_projection.data.device
|
137 |
+
)
|
138 |
+
text_features = clip_model.encode_text(text_inputs)
|
139 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
140 |
+
text_features_bucket.append(text_features)
|
141 |
+
del text_inputs
|
142 |
+
# ensemble by averaging
|
143 |
+
text_features = torch.stack(text_features_bucket).mean(dim=0)
|
144 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
145 |
+
|
146 |
+
return text_features
|
147 |
+
|
148 |
+
|
149 |
+
class ImageNetPromptExtractor(PredefinedPromptExtractor):
|
150 |
+
def __init__(self):
|
151 |
+
super().__init__(IMAGENET_PROMPT)
|
152 |
+
|
153 |
+
|
154 |
+
class VILDPromptExtractor(PredefinedPromptExtractor):
|
155 |
+
def __init__(self):
|
156 |
+
super().__init__(VILD_PROMPT)
|
open_vocab_seg/modeling/clip_adapter/utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
from typing import Tuple
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from .clip import load as clip_load
|
8 |
+
from detectron2.utils.comm import get_local_rank, synchronize
|
9 |
+
|
10 |
+
|
11 |
+
def expand_box(
|
12 |
+
x1: float,
|
13 |
+
y1: float,
|
14 |
+
x2: float,
|
15 |
+
y2: float,
|
16 |
+
expand_ratio: float = 1.0,
|
17 |
+
max_h: int = None,
|
18 |
+
max_w: int = None,
|
19 |
+
):
|
20 |
+
cx = 0.5 * (x1 + x2)
|
21 |
+
cy = 0.5 * (y1 + y2)
|
22 |
+
w = x2 - x1
|
23 |
+
h = y2 - y1
|
24 |
+
w = w * expand_ratio
|
25 |
+
h = h * expand_ratio
|
26 |
+
box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h]
|
27 |
+
if max_h is not None:
|
28 |
+
box[1] = max(0, box[1])
|
29 |
+
box[3] = min(max_h - 1, box[3])
|
30 |
+
if max_w is not None:
|
31 |
+
box[0] = max(0, box[0])
|
32 |
+
box[2] = min(max_w - 1, box[2])
|
33 |
+
return [int(b) for b in box]
|
34 |
+
|
35 |
+
|
36 |
+
def mask2box(mask: torch.Tensor):
|
37 |
+
# use naive way
|
38 |
+
row = torch.nonzero(mask.sum(dim=0))[:, 0]
|
39 |
+
if len(row) == 0:
|
40 |
+
return None
|
41 |
+
x1 = row.min()
|
42 |
+
x2 = row.max()
|
43 |
+
col = np.nonzero(mask.sum(dim=1))[:, 0]
|
44 |
+
y1 = col.min()
|
45 |
+
y2 = col.max()
|
46 |
+
return x1, y1, x2 + 1, y2 + 1
|
47 |
+
|
48 |
+
|
49 |
+
def crop_with_mask(
|
50 |
+
image: torch.Tensor,
|
51 |
+
mask: torch.Tensor,
|
52 |
+
bbox: torch.Tensor,
|
53 |
+
fill: Tuple[float, float, float] = (0, 0, 0),
|
54 |
+
expand_ratio: float = 1.0,
|
55 |
+
):
|
56 |
+
l, t, r, b = expand_box(*bbox, expand_ratio)
|
57 |
+
_, h, w = image.shape
|
58 |
+
l = max(l, 0)
|
59 |
+
t = max(t, 0)
|
60 |
+
r = min(r, w)
|
61 |
+
b = min(b, h)
|
62 |
+
new_image = torch.cat(
|
63 |
+
[image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
|
64 |
+
)
|
65 |
+
mask_bool = mask.bool()
|
66 |
+
return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask_bool[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]
|
67 |
+
|
68 |
+
|
69 |
+
def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
|
70 |
+
rank = get_local_rank()
|
71 |
+
if rank == 0:
|
72 |
+
# download on rank 0 only
|
73 |
+
model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
|
74 |
+
synchronize()
|
75 |
+
if rank != 0:
|
76 |
+
model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
|
77 |
+
synchronize()
|
78 |
+
if frozen:
|
79 |
+
for param in model.parameters():
|
80 |
+
param.requires_grad = False
|
81 |
+
return model
|
open_vocab_seg/modeling/criterion.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
3 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
4 |
+
|
5 |
+
"""
|
6 |
+
MaskFormer criterion.
|
7 |
+
"""
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from detectron2.utils.comm import get_world_size
|
13 |
+
|
14 |
+
from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
|
15 |
+
|
16 |
+
|
17 |
+
def dice_loss(inputs, targets, num_masks):
|
18 |
+
"""
|
19 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
20 |
+
Args:
|
21 |
+
inputs: A float tensor of arbitrary shape.
|
22 |
+
The predictions for each example.
|
23 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
24 |
+
classification label for each element in inputs
|
25 |
+
(0 for the negative class and 1 for the positive class).
|
26 |
+
"""
|
27 |
+
inputs = inputs.sigmoid()
|
28 |
+
inputs = inputs.flatten(1)
|
29 |
+
numerator = 2 * (inputs * targets).sum(-1)
|
30 |
+
denominator = inputs.sum(-1) + targets.sum(-1)
|
31 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
32 |
+
return loss.sum() / num_masks
|
33 |
+
|
34 |
+
|
35 |
+
def sigmoid_focal_loss(
|
36 |
+
inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
40 |
+
Args:
|
41 |
+
inputs: A float tensor of arbitrary shape.
|
42 |
+
The predictions for each example.
|
43 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
44 |
+
classification label for each element in inputs
|
45 |
+
(0 for the negative class and 1 for the positive class).
|
46 |
+
alpha: (optional) Weighting factor in range (0,1) to balance
|
47 |
+
positive vs negative examples. Default = -1 (no weighting).
|
48 |
+
gamma: Exponent of the modulating factor (1 - p_t) to
|
49 |
+
balance easy vs hard examples.
|
50 |
+
Returns:
|
51 |
+
Loss tensor
|
52 |
+
"""
|
53 |
+
prob = inputs.sigmoid()
|
54 |
+
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
55 |
+
p_t = prob * targets + (1 - prob) * (1 - targets)
|
56 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
57 |
+
|
58 |
+
if alpha >= 0:
|
59 |
+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
60 |
+
loss = alpha_t * loss
|
61 |
+
|
62 |
+
return loss.mean(1).sum() / num_masks
|
63 |
+
|
64 |
+
|
65 |
+
class SetCriterion(nn.Module):
|
66 |
+
"""This class computes the loss for DETR.
|
67 |
+
The process happens in two steps:
|
68 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
69 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
|
73 |
+
"""Create the criterion.
|
74 |
+
Parameters:
|
75 |
+
num_classes: number of object categories, omitting the special no-object category
|
76 |
+
matcher: module able to compute a matching between targets and proposals
|
77 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
78 |
+
eos_coef: relative classification weight applied to the no-object category
|
79 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
80 |
+
"""
|
81 |
+
super().__init__()
|
82 |
+
self.num_classes = num_classes
|
83 |
+
self.matcher = matcher
|
84 |
+
self.weight_dict = weight_dict
|
85 |
+
self.eos_coef = eos_coef
|
86 |
+
self.losses = losses
|
87 |
+
if eos_coef > 0:
|
88 |
+
|
89 |
+
empty_weight = torch.ones(self.num_classes + 1)
|
90 |
+
|
91 |
+
empty_weight[-1] = self.eos_coef
|
92 |
+
self.register_buffer("empty_weight", empty_weight)
|
93 |
+
self.use_ignore_idx = False
|
94 |
+
else:
|
95 |
+
self.use_ignore_idx = True
|
96 |
+
self.cur_target = []
|
97 |
+
|
98 |
+
def loss_labels(self, outputs, targets, indices, num_masks):
|
99 |
+
"""Classification loss (NLL)
|
100 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
101 |
+
"""
|
102 |
+
assert "pred_logits" in outputs
|
103 |
+
src_logits = outputs["pred_logits"]
|
104 |
+
|
105 |
+
idx = self._get_src_permutation_idx(indices)
|
106 |
+
target_classes_o = torch.cat(
|
107 |
+
[t["labels"][J] for t, (_, J) in zip(targets, indices)]
|
108 |
+
)
|
109 |
+
target_classes = torch.full(
|
110 |
+
src_logits.shape[:2],
|
111 |
+
self.num_classes,
|
112 |
+
dtype=torch.int64,
|
113 |
+
device=src_logits.device,
|
114 |
+
)
|
115 |
+
target_classes[idx] = target_classes_o
|
116 |
+
if self.use_ignore_idx:
|
117 |
+
loss_ce = F.cross_entropy(
|
118 |
+
src_logits.transpose(1, 2),
|
119 |
+
target_classes,
|
120 |
+
ignore_index=self.num_classes,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
if "empty_weight" in outputs:
|
124 |
+
empty_weight = torch.cat(
|
125 |
+
[outputs["empty_weight"], self.empty_weight[-1:]]
|
126 |
+
).detach()
|
127 |
+
else:
|
128 |
+
empty_weight = self.empty_weight
|
129 |
+
loss_ce = F.cross_entropy(
|
130 |
+
src_logits.transpose(1, 2), target_classes, empty_weight
|
131 |
+
)
|
132 |
+
losses = {"loss_ce": loss_ce}
|
133 |
+
return losses
|
134 |
+
|
135 |
+
def loss_masks(self, outputs, targets, indices, num_masks):
|
136 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
137 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
138 |
+
"""
|
139 |
+
assert "pred_masks" in outputs
|
140 |
+
|
141 |
+
src_idx = self._get_src_permutation_idx(indices)
|
142 |
+
tgt_idx = self._get_tgt_permutation_idx(indices)
|
143 |
+
src_masks = outputs["pred_masks"]
|
144 |
+
src_masks = src_masks[src_idx]
|
145 |
+
masks = [t["masks"] for t in targets]
|
146 |
+
# TODO use valid to mask invalid areas due to padding in loss
|
147 |
+
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
148 |
+
target_masks = target_masks.to(src_masks)
|
149 |
+
target_masks = target_masks[tgt_idx]
|
150 |
+
|
151 |
+
# upsample predictions to the target size
|
152 |
+
src_masks = F.interpolate(
|
153 |
+
src_masks[:, None],
|
154 |
+
size=target_masks.shape[-2:],
|
155 |
+
mode="bilinear",
|
156 |
+
align_corners=False,
|
157 |
+
)
|
158 |
+
src_masks = src_masks[:, 0].flatten(1)
|
159 |
+
|
160 |
+
target_masks = target_masks.flatten(1)
|
161 |
+
target_masks = target_masks.view(src_masks.shape)
|
162 |
+
losses = {
|
163 |
+
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
|
164 |
+
"loss_dice": dice_loss(src_masks, target_masks, num_masks),
|
165 |
+
}
|
166 |
+
return losses
|
167 |
+
|
168 |
+
def _get_src_permutation_idx(self, indices):
|
169 |
+
# permute predictions following indices
|
170 |
+
batch_idx = torch.cat(
|
171 |
+
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
|
172 |
+
)
|
173 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
174 |
+
return batch_idx, src_idx
|
175 |
+
|
176 |
+
def _get_tgt_permutation_idx(self, indices):
|
177 |
+
# permute targets following indices
|
178 |
+
batch_idx = torch.cat(
|
179 |
+
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
|
180 |
+
)
|
181 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
182 |
+
return batch_idx, tgt_idx
|
183 |
+
|
184 |
+
def get_loss(self, loss, outputs, targets, indices, num_masks):
|
185 |
+
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
|
186 |
+
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
187 |
+
return loss_map[loss](outputs, targets, indices, num_masks)
|
188 |
+
|
189 |
+
def forward(self, outputs, targets):
|
190 |
+
"""This performs the loss computation.
|
191 |
+
Parameters:
|
192 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
193 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
194 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
195 |
+
"""
|
196 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
197 |
+
|
198 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
199 |
+
indices = self.matcher(outputs_without_aux, targets)
|
200 |
+
|
201 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
202 |
+
num_masks = sum(len(t["labels"]) for t in targets)
|
203 |
+
num_masks = torch.as_tensor(
|
204 |
+
[num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
|
205 |
+
)
|
206 |
+
if is_dist_avail_and_initialized():
|
207 |
+
torch.distributed.all_reduce(num_masks)
|
208 |
+
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
209 |
+
|
210 |
+
# Compute all the requested losses
|
211 |
+
losses = {}
|
212 |
+
for loss in self.losses:
|
213 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
|
214 |
+
|
215 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
216 |
+
if "aux_outputs" in outputs:
|
217 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
218 |
+
indices = self.matcher(aux_outputs, targets)
|
219 |
+
for loss in self.losses:
|
220 |
+
l_dict = self.get_loss(
|
221 |
+
loss, aux_outputs, targets, indices, num_masks
|
222 |
+
)
|
223 |
+
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
224 |
+
losses.update(l_dict)
|
225 |
+
|
226 |
+
return losses
|
227 |
+
|
228 |
+
def clean_buffer(self):
|
229 |
+
self.cur_target = []
|
open_vocab_seg/modeling/heads/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
open_vocab_seg/modeling/heads/mask_former_head.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import logging
|
5 |
+
from copy import deepcopy
|
6 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import fvcore.nn.weight_init as weight_init
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from detectron2.config import configurable
|
13 |
+
from detectron2.layers import Conv2d, ShapeSpec, get_norm
|
14 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
15 |
+
|
16 |
+
from ..transformer.transformer_predictor import TransformerPredictor
|
17 |
+
from .pixel_decoder import build_pixel_decoder
|
18 |
+
|
19 |
+
|
20 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
21 |
+
class MaskFormerHead(nn.Module):
|
22 |
+
|
23 |
+
_version = 2
|
24 |
+
|
25 |
+
def _load_from_state_dict(
|
26 |
+
self,
|
27 |
+
state_dict,
|
28 |
+
prefix,
|
29 |
+
local_metadata,
|
30 |
+
strict,
|
31 |
+
missing_keys,
|
32 |
+
unexpected_keys,
|
33 |
+
error_msgs,
|
34 |
+
):
|
35 |
+
version = local_metadata.get("version", None)
|
36 |
+
if version is None or version < 2:
|
37 |
+
# Do not warn if train from scratch
|
38 |
+
scratch = True
|
39 |
+
logger = logging.getLogger(__name__)
|
40 |
+
for k in list(state_dict.keys()):
|
41 |
+
newk = k
|
42 |
+
if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
|
43 |
+
newk = k.replace(prefix, prefix + "pixel_decoder.")
|
44 |
+
# logger.debug(f"{k} ==> {newk}")
|
45 |
+
if newk != k:
|
46 |
+
state_dict[newk] = state_dict[k]
|
47 |
+
del state_dict[k]
|
48 |
+
scratch = False
|
49 |
+
|
50 |
+
if not scratch:
|
51 |
+
logger.warning(
|
52 |
+
f"Weight format of {self.__class__.__name__} have changed! "
|
53 |
+
"Please upgrade your models. Applying automatic conversion now ..."
|
54 |
+
)
|
55 |
+
|
56 |
+
@configurable
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
input_shape: Dict[str, ShapeSpec],
|
60 |
+
*,
|
61 |
+
num_classes: int,
|
62 |
+
pixel_decoder: nn.Module,
|
63 |
+
loss_weight: float = 1.0,
|
64 |
+
ignore_value: int = -1,
|
65 |
+
# extra parameters
|
66 |
+
transformer_predictor: nn.Module,
|
67 |
+
transformer_in_feature: str,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
NOTE: this interface is experimental.
|
71 |
+
Args:
|
72 |
+
input_shape: shapes (channels and stride) of the input features
|
73 |
+
num_classes: number of classes to predict
|
74 |
+
pixel_decoder: the pixel decoder module
|
75 |
+
loss_weight: loss weight
|
76 |
+
ignore_value: category id to be ignored during training.
|
77 |
+
transformer_predictor: the transformer decoder that makes prediction
|
78 |
+
transformer_in_feature: input feature name to the transformer_predictor
|
79 |
+
"""
|
80 |
+
super().__init__()
|
81 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
82 |
+
self.in_features = [k for k, v in input_shape]
|
83 |
+
feature_strides = [v.stride for k, v in input_shape]
|
84 |
+
feature_channels = [v.channels for k, v in input_shape]
|
85 |
+
|
86 |
+
self.ignore_value = ignore_value
|
87 |
+
self.common_stride = 4
|
88 |
+
self.loss_weight = loss_weight
|
89 |
+
|
90 |
+
self.pixel_decoder = pixel_decoder
|
91 |
+
self.predictor = transformer_predictor
|
92 |
+
self.transformer_in_feature = transformer_in_feature
|
93 |
+
|
94 |
+
self.num_classes = num_classes
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
98 |
+
return {
|
99 |
+
"input_shape": {
|
100 |
+
k: v
|
101 |
+
for k, v in input_shape.items()
|
102 |
+
if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
103 |
+
},
|
104 |
+
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
105 |
+
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
106 |
+
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
|
107 |
+
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
|
108 |
+
"transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
|
109 |
+
"transformer_predictor": TransformerPredictor(
|
110 |
+
cfg,
|
111 |
+
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
112 |
+
if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
|
113 |
+
else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
|
114 |
+
mask_classification=True,
|
115 |
+
),
|
116 |
+
}
|
117 |
+
|
118 |
+
def forward(self, features):
|
119 |
+
return self.layers(features)
|
120 |
+
|
121 |
+
def layers(self, features):
|
122 |
+
(
|
123 |
+
mask_features,
|
124 |
+
transformer_encoder_features,
|
125 |
+
) = self.pixel_decoder.forward_features(features)
|
126 |
+
if self.transformer_in_feature == "transformer_encoder":
|
127 |
+
assert (
|
128 |
+
transformer_encoder_features is not None
|
129 |
+
), "Please use the TransformerEncoderPixelDecoder."
|
130 |
+
predictions = self.predictor(transformer_encoder_features, mask_features)
|
131 |
+
else:
|
132 |
+
predictions = self.predictor(
|
133 |
+
features[self.transformer_in_feature], mask_features
|
134 |
+
)
|
135 |
+
return predictions
|
open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
# Modified by Feng Liang from
|
4 |
+
# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/heads/zero_shot_mask_former_head.py
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from copy import deepcopy
|
8 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import fvcore.nn.weight_init as weight_init
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from detectron2.config import configurable
|
15 |
+
from detectron2.layers import Conv2d, ShapeSpec, get_norm
|
16 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
17 |
+
|
18 |
+
from ..transformer.open_vocab_transformer_predictor import OpenVocabTransformerPredictor
|
19 |
+
from .pixel_decoder import build_pixel_decoder
|
20 |
+
|
21 |
+
|
22 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
23 |
+
class OpenVocabMaskFormerHead(nn.Module):
|
24 |
+
|
25 |
+
_version = 2
|
26 |
+
|
27 |
+
def _load_from_state_dict(
|
28 |
+
self,
|
29 |
+
state_dict,
|
30 |
+
prefix,
|
31 |
+
local_metadata,
|
32 |
+
strict,
|
33 |
+
missing_keys,
|
34 |
+
unexpected_keys,
|
35 |
+
error_msgs,
|
36 |
+
):
|
37 |
+
version = local_metadata.get("version", None)
|
38 |
+
if version is None or version < 2:
|
39 |
+
# Do not warn if train from scratch
|
40 |
+
scratch = True
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
for k in list(state_dict.keys()):
|
43 |
+
newk = k
|
44 |
+
if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
|
45 |
+
newk = k.replace(prefix, prefix + "pixel_decoder.")
|
46 |
+
# logger.debug(f"{k} ==> {newk}")
|
47 |
+
if newk != k:
|
48 |
+
state_dict[newk] = state_dict[k]
|
49 |
+
del state_dict[k]
|
50 |
+
scratch = False
|
51 |
+
|
52 |
+
if not scratch:
|
53 |
+
logger.warning(
|
54 |
+
f"Weight format of {self.__class__.__name__} have changed! "
|
55 |
+
"Please upgrade your models. Applying automatic conversion now ..."
|
56 |
+
)
|
57 |
+
|
58 |
+
@configurable
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
input_shape: Dict[str, ShapeSpec],
|
62 |
+
*,
|
63 |
+
num_classes: int,
|
64 |
+
pixel_decoder: nn.Module,
|
65 |
+
loss_weight: float = 1.0,
|
66 |
+
ignore_value: int = -1,
|
67 |
+
# extra parameters
|
68 |
+
transformer_predictor: nn.Module,
|
69 |
+
transformer_in_feature: str,
|
70 |
+
):
|
71 |
+
"""
|
72 |
+
NOTE: this interface is experimental.
|
73 |
+
Args:
|
74 |
+
input_shape: shapes (channels and stride) of the input features
|
75 |
+
num_classes: number of classes to predict
|
76 |
+
pixel_decoder: the pixel decoder module
|
77 |
+
loss_weight: loss weight
|
78 |
+
ignore_value: category id to be ignored during training.
|
79 |
+
transformer_predictor: the transformer decoder that makes prediction
|
80 |
+
transformer_in_feature: input feature name to the transformer_predictor
|
81 |
+
"""
|
82 |
+
super().__init__()
|
83 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
84 |
+
self.in_features = [k for k, v in input_shape]
|
85 |
+
feature_strides = [v.stride for k, v in input_shape]
|
86 |
+
feature_channels = [v.channels for k, v in input_shape]
|
87 |
+
|
88 |
+
self.ignore_value = ignore_value
|
89 |
+
self.common_stride = 4
|
90 |
+
self.loss_weight = loss_weight
|
91 |
+
|
92 |
+
self.pixel_decoder = pixel_decoder
|
93 |
+
self.predictor = transformer_predictor
|
94 |
+
self.transformer_in_feature = transformer_in_feature
|
95 |
+
|
96 |
+
self.num_classes = num_classes
|
97 |
+
|
98 |
+
@classmethod
|
99 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
100 |
+
return {
|
101 |
+
"input_shape": {
|
102 |
+
k: v
|
103 |
+
for k, v in input_shape.items()
|
104 |
+
if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
105 |
+
},
|
106 |
+
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
107 |
+
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
108 |
+
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
|
109 |
+
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
|
110 |
+
"transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
|
111 |
+
"transformer_predictor": OpenVocabTransformerPredictor(
|
112 |
+
cfg,
|
113 |
+
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
114 |
+
if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
|
115 |
+
else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
|
116 |
+
mask_classification=True,
|
117 |
+
),
|
118 |
+
}
|
119 |
+
|
120 |
+
def forward(self, features):
|
121 |
+
return self.layers(features)
|
122 |
+
|
123 |
+
def layers(self, features):
|
124 |
+
(
|
125 |
+
mask_features,
|
126 |
+
transformer_encoder_features,
|
127 |
+
) = self.pixel_decoder.forward_features(features)
|
128 |
+
if self.transformer_in_feature == "transformer_encoder":
|
129 |
+
assert (
|
130 |
+
transformer_encoder_features is not None
|
131 |
+
), "Please use the TransformerEncoderPixelDecoder."
|
132 |
+
predictions = self.predictor(transformer_encoder_features, mask_features)
|
133 |
+
else:
|
134 |
+
predictions = self.predictor(
|
135 |
+
features[self.transformer_in_feature], mask_features
|
136 |
+
)
|
137 |
+
return predictions
|
138 |
+
|
139 |
+
def freeze_pretrained(self):
|
140 |
+
for name, module in self.named_children():
|
141 |
+
if name not in ["predictor"]:
|
142 |
+
for param in module.parameters():
|
143 |
+
param.requires_grad = False
|
144 |
+
else:
|
145 |
+
module.freeze_pretrained()
|
open_vocab_seg/modeling/heads/pixel_decoder.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import logging
|
5 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import fvcore.nn.weight_init as weight_init
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from detectron2.config import configurable
|
12 |
+
from detectron2.layers import Conv2d, ShapeSpec, get_norm
|
13 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
14 |
+
|
15 |
+
from ..transformer.position_encoding import PositionEmbeddingSine
|
16 |
+
from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer
|
17 |
+
|
18 |
+
|
19 |
+
def build_pixel_decoder(cfg, input_shape):
|
20 |
+
"""
|
21 |
+
Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
|
22 |
+
"""
|
23 |
+
name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
|
24 |
+
model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
|
25 |
+
forward_features = getattr(model, "forward_features", None)
|
26 |
+
if not callable(forward_features):
|
27 |
+
raise ValueError(
|
28 |
+
"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
|
29 |
+
f"Please implement forward_features for {name} to only return mask features."
|
30 |
+
)
|
31 |
+
return model
|
32 |
+
|
33 |
+
|
34 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
35 |
+
class BasePixelDecoder(nn.Module):
|
36 |
+
@configurable
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
input_shape: Dict[str, ShapeSpec],
|
40 |
+
*,
|
41 |
+
conv_dim: int,
|
42 |
+
mask_dim: int,
|
43 |
+
norm: Optional[Union[str, Callable]] = None,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
NOTE: this interface is experimental.
|
47 |
+
Args:
|
48 |
+
input_shape: shapes (channels and stride) of the input features
|
49 |
+
conv_dims: number of output channels for the intermediate conv layers.
|
50 |
+
mask_dim: number of output channels for the final conv layer.
|
51 |
+
norm (str or callable): normalization for all conv layers
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
56 |
+
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
|
57 |
+
feature_channels = [v.channels for k, v in input_shape]
|
58 |
+
|
59 |
+
lateral_convs = []
|
60 |
+
output_convs = []
|
61 |
+
|
62 |
+
use_bias = norm == ""
|
63 |
+
for idx, in_channels in enumerate(feature_channels):
|
64 |
+
if idx == len(self.in_features) - 1:
|
65 |
+
output_norm = get_norm(norm, conv_dim)
|
66 |
+
output_conv = Conv2d(
|
67 |
+
in_channels,
|
68 |
+
conv_dim,
|
69 |
+
kernel_size=3,
|
70 |
+
stride=1,
|
71 |
+
padding=1,
|
72 |
+
bias=use_bias,
|
73 |
+
norm=output_norm,
|
74 |
+
activation=F.relu,
|
75 |
+
)
|
76 |
+
weight_init.c2_xavier_fill(output_conv)
|
77 |
+
self.add_module("layer_{}".format(idx + 1), output_conv)
|
78 |
+
|
79 |
+
lateral_convs.append(None)
|
80 |
+
output_convs.append(output_conv)
|
81 |
+
else:
|
82 |
+
lateral_norm = get_norm(norm, conv_dim)
|
83 |
+
output_norm = get_norm(norm, conv_dim)
|
84 |
+
|
85 |
+
lateral_conv = Conv2d(
|
86 |
+
in_channels,
|
87 |
+
conv_dim,
|
88 |
+
kernel_size=1,
|
89 |
+
bias=use_bias,
|
90 |
+
norm=lateral_norm,
|
91 |
+
)
|
92 |
+
output_conv = Conv2d(
|
93 |
+
conv_dim,
|
94 |
+
conv_dim,
|
95 |
+
kernel_size=3,
|
96 |
+
stride=1,
|
97 |
+
padding=1,
|
98 |
+
bias=use_bias,
|
99 |
+
norm=output_norm,
|
100 |
+
activation=F.relu,
|
101 |
+
)
|
102 |
+
weight_init.c2_xavier_fill(lateral_conv)
|
103 |
+
weight_init.c2_xavier_fill(output_conv)
|
104 |
+
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
|
105 |
+
self.add_module("layer_{}".format(idx + 1), output_conv)
|
106 |
+
|
107 |
+
lateral_convs.append(lateral_conv)
|
108 |
+
output_convs.append(output_conv)
|
109 |
+
# Place convs into top-down order (from low to high resolution)
|
110 |
+
# to make the top-down computation in forward clearer.
|
111 |
+
self.lateral_convs = lateral_convs[::-1]
|
112 |
+
self.output_convs = output_convs[::-1]
|
113 |
+
|
114 |
+
self.mask_dim = mask_dim
|
115 |
+
self.mask_features = Conv2d(
|
116 |
+
conv_dim,
|
117 |
+
mask_dim,
|
118 |
+
kernel_size=3,
|
119 |
+
stride=1,
|
120 |
+
padding=1,
|
121 |
+
)
|
122 |
+
weight_init.c2_xavier_fill(self.mask_features)
|
123 |
+
|
124 |
+
@classmethod
|
125 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
126 |
+
ret = {}
|
127 |
+
ret["input_shape"] = {
|
128 |
+
k: v
|
129 |
+
for k, v in input_shape.items()
|
130 |
+
if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
131 |
+
}
|
132 |
+
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
133 |
+
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
134 |
+
ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
|
135 |
+
return ret
|
136 |
+
|
137 |
+
def forward_features(self, features):
|
138 |
+
# Reverse feature maps into top-down order (from low to high resolution)
|
139 |
+
for idx, f in enumerate(self.in_features[::-1]):
|
140 |
+
x = features[f]
|
141 |
+
lateral_conv = self.lateral_convs[idx]
|
142 |
+
output_conv = self.output_convs[idx]
|
143 |
+
if lateral_conv is None:
|
144 |
+
y = output_conv(x)
|
145 |
+
else:
|
146 |
+
cur_fpn = lateral_conv(x)
|
147 |
+
# Following FPN implementation, we use nearest upsampling here
|
148 |
+
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
|
149 |
+
y = output_conv(y)
|
150 |
+
return self.mask_features(y), None
|
151 |
+
|
152 |
+
def forward(self, features, targets=None):
|
153 |
+
logger = logging.getLogger(__name__)
|
154 |
+
logger.warning(
|
155 |
+
"Calling forward() may cause unpredicted behavior of PixelDecoder module."
|
156 |
+
)
|
157 |
+
return self.forward_features(features)
|
158 |
+
|
159 |
+
|
160 |
+
class TransformerEncoderOnly(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
d_model=512,
|
164 |
+
nhead=8,
|
165 |
+
num_encoder_layers=6,
|
166 |
+
dim_feedforward=2048,
|
167 |
+
dropout=0.1,
|
168 |
+
activation="relu",
|
169 |
+
normalize_before=False,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
encoder_layer = TransformerEncoderLayer(
|
174 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
175 |
+
)
|
176 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
177 |
+
self.encoder = TransformerEncoder(
|
178 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
179 |
+
)
|
180 |
+
|
181 |
+
self._reset_parameters()
|
182 |
+
|
183 |
+
self.d_model = d_model
|
184 |
+
self.nhead = nhead
|
185 |
+
|
186 |
+
def _reset_parameters(self):
|
187 |
+
for p in self.parameters():
|
188 |
+
if p.dim() > 1:
|
189 |
+
nn.init.xavier_uniform_(p)
|
190 |
+
|
191 |
+
def forward(self, src, mask, pos_embed):
|
192 |
+
# flatten NxCxHxW to HWxNxC
|
193 |
+
bs, c, h, w = src.shape
|
194 |
+
src = src.flatten(2).permute(2, 0, 1)
|
195 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
196 |
+
if mask is not None:
|
197 |
+
mask = mask.flatten(1)
|
198 |
+
|
199 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
200 |
+
return memory.permute(1, 2, 0).view(bs, c, h, w)
|
201 |
+
|
202 |
+
|
203 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
204 |
+
class TransformerEncoderPixelDecoder(BasePixelDecoder):
|
205 |
+
@configurable
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
input_shape: Dict[str, ShapeSpec],
|
209 |
+
*,
|
210 |
+
transformer_dropout: float,
|
211 |
+
transformer_nheads: int,
|
212 |
+
transformer_dim_feedforward: int,
|
213 |
+
transformer_enc_layers: int,
|
214 |
+
transformer_pre_norm: bool,
|
215 |
+
conv_dim: int,
|
216 |
+
mask_dim: int,
|
217 |
+
norm: Optional[Union[str, Callable]] = None,
|
218 |
+
):
|
219 |
+
"""
|
220 |
+
NOTE: this interface is experimental.
|
221 |
+
Args:
|
222 |
+
input_shape: shapes (channels and stride) of the input features
|
223 |
+
transformer_dropout: dropout probability in transformer
|
224 |
+
transformer_nheads: number of heads in transformer
|
225 |
+
transformer_dim_feedforward: dimension of feedforward network
|
226 |
+
transformer_enc_layers: number of transformer encoder layers
|
227 |
+
transformer_pre_norm: whether to use pre-layernorm or not
|
228 |
+
conv_dims: number of output channels for the intermediate conv layers.
|
229 |
+
mask_dim: number of output channels for the final conv layer.
|
230 |
+
norm (str or callable): normalization for all conv layers
|
231 |
+
"""
|
232 |
+
super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
|
233 |
+
|
234 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
235 |
+
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
|
236 |
+
feature_strides = [v.stride for k, v in input_shape]
|
237 |
+
feature_channels = [v.channels for k, v in input_shape]
|
238 |
+
|
239 |
+
in_channels = feature_channels[len(self.in_features) - 1]
|
240 |
+
self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
|
241 |
+
weight_init.c2_xavier_fill(self.input_proj)
|
242 |
+
self.transformer = TransformerEncoderOnly(
|
243 |
+
d_model=conv_dim,
|
244 |
+
dropout=transformer_dropout,
|
245 |
+
nhead=transformer_nheads,
|
246 |
+
dim_feedforward=transformer_dim_feedforward,
|
247 |
+
num_encoder_layers=transformer_enc_layers,
|
248 |
+
normalize_before=transformer_pre_norm,
|
249 |
+
)
|
250 |
+
N_steps = conv_dim // 2
|
251 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
252 |
+
|
253 |
+
# update layer
|
254 |
+
use_bias = norm == ""
|
255 |
+
output_norm = get_norm(norm, conv_dim)
|
256 |
+
output_conv = Conv2d(
|
257 |
+
conv_dim,
|
258 |
+
conv_dim,
|
259 |
+
kernel_size=3,
|
260 |
+
stride=1,
|
261 |
+
padding=1,
|
262 |
+
bias=use_bias,
|
263 |
+
norm=output_norm,
|
264 |
+
activation=F.relu,
|
265 |
+
)
|
266 |
+
weight_init.c2_xavier_fill(output_conv)
|
267 |
+
delattr(self, "layer_{}".format(len(self.in_features)))
|
268 |
+
self.add_module("layer_{}".format(len(self.in_features)), output_conv)
|
269 |
+
self.output_convs[0] = output_conv
|
270 |
+
|
271 |
+
@classmethod
|
272 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
273 |
+
ret = super().from_config(cfg, input_shape)
|
274 |
+
ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
|
275 |
+
ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
|
276 |
+
ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
|
277 |
+
ret[
|
278 |
+
"transformer_enc_layers"
|
279 |
+
] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
|
280 |
+
ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
|
281 |
+
return ret
|
282 |
+
|
283 |
+
def forward_features(self, features):
|
284 |
+
# Reverse feature maps into top-down order (from low to high resolution)
|
285 |
+
for idx, f in enumerate(self.in_features[::-1]):
|
286 |
+
x = features[f]
|
287 |
+
lateral_conv = self.lateral_convs[idx]
|
288 |
+
output_conv = self.output_convs[idx]
|
289 |
+
if lateral_conv is None:
|
290 |
+
transformer = self.input_proj(x)
|
291 |
+
pos = self.pe_layer(x)
|
292 |
+
transformer = self.transformer(transformer, None, pos)
|
293 |
+
y = output_conv(transformer)
|
294 |
+
# save intermediate feature as input to Transformer decoder
|
295 |
+
transformer_encoder_features = transformer
|
296 |
+
else:
|
297 |
+
cur_fpn = lateral_conv(x)
|
298 |
+
# Following FPN implementation, we use nearest upsampling here
|
299 |
+
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
|
300 |
+
y = output_conv(y)
|
301 |
+
return self.mask_features(y), transformer_encoder_features
|
302 |
+
|
303 |
+
def forward(self, features, targets=None):
|
304 |
+
logger = logging.getLogger(__name__)
|
305 |
+
logger.warning(
|
306 |
+
"Calling forward() may cause unpredicted behavior of PixelDecoder module."
|
307 |
+
)
|
308 |
+
return self.forward_features(features)
|
open_vocab_seg/modeling/matcher.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
|
3 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
4 |
+
|
5 |
+
"""
|
6 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
7 |
+
"""
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from scipy.optimize import linear_sum_assignment
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def batch_dice_loss(inputs, targets):
|
15 |
+
"""
|
16 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
17 |
+
Args:
|
18 |
+
inputs: A float tensor of arbitrary shape.
|
19 |
+
The predictions for each example.
|
20 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
21 |
+
classification label for each element in inputs
|
22 |
+
(0 for the negative class and 1 for the positive class).
|
23 |
+
"""
|
24 |
+
inputs = inputs.sigmoid()
|
25 |
+
inputs = inputs.flatten(1)
|
26 |
+
numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
|
27 |
+
denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
|
28 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
29 |
+
return loss
|
30 |
+
|
31 |
+
|
32 |
+
def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
|
33 |
+
"""
|
34 |
+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
35 |
+
Args:
|
36 |
+
inputs: A float tensor of arbitrary shape.
|
37 |
+
The predictions for each example.
|
38 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
39 |
+
classification label for each element in inputs
|
40 |
+
(0 for the negative class and 1 for the positive class).
|
41 |
+
alpha: (optional) Weighting factor in range (0,1) to balance
|
42 |
+
positive vs negative examples. Default = -1 (no weighting).
|
43 |
+
gamma: Exponent of the modulating factor (1 - p_t) to
|
44 |
+
balance easy vs hard examples.
|
45 |
+
Returns:
|
46 |
+
Loss tensor
|
47 |
+
"""
|
48 |
+
hw = inputs.shape[1]
|
49 |
+
|
50 |
+
prob = inputs.sigmoid()
|
51 |
+
focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits(
|
52 |
+
inputs, torch.ones_like(inputs), reduction="none"
|
53 |
+
)
|
54 |
+
focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits(
|
55 |
+
inputs, torch.zeros_like(inputs), reduction="none"
|
56 |
+
)
|
57 |
+
if alpha >= 0:
|
58 |
+
focal_pos = focal_pos * alpha
|
59 |
+
focal_neg = focal_neg * (1 - alpha)
|
60 |
+
|
61 |
+
loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum(
|
62 |
+
"nc,mc->nm", focal_neg, (1 - targets)
|
63 |
+
)
|
64 |
+
|
65 |
+
return loss / hw
|
66 |
+
|
67 |
+
|
68 |
+
class HungarianMatcher(nn.Module):
|
69 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
70 |
+
|
71 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
72 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
73 |
+
while the others are un-matched (and thus treated as non-objects).
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1
|
78 |
+
):
|
79 |
+
"""Creates the matcher
|
80 |
+
|
81 |
+
Params:
|
82 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
83 |
+
cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
|
84 |
+
cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
|
85 |
+
"""
|
86 |
+
super().__init__()
|
87 |
+
self.cost_class = cost_class
|
88 |
+
self.cost_mask = cost_mask
|
89 |
+
self.cost_dice = cost_dice
|
90 |
+
assert (
|
91 |
+
cost_class != 0 or cost_mask != 0 or cost_dice != 0
|
92 |
+
), "all costs cant be 0"
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def memory_efficient_forward(self, outputs, targets):
|
96 |
+
"""More memory-friendly matching"""
|
97 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
98 |
+
|
99 |
+
# Work out the mask padding size
|
100 |
+
masks = [v["masks"] for v in targets]
|
101 |
+
h_max = max([m.shape[1] for m in masks])
|
102 |
+
w_max = max([m.shape[2] for m in masks])
|
103 |
+
|
104 |
+
indices = []
|
105 |
+
|
106 |
+
# Iterate through batch size
|
107 |
+
for b in range(bs):
|
108 |
+
|
109 |
+
out_prob = outputs["pred_logits"][b].softmax(
|
110 |
+
-1
|
111 |
+
) # [num_queries, num_classes]
|
112 |
+
out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
|
113 |
+
|
114 |
+
tgt_ids = targets[b]["labels"]
|
115 |
+
# gt masks are already padded when preparing target
|
116 |
+
tgt_mask = targets[b]["masks"].to(out_mask)
|
117 |
+
|
118 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
119 |
+
# but approximate it in 1 - proba[target class].
|
120 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
121 |
+
cost_class = -out_prob[:, tgt_ids]
|
122 |
+
|
123 |
+
# Downsample gt masks to save memory
|
124 |
+
tgt_mask = F.interpolate(
|
125 |
+
tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest"
|
126 |
+
)
|
127 |
+
|
128 |
+
# Flatten spatial dimension
|
129 |
+
out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W]
|
130 |
+
tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W]
|
131 |
+
|
132 |
+
# Compute the focal loss between masks
|
133 |
+
cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask)
|
134 |
+
|
135 |
+
# Compute the dice loss betwen masks
|
136 |
+
cost_dice = batch_dice_loss(out_mask, tgt_mask)
|
137 |
+
|
138 |
+
# Final cost matrix
|
139 |
+
C = (
|
140 |
+
self.cost_mask * cost_mask
|
141 |
+
+ self.cost_class * cost_class
|
142 |
+
+ self.cost_dice * cost_dice
|
143 |
+
)
|
144 |
+
C = C.reshape(num_queries, -1).cpu()
|
145 |
+
|
146 |
+
indices.append(linear_sum_assignment(C))
|
147 |
+
return [
|
148 |
+
(
|
149 |
+
torch.as_tensor(i, dtype=torch.int64),
|
150 |
+
torch.as_tensor(j, dtype=torch.int64),
|
151 |
+
)
|
152 |
+
for i, j in indices
|
153 |
+
]
|
154 |
+
|
155 |
+
@torch.no_grad()
|
156 |
+
def forward(self, outputs, targets):
|
157 |
+
"""Performs the matching
|
158 |
+
|
159 |
+
Params:
|
160 |
+
outputs: This is a dict that contains at least these entries:
|
161 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
162 |
+
"pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
|
163 |
+
|
164 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
165 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
166 |
+
objects in the target) containing the class labels
|
167 |
+
"masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
171 |
+
- index_i is the indices of the selected predictions (in order)
|
172 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
173 |
+
For each batch element, it holds:
|
174 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
175 |
+
"""
|
176 |
+
return self.memory_efficient_forward(outputs, targets)
|
177 |
+
|
178 |
+
def __repr__(self):
|
179 |
+
head = "Matcher " + self.__class__.__name__
|
180 |
+
body = [
|
181 |
+
"cost_class: {}".format(self.cost_class),
|
182 |
+
"cost_mask: {}".format(self.cost_mask),
|
183 |
+
"cost_dice: {}".format(self.cost_dice),
|
184 |
+
]
|
185 |
+
_repr_indent = 4
|
186 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
187 |
+
return "\n".join(lines)
|
open_vocab_seg/modeling/transformer/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
3 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from detectron2.config import configurable
|
7 |
+
from .transformer_predictor import TransformerPredictor, MLP
|
8 |
+
|
9 |
+
|
10 |
+
class OpenVocabTransformerPredictor(TransformerPredictor):
|
11 |
+
@configurable
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
in_channels,
|
15 |
+
mask_classification=True,
|
16 |
+
*,
|
17 |
+
embedding_dim: int,
|
18 |
+
embed_hidden_dim: int,
|
19 |
+
embed_layers: int,
|
20 |
+
hidden_dim: int,
|
21 |
+
num_queries: int,
|
22 |
+
nheads: int,
|
23 |
+
dropout: float,
|
24 |
+
dim_feedforward: int,
|
25 |
+
enc_layers: int,
|
26 |
+
dec_layers: int,
|
27 |
+
pre_norm: bool,
|
28 |
+
deep_supervision: bool,
|
29 |
+
mask_dim: int,
|
30 |
+
enforce_input_project: bool,
|
31 |
+
):
|
32 |
+
super().__init__(
|
33 |
+
in_channels,
|
34 |
+
False,
|
35 |
+
num_classes=embedding_dim,
|
36 |
+
hidden_dim=hidden_dim,
|
37 |
+
num_queries=num_queries,
|
38 |
+
nheads=nheads,
|
39 |
+
dropout=dropout,
|
40 |
+
dim_feedforward=dim_feedforward,
|
41 |
+
enc_layers=enc_layers,
|
42 |
+
dec_layers=dec_layers,
|
43 |
+
pre_norm=pre_norm,
|
44 |
+
deep_supervision=deep_supervision,
|
45 |
+
mask_dim=mask_dim,
|
46 |
+
enforce_input_project=enforce_input_project,
|
47 |
+
)
|
48 |
+
self.mask_classification = mask_classification
|
49 |
+
# output FFNs
|
50 |
+
if self.mask_classification:
|
51 |
+
self.class_embed = MLP(
|
52 |
+
hidden_dim, embed_hidden_dim, embedding_dim, embed_layers
|
53 |
+
)
|
54 |
+
|
55 |
+
def freeze_pretrained(self):
|
56 |
+
for name, module in self.named_children():
|
57 |
+
if name not in ["class_embed"]:
|
58 |
+
for param in module.parameters():
|
59 |
+
param.requires_grad = False
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_config(cls, cfg, in_channels, mask_classification):
|
63 |
+
ret = {}
|
64 |
+
ret["in_channels"] = in_channels
|
65 |
+
ret["mask_classification"] = mask_classification
|
66 |
+
|
67 |
+
ret["embedding_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM
|
68 |
+
ret["embed_hidden_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM
|
69 |
+
ret["embed_layers"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS
|
70 |
+
ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
|
71 |
+
ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
72 |
+
# Transformer parameters:
|
73 |
+
ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
|
74 |
+
ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
|
75 |
+
ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
|
76 |
+
ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
|
77 |
+
ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
|
78 |
+
ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
|
79 |
+
ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
|
80 |
+
ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
|
81 |
+
|
82 |
+
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
83 |
+
|
84 |
+
return ret
|
open_vocab_seg/modeling/transformer/position_encoding.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
|
3 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
4 |
+
|
5 |
+
"""
|
6 |
+
Various positional encodings for the transformer.
|
7 |
+
"""
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
class PositionEmbeddingSine(nn.Module):
|
15 |
+
"""
|
16 |
+
This is a more standard version of the position embedding, very similar to the one
|
17 |
+
used by the Attention is all you need paper, generalized to work on images.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.num_pos_feats = num_pos_feats
|
25 |
+
self.temperature = temperature
|
26 |
+
self.normalize = normalize
|
27 |
+
if scale is not None and normalize is False:
|
28 |
+
raise ValueError("normalize should be True if scale is passed")
|
29 |
+
if scale is None:
|
30 |
+
scale = 2 * math.pi
|
31 |
+
self.scale = scale
|
32 |
+
|
33 |
+
def forward(self, x, mask=None):
|
34 |
+
if mask is None:
|
35 |
+
mask = torch.zeros(
|
36 |
+
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
37 |
+
)
|
38 |
+
not_mask = ~mask
|
39 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
40 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
41 |
+
if self.normalize:
|
42 |
+
eps = 1e-6
|
43 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
44 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
45 |
+
|
46 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
47 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
48 |
+
|
49 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
50 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
51 |
+
pos_x = torch.stack(
|
52 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
53 |
+
).flatten(3)
|
54 |
+
pos_y = torch.stack(
|
55 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
56 |
+
).flatten(3)
|
57 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
58 |
+
return pos
|
open_vocab_seg/modeling/transformer/transformer.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
|
3 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
4 |
+
|
5 |
+
"""
|
6 |
+
Transformer class.
|
7 |
+
|
8 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
9 |
+
* positional encodings are passed in MHattention
|
10 |
+
* extra LN at the end of encoder is removed
|
11 |
+
* decoder returns a stack of activations from all decoding layers
|
12 |
+
"""
|
13 |
+
import copy
|
14 |
+
from typing import List, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import Tensor, nn
|
19 |
+
|
20 |
+
|
21 |
+
class Transformer(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
d_model=512,
|
25 |
+
nhead=8,
|
26 |
+
num_encoder_layers=6,
|
27 |
+
num_decoder_layers=6,
|
28 |
+
dim_feedforward=2048,
|
29 |
+
dropout=0.1,
|
30 |
+
activation="relu",
|
31 |
+
normalize_before=False,
|
32 |
+
return_intermediate_dec=False,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
encoder_layer = TransformerEncoderLayer(
|
37 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
38 |
+
)
|
39 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
40 |
+
self.encoder = TransformerEncoder(
|
41 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
42 |
+
)
|
43 |
+
|
44 |
+
decoder_layer = TransformerDecoderLayer(
|
45 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
46 |
+
)
|
47 |
+
decoder_norm = nn.LayerNorm(d_model)
|
48 |
+
self.decoder = TransformerDecoder(
|
49 |
+
decoder_layer,
|
50 |
+
num_decoder_layers,
|
51 |
+
decoder_norm,
|
52 |
+
return_intermediate=return_intermediate_dec,
|
53 |
+
)
|
54 |
+
|
55 |
+
self._reset_parameters()
|
56 |
+
|
57 |
+
self.d_model = d_model
|
58 |
+
self.nhead = nhead
|
59 |
+
|
60 |
+
def _reset_parameters(self):
|
61 |
+
for p in self.parameters():
|
62 |
+
if p.dim() > 1:
|
63 |
+
nn.init.xavier_uniform_(p)
|
64 |
+
|
65 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
66 |
+
# flatten NxCxHxW to HWxNxC
|
67 |
+
bs, c, h, w = src.shape
|
68 |
+
src = src.flatten(2).permute(2, 0, 1)
|
69 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
70 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
71 |
+
if mask is not None:
|
72 |
+
mask = mask.flatten(1)
|
73 |
+
|
74 |
+
tgt = torch.zeros_like(query_embed)
|
75 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
76 |
+
hs = self.decoder(
|
77 |
+
tgt,
|
78 |
+
memory,
|
79 |
+
memory_key_padding_mask=mask,
|
80 |
+
pos=pos_embed,
|
81 |
+
query_pos=query_embed,
|
82 |
+
)
|
83 |
+
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
|
84 |
+
|
85 |
+
|
86 |
+
class TransformerEncoder(nn.Module):
|
87 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
88 |
+
super().__init__()
|
89 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
90 |
+
self.num_layers = num_layers
|
91 |
+
self.norm = norm
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
src,
|
96 |
+
mask: Optional[Tensor] = None,
|
97 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
98 |
+
pos: Optional[Tensor] = None,
|
99 |
+
):
|
100 |
+
output = src
|
101 |
+
|
102 |
+
for layer in self.layers:
|
103 |
+
output = layer(
|
104 |
+
output,
|
105 |
+
src_mask=mask,
|
106 |
+
src_key_padding_mask=src_key_padding_mask,
|
107 |
+
pos=pos,
|
108 |
+
)
|
109 |
+
|
110 |
+
if self.norm is not None:
|
111 |
+
output = self.norm(output)
|
112 |
+
|
113 |
+
return output
|
114 |
+
|
115 |
+
|
116 |
+
class TransformerDecoder(nn.Module):
|
117 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
118 |
+
super().__init__()
|
119 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
120 |
+
self.num_layers = num_layers
|
121 |
+
self.norm = norm
|
122 |
+
self.return_intermediate = return_intermediate
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
tgt,
|
127 |
+
memory,
|
128 |
+
tgt_mask: Optional[Tensor] = None,
|
129 |
+
memory_mask: Optional[Tensor] = None,
|
130 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
131 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
132 |
+
pos: Optional[Tensor] = None,
|
133 |
+
query_pos: Optional[Tensor] = None,
|
134 |
+
):
|
135 |
+
output = tgt
|
136 |
+
|
137 |
+
intermediate = []
|
138 |
+
|
139 |
+
for layer in self.layers:
|
140 |
+
output = layer(
|
141 |
+
output,
|
142 |
+
memory,
|
143 |
+
tgt_mask=tgt_mask,
|
144 |
+
memory_mask=memory_mask,
|
145 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
146 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
147 |
+
pos=pos,
|
148 |
+
query_pos=query_pos,
|
149 |
+
)
|
150 |
+
if self.return_intermediate:
|
151 |
+
intermediate.append(self.norm(output))
|
152 |
+
|
153 |
+
if self.norm is not None:
|
154 |
+
output = self.norm(output)
|
155 |
+
if self.return_intermediate:
|
156 |
+
intermediate.pop()
|
157 |
+
intermediate.append(output)
|
158 |
+
|
159 |
+
if self.return_intermediate:
|
160 |
+
return torch.stack(intermediate)
|
161 |
+
|
162 |
+
return output.unsqueeze(0)
|
163 |
+
|
164 |
+
|
165 |
+
class TransformerEncoderLayer(nn.Module):
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
d_model,
|
169 |
+
nhead,
|
170 |
+
dim_feedforward=2048,
|
171 |
+
dropout=0.1,
|
172 |
+
activation="relu",
|
173 |
+
normalize_before=False,
|
174 |
+
):
|
175 |
+
super().__init__()
|
176 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
177 |
+
# Implementation of Feedforward model
|
178 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
179 |
+
self.dropout = nn.Dropout(dropout)
|
180 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
181 |
+
|
182 |
+
self.norm1 = nn.LayerNorm(d_model)
|
183 |
+
self.norm2 = nn.LayerNorm(d_model)
|
184 |
+
self.dropout1 = nn.Dropout(dropout)
|
185 |
+
self.dropout2 = nn.Dropout(dropout)
|
186 |
+
|
187 |
+
self.activation = _get_activation_fn(activation)
|
188 |
+
self.normalize_before = normalize_before
|
189 |
+
|
190 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
191 |
+
return tensor if pos is None else tensor + pos
|
192 |
+
|
193 |
+
def forward_post(
|
194 |
+
self,
|
195 |
+
src,
|
196 |
+
src_mask: Optional[Tensor] = None,
|
197 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
198 |
+
pos: Optional[Tensor] = None,
|
199 |
+
):
|
200 |
+
q = k = self.with_pos_embed(src, pos)
|
201 |
+
src2 = self.self_attn(
|
202 |
+
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
203 |
+
)[0]
|
204 |
+
src = src + self.dropout1(src2)
|
205 |
+
src = self.norm1(src)
|
206 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
207 |
+
src = src + self.dropout2(src2)
|
208 |
+
src = self.norm2(src)
|
209 |
+
return src
|
210 |
+
|
211 |
+
def forward_pre(
|
212 |
+
self,
|
213 |
+
src,
|
214 |
+
src_mask: Optional[Tensor] = None,
|
215 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
216 |
+
pos: Optional[Tensor] = None,
|
217 |
+
):
|
218 |
+
src2 = self.norm1(src)
|
219 |
+
q = k = self.with_pos_embed(src2, pos)
|
220 |
+
src2 = self.self_attn(
|
221 |
+
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
222 |
+
)[0]
|
223 |
+
src = src + self.dropout1(src2)
|
224 |
+
src2 = self.norm2(src)
|
225 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
226 |
+
src = src + self.dropout2(src2)
|
227 |
+
return src
|
228 |
+
|
229 |
+
def forward(
|
230 |
+
self,
|
231 |
+
src,
|
232 |
+
src_mask: Optional[Tensor] = None,
|
233 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
234 |
+
pos: Optional[Tensor] = None,
|
235 |
+
):
|
236 |
+
if self.normalize_before:
|
237 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
238 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
239 |
+
|
240 |
+
|
241 |
+
class TransformerDecoderLayer(nn.Module):
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
d_model,
|
245 |
+
nhead,
|
246 |
+
dim_feedforward=2048,
|
247 |
+
dropout=0.1,
|
248 |
+
activation="relu",
|
249 |
+
normalize_before=False,
|
250 |
+
):
|
251 |
+
super().__init__()
|
252 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
253 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
254 |
+
# Implementation of Feedforward model
|
255 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
256 |
+
self.dropout = nn.Dropout(dropout)
|
257 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
258 |
+
|
259 |
+
self.norm1 = nn.LayerNorm(d_model)
|
260 |
+
self.norm2 = nn.LayerNorm(d_model)
|
261 |
+
self.norm3 = nn.LayerNorm(d_model)
|
262 |
+
self.dropout1 = nn.Dropout(dropout)
|
263 |
+
self.dropout2 = nn.Dropout(dropout)
|
264 |
+
self.dropout3 = nn.Dropout(dropout)
|
265 |
+
|
266 |
+
self.activation = _get_activation_fn(activation)
|
267 |
+
self.normalize_before = normalize_before
|
268 |
+
|
269 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
270 |
+
return tensor if pos is None else tensor + pos
|
271 |
+
|
272 |
+
def forward_post(
|
273 |
+
self,
|
274 |
+
tgt,
|
275 |
+
memory,
|
276 |
+
tgt_mask: Optional[Tensor] = None,
|
277 |
+
memory_mask: Optional[Tensor] = None,
|
278 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
279 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
280 |
+
pos: Optional[Tensor] = None,
|
281 |
+
query_pos: Optional[Tensor] = None,
|
282 |
+
):
|
283 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
284 |
+
tgt2 = self.self_attn(
|
285 |
+
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
286 |
+
)[0]
|
287 |
+
tgt = tgt + self.dropout1(tgt2)
|
288 |
+
tgt = self.norm1(tgt)
|
289 |
+
tgt2 = self.multihead_attn(
|
290 |
+
query=self.with_pos_embed(tgt, query_pos),
|
291 |
+
key=self.with_pos_embed(memory, pos),
|
292 |
+
value=memory,
|
293 |
+
attn_mask=memory_mask,
|
294 |
+
key_padding_mask=memory_key_padding_mask,
|
295 |
+
)[0]
|
296 |
+
tgt = tgt + self.dropout2(tgt2)
|
297 |
+
tgt = self.norm2(tgt)
|
298 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
299 |
+
tgt = tgt + self.dropout3(tgt2)
|
300 |
+
tgt = self.norm3(tgt)
|
301 |
+
return tgt
|
302 |
+
|
303 |
+
def forward_pre(
|
304 |
+
self,
|
305 |
+
tgt,
|
306 |
+
memory,
|
307 |
+
tgt_mask: Optional[Tensor] = None,
|
308 |
+
memory_mask: Optional[Tensor] = None,
|
309 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
310 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
311 |
+
pos: Optional[Tensor] = None,
|
312 |
+
query_pos: Optional[Tensor] = None,
|
313 |
+
):
|
314 |
+
tgt2 = self.norm1(tgt)
|
315 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
316 |
+
tgt2 = self.self_attn(
|
317 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
318 |
+
)[0]
|
319 |
+
tgt = tgt + self.dropout1(tgt2)
|
320 |
+
tgt2 = self.norm2(tgt)
|
321 |
+
tgt2 = self.multihead_attn(
|
322 |
+
query=self.with_pos_embed(tgt2, query_pos),
|
323 |
+
key=self.with_pos_embed(memory, pos),
|
324 |
+
value=memory,
|
325 |
+
attn_mask=memory_mask,
|
326 |
+
key_padding_mask=memory_key_padding_mask,
|
327 |
+
)[0]
|
328 |
+
tgt = tgt + self.dropout2(tgt2)
|
329 |
+
tgt2 = self.norm3(tgt)
|
330 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
331 |
+
tgt = tgt + self.dropout3(tgt2)
|
332 |
+
return tgt
|
333 |
+
|
334 |
+
def forward(
|
335 |
+
self,
|
336 |
+
tgt,
|
337 |
+
memory,
|
338 |
+
tgt_mask: Optional[Tensor] = None,
|
339 |
+
memory_mask: Optional[Tensor] = None,
|
340 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
341 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
342 |
+
pos: Optional[Tensor] = None,
|
343 |
+
query_pos: Optional[Tensor] = None,
|
344 |
+
):
|
345 |
+
if self.normalize_before:
|
346 |
+
return self.forward_pre(
|
347 |
+
tgt,
|
348 |
+
memory,
|
349 |
+
tgt_mask,
|
350 |
+
memory_mask,
|
351 |
+
tgt_key_padding_mask,
|
352 |
+
memory_key_padding_mask,
|
353 |
+
pos,
|
354 |
+
query_pos,
|
355 |
+
)
|
356 |
+
return self.forward_post(
|
357 |
+
tgt,
|
358 |
+
memory,
|
359 |
+
tgt_mask,
|
360 |
+
memory_mask,
|
361 |
+
tgt_key_padding_mask,
|
362 |
+
memory_key_padding_mask,
|
363 |
+
pos,
|
364 |
+
query_pos,
|
365 |
+
)
|
366 |
+
|
367 |
+
|
368 |
+
def _get_clones(module, N):
|
369 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
370 |
+
|
371 |
+
|
372 |
+
def _get_activation_fn(activation):
|
373 |
+
"""Return an activation function given a string"""
|
374 |
+
if activation == "relu":
|
375 |
+
return F.relu
|
376 |
+
if activation == "gelu":
|
377 |
+
return F.gelu
|
378 |
+
if activation == "glu":
|
379 |
+
return F.glu
|
380 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
open_vocab_seg/modeling/transformer/transformer_predictor.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
3 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
4 |
+
|
5 |
+
import fvcore.nn.weight_init as weight_init
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from detectron2.config import configurable
|
11 |
+
from detectron2.layers import Conv2d
|
12 |
+
|
13 |
+
from .position_encoding import PositionEmbeddingSine
|
14 |
+
from .transformer import Transformer
|
15 |
+
|
16 |
+
|
17 |
+
class TransformerPredictor(nn.Module):
|
18 |
+
@configurable
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
in_channels,
|
22 |
+
mask_classification=True,
|
23 |
+
*,
|
24 |
+
num_classes: int,
|
25 |
+
hidden_dim: int,
|
26 |
+
num_queries: int,
|
27 |
+
nheads: int,
|
28 |
+
dropout: float,
|
29 |
+
dim_feedforward: int,
|
30 |
+
enc_layers: int,
|
31 |
+
dec_layers: int,
|
32 |
+
pre_norm: bool,
|
33 |
+
deep_supervision: bool,
|
34 |
+
mask_dim: int,
|
35 |
+
enforce_input_project: bool,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
NOTE: this interface is experimental.
|
39 |
+
Args:
|
40 |
+
in_channels: channels of the input features
|
41 |
+
mask_classification: whether to add mask classifier or not
|
42 |
+
num_classes: number of classes
|
43 |
+
hidden_dim: Transformer feature dimension
|
44 |
+
num_queries: number of queries
|
45 |
+
nheads: number of heads
|
46 |
+
dropout: dropout in Transformer
|
47 |
+
dim_feedforward: feature dimension in feedforward network
|
48 |
+
enc_layers: number of Transformer encoder layers
|
49 |
+
dec_layers: number of Transformer decoder layers
|
50 |
+
pre_norm: whether to use pre-LayerNorm or not
|
51 |
+
deep_supervision: whether to add supervision to every decoder layers
|
52 |
+
mask_dim: mask feature dimension
|
53 |
+
enforce_input_project: add input project 1x1 conv even if input
|
54 |
+
channels and hidden dim is identical
|
55 |
+
"""
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.mask_classification = mask_classification
|
59 |
+
|
60 |
+
# positional encoding
|
61 |
+
N_steps = hidden_dim // 2
|
62 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
63 |
+
|
64 |
+
transformer = Transformer(
|
65 |
+
d_model=hidden_dim,
|
66 |
+
dropout=dropout,
|
67 |
+
nhead=nheads,
|
68 |
+
dim_feedforward=dim_feedforward,
|
69 |
+
num_encoder_layers=enc_layers,
|
70 |
+
num_decoder_layers=dec_layers,
|
71 |
+
normalize_before=pre_norm,
|
72 |
+
return_intermediate_dec=deep_supervision,
|
73 |
+
)
|
74 |
+
|
75 |
+
self.num_queries = num_queries
|
76 |
+
self.transformer = transformer
|
77 |
+
hidden_dim = transformer.d_model
|
78 |
+
|
79 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
80 |
+
|
81 |
+
if in_channels != hidden_dim or enforce_input_project:
|
82 |
+
self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
|
83 |
+
weight_init.c2_xavier_fill(self.input_proj)
|
84 |
+
else:
|
85 |
+
self.input_proj = nn.Sequential()
|
86 |
+
self.aux_loss = deep_supervision
|
87 |
+
|
88 |
+
# output FFNs
|
89 |
+
if self.mask_classification:
|
90 |
+
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
|
91 |
+
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def from_config(cls, cfg, in_channels, mask_classification):
|
95 |
+
ret = {}
|
96 |
+
ret["in_channels"] = in_channels
|
97 |
+
ret["mask_classification"] = mask_classification
|
98 |
+
|
99 |
+
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
|
100 |
+
ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
|
101 |
+
ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
102 |
+
# Transformer parameters:
|
103 |
+
ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
|
104 |
+
ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
|
105 |
+
ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
|
106 |
+
ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
|
107 |
+
ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
|
108 |
+
ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
|
109 |
+
ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
|
110 |
+
ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
|
111 |
+
|
112 |
+
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
113 |
+
|
114 |
+
return ret
|
115 |
+
|
116 |
+
def forward(self, x, mask_features):
|
117 |
+
pos = self.pe_layer(x)
|
118 |
+
|
119 |
+
src = x
|
120 |
+
mask = None
|
121 |
+
hs, memory = self.transformer(
|
122 |
+
self.input_proj(src), mask, self.query_embed.weight, pos
|
123 |
+
)
|
124 |
+
|
125 |
+
if self.mask_classification:
|
126 |
+
outputs_class = self.class_embed(hs)
|
127 |
+
out = {"pred_logits": outputs_class[-1]}
|
128 |
+
else:
|
129 |
+
out = {}
|
130 |
+
|
131 |
+
if self.aux_loss:
|
132 |
+
# [l, bs, queries, embed]
|
133 |
+
mask_embed = self.mask_embed(hs)
|
134 |
+
outputs_seg_masks = torch.einsum(
|
135 |
+
"lbqc,bchw->lbqhw", mask_embed, mask_features
|
136 |
+
)
|
137 |
+
out["pred_masks"] = outputs_seg_masks[-1]
|
138 |
+
out["aux_outputs"] = self._set_aux_loss(
|
139 |
+
outputs_class if self.mask_classification else None, outputs_seg_masks
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
# FIXME h_boxes takes the last one computed, keep this in mind
|
143 |
+
# [bs, queries, embed]
|
144 |
+
mask_embed = self.mask_embed(hs[-1])
|
145 |
+
outputs_seg_masks = torch.einsum(
|
146 |
+
"bqc,bchw->bqhw", mask_embed, mask_features
|
147 |
+
)
|
148 |
+
out["pred_masks"] = outputs_seg_masks
|
149 |
+
return out
|
150 |
+
|
151 |
+
@torch.jit.unused
|
152 |
+
def _set_aux_loss(self, outputs_class, outputs_seg_masks):
|
153 |
+
# this is a workaround to make torchscript happy, as torchscript
|
154 |
+
# doesn't support dictionary with non-homogeneous values, such
|
155 |
+
# as a dict having both a Tensor and a list.
|
156 |
+
if self.mask_classification:
|
157 |
+
return [
|
158 |
+
{"pred_logits": a, "pred_masks": b}
|
159 |
+
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
|
160 |
+
]
|
161 |
+
else:
|
162 |
+
return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
|
163 |
+
|
164 |
+
|
165 |
+
class MLP(nn.Module):
|
166 |
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
167 |
+
|
168 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
169 |
+
super().__init__()
|
170 |
+
self.num_layers = num_layers
|
171 |
+
h = [hidden_dim] * (num_layers - 1)
|
172 |
+
self.layers = nn.ModuleList(
|
173 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
174 |
+
)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
for i, layer in enumerate(self.layers):
|
178 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
179 |
+
return x
|
open_vocab_seg/ovseg_model.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
# Modified by Feng Liang from
|
4 |
+
# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/zero_shot_mask_former_model.py
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from typing import Tuple
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from detectron2.config import configurable
|
15 |
+
from detectron2.data import MetadataCatalog
|
16 |
+
from detectron2.modeling import META_ARCH_REGISTRY
|
17 |
+
from detectron2.modeling.backbone import Backbone
|
18 |
+
from detectron2.modeling.postprocessing import sem_seg_postprocess
|
19 |
+
from detectron2.structures import ImageList
|
20 |
+
from detectron2.utils.logger import log_first_n
|
21 |
+
from .modeling.clip_adapter import (
|
22 |
+
ClipAdapter,
|
23 |
+
MaskFormerClipAdapter,
|
24 |
+
build_text_prompt,
|
25 |
+
)
|
26 |
+
from .mask_former_model import MaskFormer
|
27 |
+
from .utils.misc import get_gt_binary_masks
|
28 |
+
|
29 |
+
@META_ARCH_REGISTRY.register()
|
30 |
+
class OVSeg(MaskFormer):
|
31 |
+
"""
|
32 |
+
Main class for zero shot mask classification semantic segmentation architectures.
|
33 |
+
"""
|
34 |
+
|
35 |
+
@configurable
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
*,
|
39 |
+
backbone: Backbone,
|
40 |
+
sem_seg_head: nn.Module,
|
41 |
+
clip_adapter: nn.Module,
|
42 |
+
criterion: nn.Module,
|
43 |
+
num_queries: int,
|
44 |
+
panoptic_on: bool,
|
45 |
+
object_mask_threshold: float,
|
46 |
+
overlap_threshold: float,
|
47 |
+
metadata,
|
48 |
+
size_divisibility: int,
|
49 |
+
sem_seg_postprocess_before_inference: bool,
|
50 |
+
clip_ensemble: bool,
|
51 |
+
clip_ensemble_weight: float,
|
52 |
+
pixel_mean: Tuple[float],
|
53 |
+
pixel_std: Tuple[float],
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
58 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
59 |
+
criterion: a module that defines the loss
|
60 |
+
clip_adapter: adapter for clip-based mask classification
|
61 |
+
num_queries: int, number of queries
|
62 |
+
panoptic_on: bool, whether to output panoptic segmentation prediction
|
63 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
64 |
+
for panoptic segmentation inference
|
65 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
66 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
67 |
+
segmentation inference
|
68 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
69 |
+
specific integer. We can use this to override such requirement.
|
70 |
+
sem_seg_postprocess_before_inference: whether to resize the prediction back
|
71 |
+
to original input size before semantic segmentation inference or after.
|
72 |
+
For high-resolution dataset like Mapillary, resizing predictions before
|
73 |
+
inference will cause OOM error.
|
74 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
75 |
+
the per-channel mean and std to be used to normalize the input image
|
76 |
+
"""
|
77 |
+
super().__init__(
|
78 |
+
backbone=backbone,
|
79 |
+
sem_seg_head=sem_seg_head,
|
80 |
+
criterion=criterion,
|
81 |
+
num_queries=num_queries,
|
82 |
+
panoptic_on=panoptic_on,
|
83 |
+
object_mask_threshold=object_mask_threshold,
|
84 |
+
overlap_threshold=overlap_threshold,
|
85 |
+
metadata=metadata,
|
86 |
+
size_divisibility=size_divisibility,
|
87 |
+
sem_seg_postprocess_before_inference=sem_seg_postprocess_before_inference,
|
88 |
+
pixel_mean=pixel_mean,
|
89 |
+
pixel_std=pixel_std,
|
90 |
+
)
|
91 |
+
self.clip_adapter: ClipAdapter = clip_adapter
|
92 |
+
|
93 |
+
self.clip_ensemble: bool = clip_ensemble
|
94 |
+
self.clip_ensemble_weight: float = clip_ensemble_weight
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def from_config(cls, cfg):
|
98 |
+
init_kwargs = MaskFormer.from_config(cfg)
|
99 |
+
text_templates = build_text_prompt(cfg.MODEL.CLIP_ADAPTER)
|
100 |
+
|
101 |
+
clip_adapter = MaskFormerClipAdapter(
|
102 |
+
cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME,
|
103 |
+
text_templates,
|
104 |
+
mask_fill=cfg.MODEL.CLIP_ADAPTER.MASK_FILL,
|
105 |
+
mask_expand_ratio=cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO,
|
106 |
+
mask_thr=cfg.MODEL.CLIP_ADAPTER.MASK_THR,
|
107 |
+
mask_matting=cfg.MODEL.CLIP_ADAPTER.MASK_MATTING,
|
108 |
+
region_resized=cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED,
|
109 |
+
mask_prompt_depth=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH,
|
110 |
+
mask_prompt_fwd=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD,
|
111 |
+
)
|
112 |
+
init_kwargs["clip_adapter"] = clip_adapter
|
113 |
+
init_kwargs["clip_ensemble"] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE
|
114 |
+
init_kwargs[
|
115 |
+
"clip_ensemble_weight"
|
116 |
+
] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT
|
117 |
+
|
118 |
+
return init_kwargs
|
119 |
+
|
120 |
+
def forward(self, batched_inputs):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
124 |
+
Each item in the list contains the inputs for one image.
|
125 |
+
For now, each item in the list is a dict that contains:
|
126 |
+
* "image": Tensor, image in (C, H, W) format.
|
127 |
+
* "instances": per-region ground truth
|
128 |
+
* Other information that's included in the original dicts, such as:
|
129 |
+
"height", "width" (int): the output resolution of the model (may be different
|
130 |
+
from input resolution), used in inference.
|
131 |
+
Returns:
|
132 |
+
list[dict]:
|
133 |
+
each dict has the results for one image. The dict contains the following keys:
|
134 |
+
|
135 |
+
* "sem_seg":
|
136 |
+
A Tensor that represents the
|
137 |
+
per-pixel segmentation prediced by the head.
|
138 |
+
The prediction has shape KxHxW that represents the logits of
|
139 |
+
each class for each pixel.
|
140 |
+
* "panoptic_seg":
|
141 |
+
A tuple that represent panoptic output
|
142 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
143 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
144 |
+
Each dict contains keys "id", "category_id", "isthing".
|
145 |
+
"""
|
146 |
+
dataset_name = [x["meta"]["dataset_name"] for x in batched_inputs]
|
147 |
+
assert len(set(dataset_name)) == 1
|
148 |
+
dataset_name = dataset_name[0]
|
149 |
+
|
150 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
151 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
152 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
153 |
+
|
154 |
+
features = self.backbone(images.tensor)
|
155 |
+
outputs = self.sem_seg_head(features)
|
156 |
+
class_names = self.get_class_name_list(dataset_name)
|
157 |
+
text_features = self.clip_adapter.get_text_features(class_names)
|
158 |
+
outputs["pred_logits"] = self.clip_adapter.get_sim_logits(
|
159 |
+
text_features, self.clip_adapter.normalize_feature(outputs["pred_logits"])
|
160 |
+
)
|
161 |
+
if self.training:
|
162 |
+
if "aux_outputs" in outputs.keys():
|
163 |
+
for i in range(len(outputs["aux_outputs"])):
|
164 |
+
outputs["aux_outputs"][i][
|
165 |
+
"pred_logits"
|
166 |
+
] = self.clip_adapter.get_sim_logits(
|
167 |
+
text_features,
|
168 |
+
self.clip_adapter.normalize_feature(
|
169 |
+
outputs["aux_outputs"][i]["pred_logits"]
|
170 |
+
),
|
171 |
+
)
|
172 |
+
# mask classification target
|
173 |
+
if "instances" in batched_inputs[0]:
|
174 |
+
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
175 |
+
targets = self.prepare_targets(gt_instances, images)
|
176 |
+
else:
|
177 |
+
targets = None
|
178 |
+
|
179 |
+
# bipartite matching-based loss
|
180 |
+
losses = self.criterion(outputs, targets)
|
181 |
+
|
182 |
+
for k in list(losses.keys()):
|
183 |
+
if k in self.criterion.weight_dict:
|
184 |
+
losses[k] *= self.criterion.weight_dict[k]
|
185 |
+
else:
|
186 |
+
# remove this loss if not specified in `weight_dict`
|
187 |
+
losses.pop(k)
|
188 |
+
|
189 |
+
return losses
|
190 |
+
else:
|
191 |
+
mask_cls_results = outputs["pred_logits"]
|
192 |
+
mask_pred_results = outputs["pred_masks"]
|
193 |
+
# upsample masks
|
194 |
+
mask_pred_results = F.interpolate(
|
195 |
+
mask_pred_results,
|
196 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
197 |
+
mode="bilinear",
|
198 |
+
align_corners=False,
|
199 |
+
)
|
200 |
+
|
201 |
+
processed_results = []
|
202 |
+
for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
203 |
+
mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
|
204 |
+
):
|
205 |
+
height = image_size[0]
|
206 |
+
width = image_size[1]
|
207 |
+
mask_pred_result = sem_seg_postprocess(
|
208 |
+
mask_pred_result, image_size, height, width
|
209 |
+
)
|
210 |
+
image = input_per_image["image"].to(self.device)
|
211 |
+
|
212 |
+
r, regions = self.semantic_inference(
|
213 |
+
mask_cls_result, mask_pred_result, image, class_names
|
214 |
+
)
|
215 |
+
|
216 |
+
height = input_per_image.get("height", image_size[0])
|
217 |
+
width = input_per_image.get("width", image_size[1])
|
218 |
+
r = sem_seg_postprocess(r, image_size, height, width)
|
219 |
+
processed_results.append({"sem_seg": r})
|
220 |
+
|
221 |
+
# panoptic segmentation inference
|
222 |
+
if self.panoptic_on:
|
223 |
+
panoptic_r = self.panoptic_inference(
|
224 |
+
mask_cls_result, mask_pred_result
|
225 |
+
)
|
226 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
227 |
+
|
228 |
+
return processed_results
|
229 |
+
|
230 |
+
|
231 |
+
def semantic_inference(self, mask_cls, mask_pred, image, class_names):
|
232 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
233 |
+
mask_pred = mask_pred.sigmoid()
|
234 |
+
|
235 |
+
regions = None
|
236 |
+
if self.clip_ensemble:
|
237 |
+
clip_cls, regions, valid_flag = self.clip_adapter(
|
238 |
+
image, class_names, mask_pred, normalize=True
|
239 |
+
)
|
240 |
+
if clip_cls is None:
|
241 |
+
clip_cls = torch.empty(0, mask_cls.shape[-1] + 1, device=self.device)
|
242 |
+
# softmax before index or after?
|
243 |
+
clip_cls = F.softmax(clip_cls[:, :-1], dim=-1)
|
244 |
+
if self.clip_ensemble_weight > 0:
|
245 |
+
map_back_clip_cls = mask_cls.new_ones(mask_cls.shape)
|
246 |
+
map_back_clip_cls[valid_flag] = clip_cls
|
247 |
+
mask_cls = torch.pow(mask_cls, 1 - self.clip_ensemble_weight) * \
|
248 |
+
torch.pow(map_back_clip_cls, self.clip_ensemble_weight)
|
249 |
+
|
250 |
+
|
251 |
+
else:
|
252 |
+
# only clip model predictions are used
|
253 |
+
mask_cls = clip_cls
|
254 |
+
mask_pred = mask_pred[valid_flag]
|
255 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
256 |
+
return semseg, regions
|
257 |
+
|
258 |
+
def get_class_name_list(self, dataset_name):
|
259 |
+
class_names = [
|
260 |
+
c.strip() for c in MetadataCatalog.get(dataset_name).stuff_classes
|
261 |
+
]
|
262 |
+
return class_names
|
263 |
+
|
264 |
+
|
265 |
+
@META_ARCH_REGISTRY.register()
|
266 |
+
class OVSegDEMO(MaskFormer):
|
267 |
+
"""
|
268 |
+
Main class for zero shot mask classification semantic segmentation architectures.
|
269 |
+
"""
|
270 |
+
|
271 |
+
@configurable
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
*,
|
275 |
+
backbone: Backbone,
|
276 |
+
sem_seg_head: nn.Module,
|
277 |
+
clip_adapter: nn.Module,
|
278 |
+
criterion: nn.Module,
|
279 |
+
num_queries: int,
|
280 |
+
panoptic_on: bool,
|
281 |
+
object_mask_threshold: float,
|
282 |
+
overlap_threshold: float,
|
283 |
+
metadata,
|
284 |
+
size_divisibility: int,
|
285 |
+
sem_seg_postprocess_before_inference: bool,
|
286 |
+
clip_ensemble: bool,
|
287 |
+
clip_ensemble_weight: float,
|
288 |
+
pixel_mean: Tuple[float],
|
289 |
+
pixel_std: Tuple[float],
|
290 |
+
):
|
291 |
+
"""
|
292 |
+
Args:
|
293 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
294 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
295 |
+
criterion: a module that defines the loss
|
296 |
+
clip_adapter: adapter for clip-based mask classification
|
297 |
+
num_queries: int, number of queries
|
298 |
+
panoptic_on: bool, whether to output panoptic segmentation prediction
|
299 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
300 |
+
for panoptic segmentation inference
|
301 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
302 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
303 |
+
segmentation inference
|
304 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
305 |
+
specific integer. We can use this to override such requirement.
|
306 |
+
sem_seg_postprocess_before_inference: whether to resize the prediction back
|
307 |
+
to original input size before semantic segmentation inference or after.
|
308 |
+
For high-resolution dataset like Mapillary, resizing predictions before
|
309 |
+
inference will cause OOM error.
|
310 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
311 |
+
the per-channel mean and std to be used to normalize the input image
|
312 |
+
"""
|
313 |
+
super().__init__(
|
314 |
+
backbone=backbone,
|
315 |
+
sem_seg_head=sem_seg_head,
|
316 |
+
criterion=criterion,
|
317 |
+
num_queries=num_queries,
|
318 |
+
panoptic_on=panoptic_on,
|
319 |
+
object_mask_threshold=object_mask_threshold,
|
320 |
+
overlap_threshold=overlap_threshold,
|
321 |
+
metadata=metadata,
|
322 |
+
size_divisibility=size_divisibility,
|
323 |
+
sem_seg_postprocess_before_inference=sem_seg_postprocess_before_inference,
|
324 |
+
pixel_mean=pixel_mean,
|
325 |
+
pixel_std=pixel_std,
|
326 |
+
)
|
327 |
+
self.clip_adapter: ClipAdapter = clip_adapter
|
328 |
+
|
329 |
+
self.clip_ensemble: bool = clip_ensemble
|
330 |
+
self.clip_ensemble_weight: float = clip_ensemble_weight
|
331 |
+
|
332 |
+
@classmethod
|
333 |
+
def from_config(cls, cfg):
|
334 |
+
init_kwargs = MaskFormer.from_config(cfg)
|
335 |
+
text_templates = build_text_prompt(cfg.MODEL.CLIP_ADAPTER)
|
336 |
+
|
337 |
+
clip_adapter = MaskFormerClipAdapter(
|
338 |
+
cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME,
|
339 |
+
text_templates,
|
340 |
+
mask_fill=cfg.MODEL.CLIP_ADAPTER.MASK_FILL,
|
341 |
+
mask_expand_ratio=cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO,
|
342 |
+
mask_thr=cfg.MODEL.CLIP_ADAPTER.MASK_THR,
|
343 |
+
mask_matting=cfg.MODEL.CLIP_ADAPTER.MASK_MATTING,
|
344 |
+
region_resized=cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED,
|
345 |
+
mask_prompt_depth=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH,
|
346 |
+
mask_prompt_fwd=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD,
|
347 |
+
)
|
348 |
+
init_kwargs["clip_adapter"] = clip_adapter
|
349 |
+
init_kwargs["clip_ensemble"] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE
|
350 |
+
init_kwargs[
|
351 |
+
"clip_ensemble_weight"
|
352 |
+
] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT
|
353 |
+
|
354 |
+
return init_kwargs
|
355 |
+
|
356 |
+
def forward(self, batched_inputs):
|
357 |
+
"""
|
358 |
+
Args:
|
359 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
360 |
+
Each item in the list contains the inputs for one image.
|
361 |
+
For now, each item in the list is a dict that contains:
|
362 |
+
* "image": Tensor, image in (C, H, W) format.
|
363 |
+
* "instances": per-region ground truth
|
364 |
+
* Other information that's included in the original dicts, such as:
|
365 |
+
"height", "width" (int): the output resolution of the model (may be different
|
366 |
+
from input resolution), used in inference.
|
367 |
+
Returns:
|
368 |
+
list[dict]:
|
369 |
+
each dict has the results for one image. The dict contains the following keys:
|
370 |
+
|
371 |
+
* "sem_seg":
|
372 |
+
A Tensor that represents the
|
373 |
+
per-pixel segmentation prediced by the head.
|
374 |
+
The prediction has shape KxHxW that represents the logits of
|
375 |
+
each class for each pixel.
|
376 |
+
* "panoptic_seg":
|
377 |
+
A tuple that represent panoptic output
|
378 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
379 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
380 |
+
Each dict contains keys "id", "category_id", "isthing".
|
381 |
+
"""
|
382 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
383 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
384 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
385 |
+
|
386 |
+
features = self.backbone(images.tensor)
|
387 |
+
outputs = self.sem_seg_head(features)
|
388 |
+
class_names = batched_inputs[0]["class_names"]
|
389 |
+
if len(class_names) == 1:
|
390 |
+
# Because classification is performed in a 'contrastive' manner, adding others to represent other concepts
|
391 |
+
class_names.append('others')
|
392 |
+
text_features = self.clip_adapter.get_text_features(class_names)
|
393 |
+
outputs["pred_logits"] = self.clip_adapter.get_sim_logits(
|
394 |
+
text_features, self.clip_adapter.normalize_feature(outputs["pred_logits"])
|
395 |
+
)
|
396 |
+
mask_cls_results = outputs["pred_logits"]
|
397 |
+
mask_pred_results = outputs["pred_masks"]
|
398 |
+
# upsample masks
|
399 |
+
mask_pred_results = F.interpolate(
|
400 |
+
mask_pred_results,
|
401 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
402 |
+
mode="bilinear",
|
403 |
+
align_corners=False,
|
404 |
+
)
|
405 |
+
|
406 |
+
processed_results = []
|
407 |
+
for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
408 |
+
mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
|
409 |
+
):
|
410 |
+
height = image_size[0]
|
411 |
+
width = image_size[1]
|
412 |
+
mask_pred_result = sem_seg_postprocess(
|
413 |
+
mask_pred_result, image_size, height, width
|
414 |
+
)
|
415 |
+
image = input_per_image["image"].to(self.device)
|
416 |
+
|
417 |
+
r, regions = self.demo_inference(mask_cls_result, mask_pred_result, image, class_names)
|
418 |
+
|
419 |
+
height = input_per_image.get("height", image_size[0])
|
420 |
+
width = input_per_image.get("width", image_size[1])
|
421 |
+
r = sem_seg_postprocess(r, image_size, height, width)
|
422 |
+
processed_results.append({"sem_seg": r})
|
423 |
+
|
424 |
+
return processed_results
|
425 |
+
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
+
def demo_inference(self, mask_cls, mask_pred, image, class_names):
|
430 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
431 |
+
mask_pred = mask_pred.sigmoid()
|
432 |
+
|
433 |
+
regions = None
|
434 |
+
if self.clip_ensemble:
|
435 |
+
clip_cls, regions, valid_flag = self.clip_adapter(
|
436 |
+
image, class_names, mask_pred, normalize=True
|
437 |
+
)
|
438 |
+
if clip_cls is None:
|
439 |
+
clip_cls = torch.empty(0, mask_cls.shape[-1] + 1, device=self.device)
|
440 |
+
# softmax before index or after?
|
441 |
+
clip_cls = F.softmax(clip_cls[:, :-1], dim=-1)
|
442 |
+
if self.clip_ensemble_weight > 0:
|
443 |
+
map_back_clip_cls = mask_cls.new_ones(mask_cls.shape)
|
444 |
+
map_back_clip_cls[valid_flag] = clip_cls
|
445 |
+
mask_cls = torch.pow(mask_cls, 1 - self.clip_ensemble_weight) * \
|
446 |
+
torch.pow(map_back_clip_cls, self.clip_ensemble_weight)
|
447 |
+
|
448 |
+
else:
|
449 |
+
# only clip model predictions are used
|
450 |
+
mask_cls = clip_cls
|
451 |
+
mask_pred = mask_pred[valid_flag]
|
452 |
+
bin_mask = mask_pred > self.clip_adapter.mask_thr
|
453 |
+
select_cls = torch.zeros(sum(valid_flag), mask_cls.shape[-1], device=self.device)
|
454 |
+
select_mask = torch.argmax(mask_cls, dim=0)
|
455 |
+
if len(class_names) == 2 and class_names[-1] == 'others':
|
456 |
+
select_mask = select_mask[:-1]
|
457 |
+
for idx in select_mask:
|
458 |
+
select_cls[idx] = mask_cls[idx]
|
459 |
+
semseg = torch.einsum("qc,qhw->chw", select_cls, bin_mask.float())
|
460 |
+
return semseg, regions
|
open_vocab_seg/test_time_augmentation.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import copy
|
5 |
+
from itertools import count
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from fvcore.transforms import HFlipTransform
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn.parallel import DistributedDataParallel
|
12 |
+
|
13 |
+
from detectron2.data.detection_utils import read_image
|
14 |
+
from detectron2.modeling import DatasetMapperTTA
|
15 |
+
from detectron2.modeling.postprocessing import sem_seg_postprocess
|
16 |
+
import logging
|
17 |
+
from detectron2.utils.logger import log_every_n, log_first_n
|
18 |
+
|
19 |
+
__all__ = [
|
20 |
+
"SemanticSegmentorWithTTA",
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
class SemanticSegmentorWithTTA(nn.Module):
|
25 |
+
"""
|
26 |
+
A SemanticSegmentor with test-time augmentation enabled.
|
27 |
+
Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
cfg (CfgNode):
|
34 |
+
model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
|
35 |
+
tta_mapper (callable): takes a dataset dict and returns a list of
|
36 |
+
augmented versions of the dataset dict. Defaults to
|
37 |
+
`DatasetMapperTTA(cfg)`.
|
38 |
+
batch_size (int): batch the augmented images into this batch size for inference.
|
39 |
+
"""
|
40 |
+
super().__init__()
|
41 |
+
if isinstance(model, DistributedDataParallel):
|
42 |
+
model = model.module
|
43 |
+
self.cfg = cfg.clone()
|
44 |
+
|
45 |
+
self.model = model
|
46 |
+
|
47 |
+
if tta_mapper is None:
|
48 |
+
tta_mapper = DatasetMapperTTA(cfg)
|
49 |
+
self.tta_mapper = tta_mapper
|
50 |
+
self.batch_size = batch_size
|
51 |
+
|
52 |
+
def _inference_with_model(self, inputs):
|
53 |
+
if self.cfg.TEST.SLIDING_WINDOW:
|
54 |
+
log_first_n(logging.INFO, "Using sliding window to test")
|
55 |
+
|
56 |
+
outputs = []
|
57 |
+
|
58 |
+
for input in inputs:
|
59 |
+
image_size = input["image"].shape[1:] # h,w
|
60 |
+
if self.cfg.TEST.SLIDING_TILE_SIZE > 0:
|
61 |
+
tile_size = (
|
62 |
+
self.cfg.TEST.SLIDING_TILE_SIZE,
|
63 |
+
self.cfg.TEST.SLIDING_TILE_SIZE,
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
selected_mapping = {256: 224, 512: 256, 768: 512, 896: 512}
|
67 |
+
tile_size = min(image_size)
|
68 |
+
tile_size = selected_mapping[tile_size]
|
69 |
+
tile_size = (tile_size, tile_size)
|
70 |
+
extra_info = {
|
71 |
+
k: v
|
72 |
+
for k, v in input.items()
|
73 |
+
if k not in ["image", "height", "width"]
|
74 |
+
}
|
75 |
+
log_every_n(
|
76 |
+
logging.INFO, "split {} to {}".format(image_size, tile_size)
|
77 |
+
)
|
78 |
+
overlap = self.cfg.TEST.SLIDING_OVERLAP
|
79 |
+
stride = math.ceil(tile_size[0] * (1 - overlap))
|
80 |
+
tile_rows = int(
|
81 |
+
math.ceil((image_size[0] - tile_size[0]) / stride) + 1
|
82 |
+
) # strided convolution formula
|
83 |
+
tile_cols = int(math.ceil((image_size[1] - tile_size[1]) / stride) + 1)
|
84 |
+
full_probs = None
|
85 |
+
count_predictions = None
|
86 |
+
tile_counter = 0
|
87 |
+
|
88 |
+
for row in range(tile_rows):
|
89 |
+
for col in range(tile_cols):
|
90 |
+
x1 = int(col * stride)
|
91 |
+
y1 = int(row * stride)
|
92 |
+
x2 = min(x1 + tile_size[1], image_size[1])
|
93 |
+
y2 = min(y1 + tile_size[0], image_size[0])
|
94 |
+
x1 = max(
|
95 |
+
int(x2 - tile_size[1]), 0
|
96 |
+
) # for portrait images the x1 underflows sometimes
|
97 |
+
y1 = max(
|
98 |
+
int(y2 - tile_size[0]), 0
|
99 |
+
) # for very few rows y1 underflows
|
100 |
+
|
101 |
+
img = input["image"][:, y1:y2, x1:x2]
|
102 |
+
padded_img = nn.functional.pad(
|
103 |
+
img,
|
104 |
+
(
|
105 |
+
0,
|
106 |
+
tile_size[1] - img.shape[-1],
|
107 |
+
0,
|
108 |
+
tile_size[0] - img.shape[-2],
|
109 |
+
),
|
110 |
+
)
|
111 |
+
tile_counter += 1
|
112 |
+
padded_input = {"image": padded_img}
|
113 |
+
padded_input.update(extra_info)
|
114 |
+
padded_prediction = self.model([padded_input])[0]["sem_seg"]
|
115 |
+
prediction = padded_prediction[
|
116 |
+
:, 0 : img.shape[1], 0 : img.shape[2]
|
117 |
+
]
|
118 |
+
if full_probs is None:
|
119 |
+
full_probs = prediction.new_zeros(
|
120 |
+
prediction.shape[0], image_size[0], image_size[1]
|
121 |
+
)
|
122 |
+
if count_predictions is None:
|
123 |
+
count_predictions = prediction.new_zeros(
|
124 |
+
prediction.shape[0], image_size[0], image_size[1]
|
125 |
+
)
|
126 |
+
count_predictions[:, y1:y2, x1:x2] += 1
|
127 |
+
full_probs[
|
128 |
+
:, y1:y2, x1:x2
|
129 |
+
] += prediction # accumulate the predictions also in the overlapping regions
|
130 |
+
|
131 |
+
full_probs /= count_predictions
|
132 |
+
full_probs = sem_seg_postprocess(
|
133 |
+
full_probs,
|
134 |
+
image_size,
|
135 |
+
input.get("height", image_size[0]),
|
136 |
+
input.get("width", image_size[1]),
|
137 |
+
)
|
138 |
+
outputs.append({"sem_seg": full_probs})
|
139 |
+
|
140 |
+
return outputs
|
141 |
+
else:
|
142 |
+
log_first_n(logging.INFO, "Using whole image to test")
|
143 |
+
return self.model(inputs)
|
144 |
+
|
145 |
+
def _batch_inference(self, batched_inputs):
|
146 |
+
"""
|
147 |
+
Execute inference on a list of inputs,
|
148 |
+
using batch size = self.batch_size, instead of the length of the list.
|
149 |
+
Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward`
|
150 |
+
"""
|
151 |
+
outputs = []
|
152 |
+
inputs = []
|
153 |
+
for idx, input in zip(count(), batched_inputs):
|
154 |
+
inputs.append(input)
|
155 |
+
if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
|
156 |
+
with torch.no_grad():
|
157 |
+
outputs.extend(self._inference_with_model(inputs))
|
158 |
+
inputs = []
|
159 |
+
return outputs
|
160 |
+
|
161 |
+
def __call__(self, batched_inputs):
|
162 |
+
"""
|
163 |
+
Same input/output format as :meth:`SemanticSegmentor.forward`
|
164 |
+
"""
|
165 |
+
|
166 |
+
def _maybe_read_image(dataset_dict):
|
167 |
+
ret = copy.copy(dataset_dict)
|
168 |
+
if "image" not in ret:
|
169 |
+
image = read_image(ret.pop("file_name"), self.model.input_format)
|
170 |
+
image = torch.from_numpy(
|
171 |
+
np.ascontiguousarray(image.transpose(2, 0, 1))
|
172 |
+
) # CHW
|
173 |
+
ret["image"] = image
|
174 |
+
if "height" not in ret and "width" not in ret:
|
175 |
+
ret["height"] = image.shape[1]
|
176 |
+
ret["width"] = image.shape[2]
|
177 |
+
return ret
|
178 |
+
|
179 |
+
return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs]
|
180 |
+
|
181 |
+
def _inference_one_image(self, input):
|
182 |
+
"""
|
183 |
+
Args:
|
184 |
+
input (dict): one dataset dict with "image" field being a CHW tensor
|
185 |
+
Returns:
|
186 |
+
dict: one output dict
|
187 |
+
"""
|
188 |
+
augmented_inputs, tfms = self._get_augmented_inputs(input)
|
189 |
+
# 1: forward with all augmented images
|
190 |
+
outputs = self._batch_inference(augmented_inputs)
|
191 |
+
# Delete now useless variables to avoid being out of memory
|
192 |
+
del augmented_inputs
|
193 |
+
# 2: merge the results
|
194 |
+
# handle flip specially
|
195 |
+
# outputs = [output.detach() for output in outputs]
|
196 |
+
return self._merge_auged_output(outputs, tfms)
|
197 |
+
|
198 |
+
def _merge_auged_output(self, outputs, tfms):
|
199 |
+
new_outputs = []
|
200 |
+
for output, tfm in zip(outputs, tfms):
|
201 |
+
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
|
202 |
+
new_outputs.append(output["sem_seg"].flip(dims=[2]))
|
203 |
+
else:
|
204 |
+
new_outputs.append(output["sem_seg"])
|
205 |
+
del outputs
|
206 |
+
# to avoid OOM with torch.stack
|
207 |
+
final_predictions = new_outputs[0]
|
208 |
+
for i in range(1, len(new_outputs)):
|
209 |
+
final_predictions += new_outputs[i]
|
210 |
+
final_predictions = final_predictions / len(new_outputs)
|
211 |
+
del new_outputs
|
212 |
+
return {"sem_seg": final_predictions}
|
213 |
+
|
214 |
+
def _get_augmented_inputs(self, input):
|
215 |
+
augmented_inputs = self.tta_mapper(input)
|
216 |
+
tfms = [x.pop("transforms") for x in augmented_inputs]
|
217 |
+
return augmented_inputs, tfms
|