Spaces:
Runtime error
Runtime error
Qihang Yu
commited on
Commit
•
a06fad0
1
Parent(s):
f6d10ab
Add kMaX-DeepLab
Browse files- app.py +71 -4
- configs/coco/panoptic-segmentation/kmax_convnext_base.yaml +13 -0
- configs/coco/panoptic-segmentation/kmax_convnext_large.yaml +13 -0
- configs/coco/panoptic-segmentation/kmax_convnext_small.yaml +13 -0
- configs/coco/panoptic-segmentation/kmax_convnext_tiny.yaml +13 -0
- configs/coco/panoptic-segmentation/kmax_r50.yaml +91 -0
- convert-pretrained-model-to-d2.py +36 -0
- convert-tf-weights-to-d2.py +400 -0
- demo/demo.ipynb +213 -0
- demo/demo.py +156 -0
- demo/predictor.py +166 -0
- docs/clustering_view_of_mask_transformer.png +0 -0
- docs/kmax_decoder.png +0 -0
- kmax_deeplab/__init__.py +15 -0
- kmax_deeplab/config.py +96 -0
- kmax_deeplab/data/__init__.py +1 -0
- kmax_deeplab/data/dataset_mappers/__init__.py +0 -0
- kmax_deeplab/data/dataset_mappers/coco_panoptic_kmaxdeeplab_dataset_mapper.py +326 -0
- kmax_deeplab/data/datasets/__init__.py +3 -0
- kmax_deeplab/data/datasets/register_coco_panoptic_annos_semseg.py +182 -0
- kmax_deeplab/evaluation/__init__.py +0 -0
- kmax_deeplab/evaluation/instance_evaluation.py +107 -0
- kmax_deeplab/evaluation/panoptic_evaluation.py +269 -0
- kmax_deeplab/kmax_model.py +446 -0
- kmax_deeplab/modeling/__init__.py +4 -0
- kmax_deeplab/modeling/backbone/__init__.py +0 -0
- kmax_deeplab/modeling/backbone/convnext.py +210 -0
- kmax_deeplab/modeling/backbone/resnet.py +697 -0
- kmax_deeplab/modeling/criterion.py +432 -0
- kmax_deeplab/modeling/matcher.py +128 -0
- kmax_deeplab/modeling/meta_arch/__init__.py +0 -0
- kmax_deeplab/modeling/meta_arch/kmax_deeplab_head.py +88 -0
- kmax_deeplab/modeling/pixel_decoder/__init__.py +0 -0
- kmax_deeplab/modeling/pixel_decoder/kmax_pixel_decoder.py +370 -0
- kmax_deeplab/modeling/transformer_decoder/__init__.py +1 -0
- kmax_deeplab/modeling/transformer_decoder/kmax_transformer_decoder.py +453 -0
- pakages.txt +4 -0
- requirements.txt +34 -0
- train_net.py +266 -0
- train_net_utils.py +225 -0
app.py
CHANGED
@@ -1,7 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
os.system("pip install gdown")
|
5 |
+
|
6 |
+
os.system("pip install imutils")
|
7 |
+
|
8 |
+
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
|
9 |
+
|
10 |
+
os.system("pip install git+https://github.com/cocodataset/panopticapi.git")
|
11 |
+
|
12 |
import gradio as gr
|
13 |
+
# check pytorch installation:
|
14 |
+
import detectron2
|
15 |
+
from detectron2.utils.logger import setup_logger
|
16 |
+
|
17 |
+
# import some common libraries
|
18 |
+
import numpy as np
|
19 |
+
import cv2
|
20 |
+
import torch
|
21 |
+
|
22 |
+
# import some common detectron2 utilities
|
23 |
+
from detectron2 import model_zoo
|
24 |
+
from detectron2.engine import DefaultPredictor
|
25 |
+
from detectron2.config import get_cfg
|
26 |
+
from detectron2.utils.visualizer import Visualizer, ColorMode
|
27 |
+
from detectron2.data import MetadataCatalog
|
28 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
29 |
+
coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic")
|
30 |
+
|
31 |
+
# import kMaXDeepLab project
|
32 |
+
from kmax_deeplab import add_kmax_deeplab_config
|
33 |
+
|
34 |
+
from PIL import Image
|
35 |
+
import imutils
|
36 |
+
|
37 |
+
cfg = get_cfg()
|
38 |
+
cfg.MODEL.DEVICE='cpu'
|
39 |
+
add_deeplab_config(cfg)
|
40 |
+
add_kmax_deeplab_config(cfg)
|
41 |
+
cfg.merge_from_file("configs/coco/panoptic-segmentation/kmax_convnext_large.yaml")
|
42 |
+
os.system("gdown 1b6rEnKw4PNTdqSdWpmb0P9dsvN0pkOiN")
|
43 |
+
cfg.MODEL.WEIGHTS = './kmax_convnext_large.pth'
|
44 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON = True
|
45 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON = True
|
46 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON = True
|
47 |
+
predictor = DefaultPredictor(cfg)
|
48 |
+
|
49 |
+
os.system("wget https://i.imgur.com/Vj17K5z.jpg")
|
50 |
+
|
51 |
+
def inference(img):
|
52 |
+
im = cv2.imread(img)
|
53 |
+
im = imutils.resize(im, width=512)
|
54 |
+
outputs = predictor(im)
|
55 |
+
v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
|
56 |
+
panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
|
57 |
+
v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
|
58 |
+
instance_result = v.draw_instance_predictions(outputs["instances"].to("cpu")).get_image()
|
59 |
+
v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
|
60 |
+
semantic_result = v.draw_sem_seg(outputs["sem_seg"].argmax(0).to("cpu")).get_image()
|
61 |
+
return Image.fromarray(np.uint8(panoptic_result)).convert('RGB'),Image.fromarray(np.uint8(instance_result)).convert('RGB'),Image.fromarray(np.uint8(semantic_result)).convert('RGB')
|
62 |
+
|
63 |
+
|
64 |
+
title = "kMaX-DeepLab"
|
65 |
+
description = "Gradio demo for kMaX-DeepLab. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
66 |
+
|
67 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.01527' target='_blank'>kMaX-DeepLab</a> | <a href='https://github.com/google-research/deeplab2' target='_blank'>Github Repo</a></p>"
|
68 |
|
69 |
+
examples = [['Vj17K5z.jpg']]
|
|
|
70 |
|
71 |
+
gr.Interface(inference, inputs=gr.inputs.Image(type="filepath"), outputs=[gr.outputs.Image(label="Panoptic segmentation",type="pil"),gr.outputs.Image(label="instance segmentation",type="pil"),gr.outputs.Image(label="semantic segmentation",type="pil")], title=title,
|
72 |
+
description=description,
|
73 |
+
article=article,
|
74 |
+
examples=examples).launch(enable_queue=True,cache_examples=True)
|
configs/coco/panoptic-segmentation/kmax_convnext_base.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: kmax_r50.yaml
|
2 |
+
MODEL:
|
3 |
+
# backbone part.
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "D2ConvNeXt"
|
6 |
+
WEIGHTS: "./convnext_base_22k_1k_384.pkl"
|
7 |
+
CONVNEXT:
|
8 |
+
IN_CHANNELS: 3
|
9 |
+
DEPTHS: [3, 3, 27, 3]
|
10 |
+
DIMS: [128, 256, 512, 1024]
|
11 |
+
# https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_convnext_base_os32.textproto#L28
|
12 |
+
DROP_PATH_RATE: 0.5
|
13 |
+
OUT_INDICES: [0, 1, 2, 3]
|
configs/coco/panoptic-segmentation/kmax_convnext_large.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: kmax_r50.yaml
|
2 |
+
MODEL:
|
3 |
+
# backbone part.
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "D2ConvNeXt"
|
6 |
+
WEIGHTS: "./convnext_large_22k_1k_384.pkl"
|
7 |
+
CONVNEXT:
|
8 |
+
IN_CHANNELS: 3
|
9 |
+
DEPTHS: [3, 3, 27, 3]
|
10 |
+
DIMS: [192, 384, 768, 1536]
|
11 |
+
# https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_convnext_large_os32.textproto#L28
|
12 |
+
DROP_PATH_RATE: 0.6
|
13 |
+
OUT_INDICES: [0, 1, 2, 3]
|
configs/coco/panoptic-segmentation/kmax_convnext_small.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: kmax_r50.yaml
|
2 |
+
MODEL:
|
3 |
+
# backbone part.
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "D2ConvNeXt"
|
6 |
+
WEIGHTS: "./convnext_small_22k_1k_384.pkl"
|
7 |
+
CONVNEXT:
|
8 |
+
IN_CHANNELS: 3
|
9 |
+
DEPTHS: [3, 3, 27, 3]
|
10 |
+
DIMS: [96, 192, 384, 768]
|
11 |
+
# https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_convnext_small_os32.textproto#L28
|
12 |
+
DROP_PATH_RATE: 0.4
|
13 |
+
OUT_INDICES: [0, 1, 2, 3]
|
configs/coco/panoptic-segmentation/kmax_convnext_tiny.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: kmax_r50.yaml
|
2 |
+
MODEL:
|
3 |
+
# backbone part.
|
4 |
+
BACKBONE:
|
5 |
+
NAME: "D2ConvNeXt"
|
6 |
+
WEIGHTS: "./convnext_tiny_22k_1k_384.pkl"
|
7 |
+
CONVNEXT:
|
8 |
+
IN_CHANNELS: 3
|
9 |
+
DEPTHS: [3, 3, 9, 3]
|
10 |
+
DIMS: [96, 192, 384, 768]
|
11 |
+
# https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_convnext_tiny_os32.textproto#L28
|
12 |
+
DROP_PATH_RATE: 0.3
|
13 |
+
OUT_INDICES: [0, 1, 2, 3]
|
configs/coco/panoptic-segmentation/kmax_r50.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
# backbone part.
|
3 |
+
BACKBONE:
|
4 |
+
FREEZE_AT: 0
|
5 |
+
NAME: "custom_bn_build_resnet_backbone" # we customize the momentum and eps in syncbn, to align with tf implementation.
|
6 |
+
WEIGHTS: "../R-50.pkl"
|
7 |
+
PIXEL_MEAN: [127.5, 127.5, 127.5]
|
8 |
+
PIXEL_STD: [127.5, 127.5, 127.5]
|
9 |
+
RESNETS:
|
10 |
+
DEPTH: 50
|
11 |
+
STEM_TYPE: "basic" # not used
|
12 |
+
STEM_OUT_CHANNELS: 64
|
13 |
+
STRIDE_IN_1X1: False
|
14 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
15 |
+
NORM: "SyncBN"
|
16 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
17 |
+
|
18 |
+
# kmax part.
|
19 |
+
META_ARCHITECTURE: "kMaXDeepLab"
|
20 |
+
SEM_SEG_HEAD:
|
21 |
+
NAME: "kMaXDeepLabHead"
|
22 |
+
IGNORE_VALUE: 255
|
23 |
+
NUM_CLASSES: 133
|
24 |
+
LOSS_WEIGHT: 1.0
|
25 |
+
|
26 |
+
KMAX_DEEPLAB:
|
27 |
+
SAVE_VIS_NUM: 0
|
28 |
+
SHARE_FINAL_MATCHING: True
|
29 |
+
DEEP_SUPERVISION: True
|
30 |
+
NO_OBJECT_WEIGHT: 1e-5
|
31 |
+
CLASS_WEIGHT: 3.0
|
32 |
+
DICE_WEIGHT: 3.0
|
33 |
+
MASK_WEIGHT: 0.3
|
34 |
+
INSDIS_WEIGHT: 1.0
|
35 |
+
AUX_SEMANTIC_WEIGHT: 1.0
|
36 |
+
|
37 |
+
PIXEL_DEC:
|
38 |
+
NAME: "kMaXPixelDecoder"
|
39 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
40 |
+
DEC_LAYERS: [1, 5, 1, 1]
|
41 |
+
LAYER_TYPES: ["axial", "axial", "bottleneck", "bottleneck"]
|
42 |
+
DEC_CHANNELS: [512, 256, 128, 64]
|
43 |
+
|
44 |
+
TRANS_DEC:
|
45 |
+
NAME: "kMaXTransformerDecoder"
|
46 |
+
DEC_LAYERS: [2, 2, 2]
|
47 |
+
NUM_OBJECT_QUERIES: 128
|
48 |
+
IN_CHANNELS: [2048, 1024, 512] # [512 * 4, 256 * 4, 128 * 4]
|
49 |
+
DROP_PATH_PROB: 0.2
|
50 |
+
|
51 |
+
TEST:
|
52 |
+
SEMANTIC_ON: False
|
53 |
+
INSTANCE_ON: False # Save some time :)
|
54 |
+
PANOPTIC_ON: True
|
55 |
+
OBJECT_MASK_THRESHOLD: 0.4
|
56 |
+
CLASS_THRESHOLD_THING: 0.7
|
57 |
+
CLASS_THRESHOLD_STUFF: 0.5
|
58 |
+
REORDER_CLASS_WEIGHT: 1.0
|
59 |
+
REORDER_MASK_WEIGHT: 1.0
|
60 |
+
OVERLAP_THRESHOLD: 0.8
|
61 |
+
|
62 |
+
DATASETS:
|
63 |
+
TRAIN: ("coco_2017_train_panoptic",)
|
64 |
+
TEST: ("coco_2017_val_panoptic",)
|
65 |
+
SOLVER:
|
66 |
+
IMS_PER_BATCH: 64
|
67 |
+
BASE_LR: 0.0005
|
68 |
+
LR_SCHEDULER_NAME: "TF2WarmupPolyLR"
|
69 |
+
MAX_ITER: 150000
|
70 |
+
WARMUP_ITERS: 5000
|
71 |
+
WEIGHT_DECAY: 0.05
|
72 |
+
OPTIMIZER: "ADAMW"
|
73 |
+
BACKBONE_MULTIPLIER: 0.1
|
74 |
+
CLIP_GRADIENTS:
|
75 |
+
ENABLED: False
|
76 |
+
AMP:
|
77 |
+
ENABLED: True
|
78 |
+
INPUT:
|
79 |
+
IMAGE_SIZE: [1281, 1281]
|
80 |
+
MIN_SCALE: 0.2
|
81 |
+
MAX_SCALE: 2.0
|
82 |
+
FORMAT: "RGB"
|
83 |
+
DATASET_MAPPER_NAME: "coco_panoptic_lsj"
|
84 |
+
MIN_SIZE_TEST: 1281
|
85 |
+
MAX_SIZE_TEST: 1281
|
86 |
+
TEST:
|
87 |
+
EVAL_PERIOD: 5000
|
88 |
+
DATALOADER:
|
89 |
+
FILTER_EMPTY_ANNOTATIONS: True
|
90 |
+
NUM_WORKERS: 4
|
91 |
+
VERSION: 2
|
convert-pretrained-model-to-d2.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
|
4 |
+
import pickle as pkl
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
"""
|
10 |
+
Usage:
|
11 |
+
# download pretrained swin model:
|
12 |
+
wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
|
13 |
+
# run the conversion
|
14 |
+
./convert-pretrained-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl
|
15 |
+
# Then, use swin_tiny_patch4_window7_224.pkl with the following changes in config:
|
16 |
+
MODEL:
|
17 |
+
WEIGHTS: "/path/to/swin_tiny_patch4_window7_224.pkl"
|
18 |
+
INPUT:
|
19 |
+
FORMAT: "RGB"
|
20 |
+
"""
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
input = sys.argv[1]
|
24 |
+
|
25 |
+
obj = torch.load(input, map_location="cpu")["model"]
|
26 |
+
|
27 |
+
# Clean unused convnext weight
|
28 |
+
if "norm.weight" in obj:
|
29 |
+
del obj["norm.weight"]
|
30 |
+
if "norm.bias" in obj:
|
31 |
+
del obj["norm.bias"]
|
32 |
+
|
33 |
+
res = {"model": obj, "__author__": "third_party", "matching_heuristics": True}
|
34 |
+
|
35 |
+
with open(sys.argv[2], "wb") as f:
|
36 |
+
pkl.dump(res, f)
|
convert-tf-weights-to-d2.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import pickle as pkl
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def load_tf_weights(ckpt_path):
|
9 |
+
# https://stackoverflow.com/questions/40118062/how-to-read-weights-saved-in-tensorflow-checkpoint-file
|
10 |
+
from tensorflow.python.training import py_checkpoint_reader
|
11 |
+
reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path)
|
12 |
+
state_dict = {}
|
13 |
+
for k in reader.get_variable_to_shape_map():
|
14 |
+
if '.OPTIMIZER_SLOT' in k or 'optimizer' in k or '_CHECKPOINTABLE_OBJECT_GRAPH' in k or 'save_counter' in k or 'global_step' in k:
|
15 |
+
continue
|
16 |
+
v = reader.get_tensor(k)
|
17 |
+
state_dict[k.replace('/.ATTRIBUTES/VARIABLE_VALUE', '')] = v
|
18 |
+
for k in sorted(state_dict.keys()):
|
19 |
+
print(k, state_dict[k].shape)
|
20 |
+
return state_dict
|
21 |
+
|
22 |
+
def map_bn(name1, name2):
|
23 |
+
res = {}
|
24 |
+
res[name1 + '/gamma'] = name2 + ".weight"
|
25 |
+
res[name1 + '/beta'] = name2 + ".bias"
|
26 |
+
res[name1 + '/moving_mean'] = name2 + ".running_mean"
|
27 |
+
res[name1 + '/moving_variance'] = name2 + ".running_var"
|
28 |
+
return res
|
29 |
+
|
30 |
+
|
31 |
+
def map_conv(name1, name2, dw=False, bias=False):
|
32 |
+
res = {}
|
33 |
+
if dw:
|
34 |
+
res[name1 + '/depthwise_kernel'] = name2 + ".weight"
|
35 |
+
else:
|
36 |
+
res[name1 + '/kernel'] = name2 + ".weight"
|
37 |
+
if bias:
|
38 |
+
res[name1 + '/bias'] = name2 + ".bias"
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
def tf_2_torch_mapping_r50():
|
43 |
+
res = {}
|
44 |
+
res.update(map_conv('encoder/_stem/_conv', 'backbone.stem.conv1'))
|
45 |
+
res.update(map_bn('encoder/_stem/_batch_norm', 'backbone.stem.conv1.norm'))
|
46 |
+
block_num = {2: 3, 3: 4, 4: 6, 5: 3}
|
47 |
+
for stage_idx in range(2, 6):
|
48 |
+
for block_idx in range(1, block_num[stage_idx] + 1):
|
49 |
+
res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv',
|
50 |
+
f'backbone.res{stage_idx}.{block_idx-1}.conv1'))
|
51 |
+
res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm',
|
52 |
+
f'backbone.res{stage_idx}.{block_idx-1}.conv1.norm'))
|
53 |
+
res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv',
|
54 |
+
f'backbone.res{stage_idx}.{block_idx-1}.conv2'))
|
55 |
+
res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm',
|
56 |
+
f'backbone.res{stage_idx}.{block_idx-1}.conv2.norm'))
|
57 |
+
res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_conv',
|
58 |
+
f'backbone.res{stage_idx}.{block_idx-1}.conv3'))
|
59 |
+
res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm',
|
60 |
+
f'backbone.res{stage_idx}.{block_idx-1}.conv3.norm'))
|
61 |
+
res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_conv',
|
62 |
+
f'backbone.res{stage_idx}.{block_idx-1}.shortcut'))
|
63 |
+
res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_batch_norm',
|
64 |
+
f'backbone.res{stage_idx}.{block_idx-1}.shortcut.norm'))
|
65 |
+
return res
|
66 |
+
|
67 |
+
def tf_2_torch_mapping_convnext():
|
68 |
+
res = {}
|
69 |
+
for i in range(4):
|
70 |
+
if i == 0:
|
71 |
+
res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-0',
|
72 |
+
f'backbone.downsample_layers.{i}.0', bias=True))
|
73 |
+
res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-1',
|
74 |
+
f'backbone.downsample_layers.{i}.1'))
|
75 |
+
else:
|
76 |
+
res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-1',
|
77 |
+
f'backbone.downsample_layers.{i}.1', bias=True))
|
78 |
+
res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-0',
|
79 |
+
f'backbone.downsample_layers.{i}.0'))
|
80 |
+
|
81 |
+
block_num = {0: 3, 1: 3, 2: 27, 3: 3}
|
82 |
+
for stage_idx in range(4):
|
83 |
+
for block_idx in range(block_num[stage_idx]):
|
84 |
+
res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/depthwise_conv',
|
85 |
+
f'backbone.stages.{stage_idx}.{block_idx}.dwconv', bias=True))
|
86 |
+
res.update(map_bn(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/norm',
|
87 |
+
f'backbone.stages.{stage_idx}.{block_idx}.norm'))
|
88 |
+
res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv1',
|
89 |
+
f'backbone.stages.{stage_idx}.{block_idx}.pwconv1', bias=True))
|
90 |
+
res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv2',
|
91 |
+
f'backbone.stages.{stage_idx}.{block_idx}.pwconv2', bias=True))
|
92 |
+
res[f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/layer_scale'] = f'backbone.stages.{stage_idx}.{block_idx}.gamma'
|
93 |
+
|
94 |
+
return res
|
95 |
+
|
96 |
+
def tf_2_torch_mapping_pixel_dec():
|
97 |
+
res = {}
|
98 |
+
for i in range(4):
|
99 |
+
res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
|
100 |
+
res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
|
101 |
+
res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
|
102 |
+
res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
|
103 |
+
|
104 |
+
for i in range(3):
|
105 |
+
res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_conv',
|
106 |
+
f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.conv'))
|
107 |
+
res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_batch_norm',
|
108 |
+
f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.norm'))
|
109 |
+
res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_conv',
|
110 |
+
f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.conv'))
|
111 |
+
res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_batch_norm',
|
112 |
+
f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.norm'))
|
113 |
+
|
114 |
+
num_blocks = {0: 1, 1:5, 2:1, 3:1}
|
115 |
+
for stage_idx in range(4):
|
116 |
+
for block_idx in range(1, 1+num_blocks[stage_idx]):
|
117 |
+
res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_conv',
|
118 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.conv'))
|
119 |
+
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_batch_norm',
|
120 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.norm'))
|
121 |
+
res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv',
|
122 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.conv'))
|
123 |
+
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm',
|
124 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.norm'))
|
125 |
+
res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_conv',
|
126 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.conv'))
|
127 |
+
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm',
|
128 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.norm'))
|
129 |
+
if stage_idx <= 1:
|
130 |
+
for attn in ['height', 'width']:
|
131 |
+
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_qkv',
|
132 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_qkv'))
|
133 |
+
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_retrieved_output',
|
134 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_retrieved_output'))
|
135 |
+
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_similarity',
|
136 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_similarity'))
|
137 |
+
res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_key_rpe/embeddings'] = (
|
138 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._key_rpe._embeddings.weight')
|
139 |
+
res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_query_rpe/embeddings'] = (
|
140 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._query_rpe._embeddings.weight')
|
141 |
+
res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_value_rpe/embeddings'] = (
|
142 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._value_rpe._embeddings.weight')
|
143 |
+
res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/qkv_kernel'] = (
|
144 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis.qkv_transform.conv.weight')
|
145 |
+
else:
|
146 |
+
res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv',
|
147 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.conv'))
|
148 |
+
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm',
|
149 |
+
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.norm'))
|
150 |
+
return res
|
151 |
+
|
152 |
+
|
153 |
+
def tf_2_torch_mapping_predcitor(prefix_tf, prefix_torch):
|
154 |
+
res = {}
|
155 |
+
res.update(map_bn(prefix_tf + 'pixel_space_feature_batch_norm',
|
156 |
+
prefix_torch + '_pixel_space_head_last_convbn.norm'))
|
157 |
+
res[prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel'] = (
|
158 |
+
prefix_torch + '_pixel_space_head_conv0bnact.conv.weight'
|
159 |
+
)
|
160 |
+
res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_batch_norm',
|
161 |
+
prefix_torch + '_pixel_space_head_conv0bnact.norm'))
|
162 |
+
res.update(map_conv(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_conv',
|
163 |
+
prefix_torch + '_pixel_space_head_conv1bnact.conv'))
|
164 |
+
res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_batch_norm',
|
165 |
+
prefix_torch + '_pixel_space_head_conv1bnact.norm'))
|
166 |
+
res.update(map_conv(prefix_tf + 'pixel_space_head/final_conv',
|
167 |
+
prefix_torch + '_pixel_space_head_last_convbn.conv', bias=True))
|
168 |
+
res.update(map_bn(prefix_tf + 'pixel_space_mask_batch_norm',
|
169 |
+
prefix_torch + '_pixel_space_mask_batch_norm'))
|
170 |
+
res.update(map_conv(prefix_tf + 'transformer_class_head/_conv',
|
171 |
+
prefix_torch + '_transformer_class_head.conv', bias=True))
|
172 |
+
res.update(map_conv(prefix_tf + 'transformer_mask_head/_conv',
|
173 |
+
prefix_torch + '_transformer_mask_head.conv'))
|
174 |
+
res.update(map_bn(prefix_tf + 'transformer_mask_head/_batch_norm',
|
175 |
+
prefix_torch + '_transformer_mask_head.norm'))
|
176 |
+
|
177 |
+
return res
|
178 |
+
|
179 |
+
|
180 |
+
def tf_2_torch_mapping_trans_dec():
|
181 |
+
res = {}
|
182 |
+
|
183 |
+
res.update(map_bn('transformer_decoder/_class_embedding_projection/_batch_norm',
|
184 |
+
'sem_seg_head.predictor._class_embedding_projection.norm'))
|
185 |
+
res.update(map_conv('transformer_decoder/_class_embedding_projection/_conv',
|
186 |
+
'sem_seg_head.predictor._class_embedding_projection.conv'))
|
187 |
+
res.update(map_bn('transformer_decoder/_mask_embedding_projection/_batch_norm',
|
188 |
+
'sem_seg_head.predictor._mask_embedding_projection.norm'))
|
189 |
+
res.update(map_conv('transformer_decoder/_mask_embedding_projection/_conv',
|
190 |
+
'sem_seg_head.predictor._mask_embedding_projection.conv'))
|
191 |
+
|
192 |
+
res['transformer_decoder/cluster_centers'] = 'sem_seg_head.predictor._cluster_centers.weight'
|
193 |
+
|
194 |
+
res.update(tf_2_torch_mapping_predcitor(
|
195 |
+
prefix_tf = '',
|
196 |
+
prefix_torch = 'sem_seg_head.predictor._predcitor.'
|
197 |
+
))
|
198 |
+
for kmax_idx in range(6):
|
199 |
+
res.update(tf_2_torch_mapping_predcitor(
|
200 |
+
prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/_auxiliary_clustering_predictor/_',
|
201 |
+
prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}._predcitor.'
|
202 |
+
))
|
203 |
+
common_prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/'
|
204 |
+
common_prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}.'
|
205 |
+
res.update(map_bn(common_prefix_tf + '_kmeans_memory_batch_norm_retrieved_value',
|
206 |
+
common_prefix_torch + '_kmeans_query_batch_norm_retrieved_value'))
|
207 |
+
res.update(map_bn(common_prefix_tf + '_kmeans_memory_conv3_bn/_batch_norm',
|
208 |
+
common_prefix_torch + '_kmeans_query_conv3_bn.norm'))
|
209 |
+
res.update(map_conv(common_prefix_tf + '_kmeans_memory_conv3_bn/_conv',
|
210 |
+
common_prefix_torch + '_kmeans_query_conv3_bn.conv'))
|
211 |
+
res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_retrieved_value',
|
212 |
+
common_prefix_torch + '_query_self_attention._batch_norm_retrieved_value'))
|
213 |
+
res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_similarity',
|
214 |
+
common_prefix_torch + '_query_self_attention._batch_norm_similarity'))
|
215 |
+
|
216 |
+
res.update(map_bn(common_prefix_tf + '_memory_conv1_bn_act/_batch_norm',
|
217 |
+
common_prefix_torch + '_query_conv1_bn_act.norm'))
|
218 |
+
res.update(map_conv(common_prefix_tf + '_memory_conv1_bn_act/_conv',
|
219 |
+
common_prefix_torch + '_query_conv1_bn_act.conv'))
|
220 |
+
|
221 |
+
res.update(map_bn(common_prefix_tf + '_memory_conv3_bn/_batch_norm',
|
222 |
+
common_prefix_torch + '_query_conv3_bn.norm'))
|
223 |
+
res.update(map_conv(common_prefix_tf + '_memory_conv3_bn/_conv',
|
224 |
+
common_prefix_torch + '_query_conv3_bn.conv'))
|
225 |
+
|
226 |
+
res.update(map_bn(common_prefix_tf + '_memory_ffn_conv1_bn_act/_batch_norm',
|
227 |
+
common_prefix_torch + '_query_ffn_conv1_bn_act.norm'))
|
228 |
+
res.update(map_conv(common_prefix_tf + '_memory_ffn_conv1_bn_act/_conv',
|
229 |
+
common_prefix_torch + '_query_ffn_conv1_bn_act.conv'))
|
230 |
+
|
231 |
+
res.update(map_bn(common_prefix_tf + '_memory_ffn_conv2_bn/_batch_norm',
|
232 |
+
common_prefix_torch + '_query_ffn_conv2_bn.norm'))
|
233 |
+
res.update(map_conv(common_prefix_tf + '_memory_ffn_conv2_bn/_conv',
|
234 |
+
common_prefix_torch + '_query_ffn_conv2_bn.conv'))
|
235 |
+
|
236 |
+
res.update(map_bn(common_prefix_tf + '_memory_qkv_conv_bn/_batch_norm',
|
237 |
+
common_prefix_torch + '_query_qkv_conv_bn.norm'))
|
238 |
+
res.update(map_conv(common_prefix_tf + '_memory_qkv_conv_bn/_conv',
|
239 |
+
common_prefix_torch + '_query_qkv_conv_bn.conv'))
|
240 |
+
|
241 |
+
res.update(map_bn(common_prefix_tf + '_pixel_conv1_bn_act/_batch_norm',
|
242 |
+
common_prefix_torch + '_pixel_conv1_bn_act.norm'))
|
243 |
+
res.update(map_conv(common_prefix_tf + '_pixel_conv1_bn_act/_conv',
|
244 |
+
common_prefix_torch + '_pixel_conv1_bn_act.conv'))
|
245 |
+
|
246 |
+
res.update(map_bn(common_prefix_tf + '_pixel_v_conv_bn/_batch_norm',
|
247 |
+
common_prefix_torch + '_pixel_v_conv_bn.norm'))
|
248 |
+
res.update(map_conv(common_prefix_tf + '_pixel_v_conv_bn/_conv',
|
249 |
+
common_prefix_torch + '_pixel_v_conv_bn.conv'))
|
250 |
+
|
251 |
+
return res
|
252 |
+
|
253 |
+
|
254 |
+
def tf_2_torch_mapping_aux_semanic_dec():
|
255 |
+
res = {}
|
256 |
+
res.update(map_conv('semantic_decoder/_aspp/_conv_bn_act/_conv',
|
257 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.conv'))
|
258 |
+
res.update(map_bn('semantic_decoder/_aspp/_conv_bn_act/_batch_norm',
|
259 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.norm'))
|
260 |
+
|
261 |
+
res.update(map_conv('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_conv',
|
262 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.conv'))
|
263 |
+
res.update(map_bn('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_batch_norm',
|
264 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.norm'))
|
265 |
+
|
266 |
+
res.update(map_conv('semantic_decoder/_aspp/_proj_conv_bn_act/_conv',
|
267 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.conv'))
|
268 |
+
res.update(map_bn('semantic_decoder/_aspp/_proj_conv_bn_act/_batch_norm',
|
269 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.norm'))
|
270 |
+
for i in range(1, 4):
|
271 |
+
res.update(map_conv(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_conv',
|
272 |
+
f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.conv'))
|
273 |
+
res.update(map_bn(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_batch_norm',
|
274 |
+
f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.norm'))
|
275 |
+
|
276 |
+
res.update({
|
277 |
+
'semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
|
278 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.conv.weight'})
|
279 |
+
res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_batch_norm',
|
280 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.norm'))
|
281 |
+
res.update({
|
282 |
+
'semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_conv/kernel':
|
283 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.conv.weight'})
|
284 |
+
res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_batch_norm',
|
285 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.norm'))
|
286 |
+
|
287 |
+
res.update({
|
288 |
+
'semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
|
289 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.conv.weight'})
|
290 |
+
res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_batch_norm',
|
291 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.norm'))
|
292 |
+
res.update({
|
293 |
+
'semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_conv/kernel':
|
294 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.conv.weight'})
|
295 |
+
res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_batch_norm',
|
296 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.norm'))
|
297 |
+
|
298 |
+
res.update({
|
299 |
+
'semantic_decoder/_low_level_conv1/_conv/kernel':
|
300 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.conv.weight'})
|
301 |
+
res.update(map_bn('semantic_decoder/_low_level_conv1/_batch_norm',
|
302 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.norm'))
|
303 |
+
res.update({
|
304 |
+
'semantic_decoder/_low_level_conv2/_conv/kernel':
|
305 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.conv.weight'})
|
306 |
+
res.update(map_bn('semantic_decoder/_low_level_conv2/_batch_norm',
|
307 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.norm'))
|
308 |
+
|
309 |
+
|
310 |
+
res.update({
|
311 |
+
'semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
|
312 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight'})
|
313 |
+
res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_batch_norm',
|
314 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.norm'))
|
315 |
+
res.update({
|
316 |
+
'semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_conv/kernel':
|
317 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.conv.weight'})
|
318 |
+
res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_batch_norm',
|
319 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.norm'))
|
320 |
+
|
321 |
+
res.update({
|
322 |
+
'semantic_last_layer/kernel':
|
323 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.weight'})
|
324 |
+
res.update({
|
325 |
+
'semantic_last_layer/bias':
|
326 |
+
'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.bias'})
|
327 |
+
return res
|
328 |
+
|
329 |
+
|
330 |
+
# python3 convert-tf-weights-to-d2.py kmax_resnet50_coco_train/ckpt-150000 tf_kmax_r50.pkl
|
331 |
+
|
332 |
+
if __name__ == "__main__":
|
333 |
+
input = sys.argv[1]
|
334 |
+
|
335 |
+
state_dict = load_tf_weights(input)
|
336 |
+
#exit()
|
337 |
+
|
338 |
+
state_dict_torch = {}
|
339 |
+
|
340 |
+
mapping_key = {}
|
341 |
+
if 'resnet50' in input:
|
342 |
+
mapping_key.update(tf_2_torch_mapping_r50())
|
343 |
+
elif 'convnext' in input:
|
344 |
+
mapping_key.update(tf_2_torch_mapping_convnext())
|
345 |
+
mapping_key.update(tf_2_torch_mapping_pixel_dec())
|
346 |
+
mapping_key.update(tf_2_torch_mapping_trans_dec())
|
347 |
+
|
348 |
+
mapping_key.update(tf_2_torch_mapping_aux_semanic_dec())
|
349 |
+
|
350 |
+
for k in state_dict.keys():
|
351 |
+
value = state_dict[k]
|
352 |
+
k2 = mapping_key[k]
|
353 |
+
rank = len(value.shape)
|
354 |
+
|
355 |
+
if '_batch_norm_retrieved_output' in k2 or '_batch_norm_similarity' in k2 or '_batch_norm_retrieved_value' in k2:
|
356 |
+
value = np.reshape(value, [-1])
|
357 |
+
elif 'qkv_transform.conv.weight' in k2:
|
358 |
+
# (512, 1024) -> (1024, 512, 1)
|
359 |
+
value = np.transpose(value, (1, 0))[:, :, None]
|
360 |
+
elif '_cluster_centers.weight' in k2:
|
361 |
+
# (1, 128, 256) -> (256, 128)
|
362 |
+
value = np.transpose(value[0], (1, 0))
|
363 |
+
elif '_pixel_conv1_bn_act.conv.weight' in k2:
|
364 |
+
# (1, 512, 256) -> (256, 512, 1, 1)
|
365 |
+
value = np.transpose(value, (2, 1, 0))[:, :, :, None]
|
366 |
+
elif '_pixel_v_conv_bn.conv.weight' in k2:
|
367 |
+
# (1, 256, 256) -> (256, 256, 1, 1)
|
368 |
+
value = np.transpose(value, (2, 1, 0))[:, :, :, None]
|
369 |
+
elif '_pixel_space_head_conv0bnact.conv.weight' in k2:
|
370 |
+
# (5, 5, 256, 1) -> (256, 1, 5, 5)
|
371 |
+
value = np.transpose(value, (2, 3, 0, 1))
|
372 |
+
elif '/layer_scale' in k:
|
373 |
+
value = np.reshape(value, [-1])
|
374 |
+
elif 'pwconv1.weight' in k2 or 'pwconv2.weight' in k2:
|
375 |
+
# (128, 512) -> (512, 128)
|
376 |
+
value = np.transpose(value, (1, 0))
|
377 |
+
elif ('_low_level_fusion_os4_conv0_bn_act.conv.weight' in k2
|
378 |
+
or '_low_level_fusion_os8_conv0_bn_act.conv.weight' in k2
|
379 |
+
or 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight' in k2):
|
380 |
+
value = np.transpose(value, (2, 3, 0, 1))
|
381 |
+
else:
|
382 |
+
if rank == 1: # bias, norm etc
|
383 |
+
pass
|
384 |
+
elif rank == 2: # _query_rpe
|
385 |
+
pass
|
386 |
+
elif rank == 3: # conv 1d kernel, etc
|
387 |
+
value = np.transpose(value, (2, 1, 0))
|
388 |
+
elif rank == 4: # conv 2d kernel, etc
|
389 |
+
value = np.transpose(value, (3, 2, 0, 1))
|
390 |
+
|
391 |
+
state_dict_torch[k2] = value
|
392 |
+
|
393 |
+
res = {"model": state_dict_torch, "__author__": "third_party", "matching_heuristics": True}
|
394 |
+
|
395 |
+
with open(sys.argv[2], "wb") as f:
|
396 |
+
pkl.dump(res, f)
|
397 |
+
|
398 |
+
|
399 |
+
# r50: 52.85 -> 52.71 w/ eps 1e-3
|
400 |
+
# convnext-base: 56.85 -> 56.97 w/ eps 1e-3
|
demo/demo.ipynb
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"attachments": {},
|
5 |
+
"cell_type": "markdown",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# kMaX-DeepLab Demo\n",
|
9 |
+
"This notebook is modified by Qihang Yu, with reference from [Mask2Former's script](https://colab.research.google.com/drive/1uIWE5KbGFSjrxey2aRd5pWkKNY1_SaNq)"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"attachments": {},
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {},
|
16 |
+
"source": [
|
17 |
+
"# Install detectron2"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": null,
|
23 |
+
"metadata": {},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"# Install detectron2\n",
|
27 |
+
"import torch\n",
|
28 |
+
"TORCH_VERSION = \".\".join(torch.__version__.split(\".\")[:2])\n",
|
29 |
+
"CUDA_VERSION = torch.__version__.split(\"+\")[-1]\n",
|
30 |
+
"print(\"torch: \", TORCH_VERSION, \"; cuda: \", CUDA_VERSION)\n",
|
31 |
+
"# Install detectron2 that matches the above pytorch version\n",
|
32 |
+
"# See https://detectron2.readthedocs.io/tutorials/install.html for instructions\n",
|
33 |
+
"!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"attachments": {},
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"metadata": {},
|
40 |
+
"source": [
|
41 |
+
"# Install kMaX-DeepLab"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": null,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"# clone and install kMaX-DeepLab\n",
|
51 |
+
"!git clone https://github.com/yucornetto/kmaxdeeplab_detectron2.git\n",
|
52 |
+
"%cd kmaxdeeplab_detectron2\n",
|
53 |
+
"!pip install -U opencv-python\n",
|
54 |
+
"!pip install git+https://github.com/cocodataset/panopticapi.git\n",
|
55 |
+
"!pip install -r requirements.txt"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": null,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"# You may need to restart your runtime prior to this, to let your installation take effect\n",
|
65 |
+
"%cd /content/kmaxdeeplab_detectron2\n",
|
66 |
+
"# Some basic setup:\n",
|
67 |
+
"# Setup detectron2 logger\n",
|
68 |
+
"import detectron2\n",
|
69 |
+
"from detectron2.utils.logger import setup_logger\n",
|
70 |
+
"setup_logger()\n",
|
71 |
+
"setup_logger(name=\"kmax_deeplab\")\n",
|
72 |
+
"\n",
|
73 |
+
"# import some common libraries\n",
|
74 |
+
"import numpy as np\n",
|
75 |
+
"import cv2\n",
|
76 |
+
"import torch\n",
|
77 |
+
"from google.colab.patches import cv2_imshow\n",
|
78 |
+
"\n",
|
79 |
+
"# import some common detectron2 utilities\n",
|
80 |
+
"from detectron2 import model_zoo\n",
|
81 |
+
"from detectron2.engine import DefaultPredictor\n",
|
82 |
+
"from detectron2.config import get_cfg\n",
|
83 |
+
"from detectron2.utils.visualizer import Visualizer, ColorMode\n",
|
84 |
+
"from detectron2.data import MetadataCatalog\n",
|
85 |
+
"from detectron2.projects.deeplab import add_deeplab_config\n",
|
86 |
+
"coco_metadata = MetadataCatalog.get(\"coco_2017_val_panoptic\")\n",
|
87 |
+
"\n",
|
88 |
+
"# import Mask2Former project\n",
|
89 |
+
"from kmax_deeplab import add_kmax_deeplab_config"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"attachments": {},
|
94 |
+
"cell_type": "markdown",
|
95 |
+
"metadata": {},
|
96 |
+
"source": [
|
97 |
+
"# Run a pre-trained Mask2Former model\n",
|
98 |
+
"We first download an image from the COCO dataset:"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": null,
|
104 |
+
"metadata": {},
|
105 |
+
"outputs": [],
|
106 |
+
"source": [
|
107 |
+
"!wget http://images.cocodataset.org/val2017/000000005477.jpg -q -O input.jpg\n",
|
108 |
+
"im = cv2.imread(\"./input.jpg\")\n",
|
109 |
+
"cv2_imshow(im)"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"attachments": {},
|
114 |
+
"cell_type": "markdown",
|
115 |
+
"metadata": {},
|
116 |
+
"source": [
|
117 |
+
"Then, we create a detectron2 config and a detectron2 `DefaultPredictor` to run inference on this image."
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"cfg = get_cfg()\n",
|
127 |
+
"add_deeplab_config(cfg)\n",
|
128 |
+
"add_kmax_deeplab_config(cfg)\n",
|
129 |
+
"cfg.merge_from_file(\"configs/coco/panoptic-segmentation/kmax_convnext_large.yaml\")\n",
|
130 |
+
"cfg.MODEL.WEIGHTS = 'https://drive.google.com/uc?id=1b6rEnKw4PNTdqSdWpmb0P9dsvN0pkOiN&export=download'\n",
|
131 |
+
"cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON = True\n",
|
132 |
+
"cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON = True\n",
|
133 |
+
"cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON = True\n",
|
134 |
+
"predictor = DefaultPredictor(cfg)\n",
|
135 |
+
"outputs = predictor(im)"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": null,
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"# Show panoptic/instance/semantic predictions: \n",
|
145 |
+
"v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
|
146 |
+
"panoptic_result = v.draw_panoptic_seg(outputs[\"panoptic_seg\"][0].to(\"cpu\"), outputs[\"panoptic_seg\"][1]).get_image()\n",
|
147 |
+
"v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
|
148 |
+
"instance_result = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\")).get_image()\n",
|
149 |
+
"v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
|
150 |
+
"semantic_result = v.draw_sem_seg(outputs[\"sem_seg\"].argmax(0).to(\"cpu\")).get_image()\n",
|
151 |
+
"print(\"Panoptic segmentation (top), instance segmentation (middle), semantic segmentation (bottom)\")\n",
|
152 |
+
"cv2_imshow(np.concatenate((panoptic_result, instance_result, semantic_result), axis=0)[:, :, ::-1])"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"attachments": {},
|
157 |
+
"cell_type": "markdown",
|
158 |
+
"metadata": {},
|
159 |
+
"source": [
|
160 |
+
"Let's try an image not from COCO as well:"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": null,
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"# Download a sample image and display. Replace path here to try your own images!\n",
|
170 |
+
"!wget https://web.eecs.umich.edu/~fouhey/fun/desk/desk.jpg\n",
|
171 |
+
"im = cv2.imread(\"./desk.jpg\")\n",
|
172 |
+
"cv2_imshow(im)"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": null,
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [],
|
180 |
+
"source": [
|
181 |
+
"outputs = predictor(im)\n",
|
182 |
+
"# Show panoptic/instance/semantic predictions: \n",
|
183 |
+
"v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
|
184 |
+
"panoptic_result = v.draw_panoptic_seg(outputs[\"panoptic_seg\"][0].to(\"cpu\"), outputs[\"panoptic_seg\"][1]).get_image()\n",
|
185 |
+
"v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
|
186 |
+
"instance_result = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\")).get_image()\n",
|
187 |
+
"v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
|
188 |
+
"semantic_result = v.draw_sem_seg(outputs[\"sem_seg\"].argmax(0).to(\"cpu\")).get_image()\n",
|
189 |
+
"print(\"Panoptic segmentation (top), instance segmentation (middle), semantic segmentation (bottom)\")\n",
|
190 |
+
"cv2_imshow(np.concatenate((panoptic_result, instance_result, semantic_result), axis=0)[:, :, ::-1])"
|
191 |
+
]
|
192 |
+
}
|
193 |
+
],
|
194 |
+
"metadata": {
|
195 |
+
"kernelspec": {
|
196 |
+
"display_name": "Python 3",
|
197 |
+
"language": "python",
|
198 |
+
"name": "python3"
|
199 |
+
},
|
200 |
+
"language_info": {
|
201 |
+
"name": "python",
|
202 |
+
"version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
|
203 |
+
},
|
204 |
+
"orig_nbformat": 4,
|
205 |
+
"vscode": {
|
206 |
+
"interpreter": {
|
207 |
+
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
208 |
+
}
|
209 |
+
}
|
210 |
+
},
|
211 |
+
"nbformat": 4,
|
212 |
+
"nbformat_minor": 2
|
213 |
+
}
|
demo/demo.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import multiprocessing as mp
|
6 |
+
import os
|
7 |
+
|
8 |
+
# fmt: off
|
9 |
+
import sys
|
10 |
+
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
11 |
+
# fmt: on
|
12 |
+
|
13 |
+
import tempfile
|
14 |
+
import time
|
15 |
+
import warnings
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
import numpy as np
|
19 |
+
import tqdm
|
20 |
+
|
21 |
+
from detectron2.config import get_cfg
|
22 |
+
from detectron2.data.detection_utils import read_image
|
23 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
24 |
+
from detectron2.utils.logger import setup_logger
|
25 |
+
|
26 |
+
from kmax_deeplab import add_kmax_deeplab_config
|
27 |
+
from predictor import VisualizationDemo
|
28 |
+
|
29 |
+
|
30 |
+
# constants
|
31 |
+
WINDOW_NAME = "kmaxdeeplab demo"
|
32 |
+
|
33 |
+
|
34 |
+
def setup_cfg(args):
|
35 |
+
# load config from file and command-line arguments
|
36 |
+
cfg = get_cfg()
|
37 |
+
add_deeplab_config(cfg)
|
38 |
+
add_kmax_deeplab_config(cfg)
|
39 |
+
cfg.merge_from_file(args.config_file)
|
40 |
+
cfg.merge_from_list(args.opts)
|
41 |
+
cfg.freeze()
|
42 |
+
return cfg
|
43 |
+
|
44 |
+
|
45 |
+
def get_parser():
|
46 |
+
parser = argparse.ArgumentParser(description="kmaxdeeplab demo for builtin configs")
|
47 |
+
parser.add_argument(
|
48 |
+
"--config-file",
|
49 |
+
default="configs/coco/panoptic-segmentation/kmax_convnext_large.yaml",
|
50 |
+
metavar="FILE",
|
51 |
+
help="path to config file",
|
52 |
+
)
|
53 |
+
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
|
54 |
+
parser.add_argument("--video-input", help="Path to video file.")
|
55 |
+
parser.add_argument(
|
56 |
+
"--input",
|
57 |
+
nargs="+",
|
58 |
+
help="A list of space separated input images; "
|
59 |
+
"or a single glob pattern such as 'directory/*.jpg'",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--output",
|
63 |
+
help="A file or directory to save output visualizations. "
|
64 |
+
"If not given, will show output in an OpenCV window.",
|
65 |
+
)
|
66 |
+
|
67 |
+
parser.add_argument(
|
68 |
+
"--confidence-threshold",
|
69 |
+
type=float,
|
70 |
+
default=0.5,
|
71 |
+
help="Minimum score for instance predictions to be shown",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--opts",
|
75 |
+
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
76 |
+
default=[],
|
77 |
+
nargs=argparse.REMAINDER,
|
78 |
+
)
|
79 |
+
return parser
|
80 |
+
|
81 |
+
|
82 |
+
def test_opencv_video_format(codec, file_ext):
|
83 |
+
with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
|
84 |
+
filename = os.path.join(dir, "test_file" + file_ext)
|
85 |
+
writer = cv2.VideoWriter(
|
86 |
+
filename=filename,
|
87 |
+
fourcc=cv2.VideoWriter_fourcc(*codec),
|
88 |
+
fps=float(30),
|
89 |
+
frameSize=(10, 10),
|
90 |
+
isColor=True,
|
91 |
+
)
|
92 |
+
[writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
|
93 |
+
writer.release()
|
94 |
+
if os.path.isfile(filename):
|
95 |
+
return True
|
96 |
+
return False
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
mp.set_start_method("spawn", force=True)
|
101 |
+
args = get_parser().parse_args()
|
102 |
+
setup_logger(name="fvcore")
|
103 |
+
logger = setup_logger()
|
104 |
+
logger.info("Arguments: " + str(args))
|
105 |
+
|
106 |
+
cfg = setup_cfg(args)
|
107 |
+
|
108 |
+
demo = VisualizationDemo(cfg)
|
109 |
+
|
110 |
+
if args.input:
|
111 |
+
if len(args.input) == 1:
|
112 |
+
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
113 |
+
assert args.input, "The input path(s) was not found"
|
114 |
+
for path in tqdm.tqdm(args.input, disable=not args.output):
|
115 |
+
# use PIL, to be consistent with evaluation
|
116 |
+
img = read_image(path, format="BGR")
|
117 |
+
start_time = time.time()
|
118 |
+
predictions, visualized_output = demo.run_on_image(img)
|
119 |
+
logger.info(
|
120 |
+
"{}: {} in {:.2f}s".format(
|
121 |
+
path,
|
122 |
+
"detected {} instances".format(len(predictions["instances"]))
|
123 |
+
if "instances" in predictions
|
124 |
+
else "finished",
|
125 |
+
time.time() - start_time,
|
126 |
+
)
|
127 |
+
)
|
128 |
+
|
129 |
+
## Below are raw outputs.
|
130 |
+
# panoptic_seg, segments_info = predictions["panoptic_seg"]
|
131 |
+
# print(panoptic_seg.shape, segments_info)
|
132 |
+
|
133 |
+
if args.output:
|
134 |
+
if os.path.isdir(args.output):
|
135 |
+
assert os.path.isdir(args.output), args.output
|
136 |
+
out_filename = os.path.join(args.output, os.path.basename(path))
|
137 |
+
else:
|
138 |
+
assert len(args.input) == 1, "Please specify a directory with args.output"
|
139 |
+
out_filename = args.output
|
140 |
+
visualized_output.save(out_filename)
|
141 |
+
else:
|
142 |
+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
143 |
+
cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
|
144 |
+
if cv2.waitKey(0) == 27:
|
145 |
+
break # esc to quit
|
146 |
+
elif args.webcam:
|
147 |
+
assert args.input is None, "Cannot have both --input and --webcam!"
|
148 |
+
assert args.output is None, "output not yet supported with --webcam!"
|
149 |
+
cam = cv2.VideoCapture(0)
|
150 |
+
for vis in tqdm.tqdm(demo.run_on_video(cam)):
|
151 |
+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
152 |
+
cv2.imshow(WINDOW_NAME, vis)
|
153 |
+
if cv2.waitKey(1) == 27:
|
154 |
+
break # esc to quit
|
155 |
+
cam.release()
|
156 |
+
cv2.destroyAllWindows()
|
demo/predictor.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py
|
3 |
+
import atexit
|
4 |
+
import bisect
|
5 |
+
import multiprocessing as mp
|
6 |
+
from collections import deque
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from detectron2.data import MetadataCatalog
|
12 |
+
from detectron2.engine.defaults import DefaultPredictor
|
13 |
+
from detectron2.utils.video_visualizer import VideoVisualizer
|
14 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
15 |
+
|
16 |
+
|
17 |
+
class VisualizationDemo(object):
|
18 |
+
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
|
19 |
+
"""
|
20 |
+
Args:
|
21 |
+
cfg (CfgNode):
|
22 |
+
instance_mode (ColorMode):
|
23 |
+
parallel (bool): whether to run the model in different processes from visualization.
|
24 |
+
Useful since the visualization logic can be slow.
|
25 |
+
"""
|
26 |
+
self.metadata = MetadataCatalog.get(
|
27 |
+
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
|
28 |
+
)
|
29 |
+
self.cpu_device = torch.device("cpu")
|
30 |
+
self.instance_mode = instance_mode
|
31 |
+
|
32 |
+
self.parallel = parallel
|
33 |
+
if parallel:
|
34 |
+
num_gpu = torch.cuda.device_count()
|
35 |
+
self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
|
36 |
+
else:
|
37 |
+
self.predictor = DefaultPredictor(cfg)
|
38 |
+
|
39 |
+
def run_on_image(self, image):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
43 |
+
This is the format used by OpenCV.
|
44 |
+
Returns:
|
45 |
+
predictions (dict): the output of the model.
|
46 |
+
vis_output (VisImage): the visualized image output.
|
47 |
+
"""
|
48 |
+
vis_output = None
|
49 |
+
predictions = self.predictor(image)
|
50 |
+
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
51 |
+
image = image[:, :, ::-1]
|
52 |
+
visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
|
53 |
+
if "panoptic_seg" in predictions:
|
54 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
55 |
+
vis_output = visualizer.draw_panoptic_seg_predictions(
|
56 |
+
panoptic_seg.to(self.cpu_device), segments_info
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
if "sem_seg" in predictions:
|
60 |
+
vis_output = visualizer.draw_sem_seg(
|
61 |
+
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
62 |
+
)
|
63 |
+
if "instances" in predictions:
|
64 |
+
instances = predictions["instances"].to(self.cpu_device)
|
65 |
+
vis_output = visualizer.draw_instance_predictions(predictions=instances)
|
66 |
+
|
67 |
+
return predictions, vis_output
|
68 |
+
|
69 |
+
def _frame_from_video(self, video):
|
70 |
+
while video.isOpened():
|
71 |
+
success, frame = video.read()
|
72 |
+
if success:
|
73 |
+
yield frame
|
74 |
+
else:
|
75 |
+
break
|
76 |
+
|
77 |
+
|
78 |
+
class AsyncPredictor:
|
79 |
+
"""
|
80 |
+
A predictor that runs the model asynchronously, possibly on >1 GPUs.
|
81 |
+
Because rendering the visualization takes considerably amount of time,
|
82 |
+
this helps improve throughput a little bit when rendering videos.
|
83 |
+
"""
|
84 |
+
|
85 |
+
class _StopToken:
|
86 |
+
pass
|
87 |
+
|
88 |
+
class _PredictWorker(mp.Process):
|
89 |
+
def __init__(self, cfg, task_queue, result_queue):
|
90 |
+
self.cfg = cfg
|
91 |
+
self.task_queue = task_queue
|
92 |
+
self.result_queue = result_queue
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
def run(self):
|
96 |
+
predictor = DefaultPredictor(self.cfg)
|
97 |
+
|
98 |
+
while True:
|
99 |
+
task = self.task_queue.get()
|
100 |
+
if isinstance(task, AsyncPredictor._StopToken):
|
101 |
+
break
|
102 |
+
idx, data = task
|
103 |
+
result = predictor(data)
|
104 |
+
self.result_queue.put((idx, result))
|
105 |
+
|
106 |
+
def __init__(self, cfg, num_gpus: int = 1):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
cfg (CfgNode):
|
110 |
+
num_gpus (int): if 0, will run on CPU
|
111 |
+
"""
|
112 |
+
num_workers = max(num_gpus, 1)
|
113 |
+
self.task_queue = mp.Queue(maxsize=num_workers * 3)
|
114 |
+
self.result_queue = mp.Queue(maxsize=num_workers * 3)
|
115 |
+
self.procs = []
|
116 |
+
for gpuid in range(max(num_gpus, 1)):
|
117 |
+
cfg = cfg.clone()
|
118 |
+
cfg.defrost()
|
119 |
+
cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
|
120 |
+
self.procs.append(
|
121 |
+
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
|
122 |
+
)
|
123 |
+
|
124 |
+
self.put_idx = 0
|
125 |
+
self.get_idx = 0
|
126 |
+
self.result_rank = []
|
127 |
+
self.result_data = []
|
128 |
+
|
129 |
+
for p in self.procs:
|
130 |
+
p.start()
|
131 |
+
atexit.register(self.shutdown)
|
132 |
+
|
133 |
+
def put(self, image):
|
134 |
+
self.put_idx += 1
|
135 |
+
self.task_queue.put((self.put_idx, image))
|
136 |
+
|
137 |
+
def get(self):
|
138 |
+
self.get_idx += 1 # the index needed for this request
|
139 |
+
if len(self.result_rank) and self.result_rank[0] == self.get_idx:
|
140 |
+
res = self.result_data[0]
|
141 |
+
del self.result_data[0], self.result_rank[0]
|
142 |
+
return res
|
143 |
+
|
144 |
+
while True:
|
145 |
+
# make sure the results are returned in the correct order
|
146 |
+
idx, res = self.result_queue.get()
|
147 |
+
if idx == self.get_idx:
|
148 |
+
return res
|
149 |
+
insert = bisect.bisect(self.result_rank, idx)
|
150 |
+
self.result_rank.insert(insert, idx)
|
151 |
+
self.result_data.insert(insert, res)
|
152 |
+
|
153 |
+
def __len__(self):
|
154 |
+
return self.put_idx - self.get_idx
|
155 |
+
|
156 |
+
def __call__(self, image):
|
157 |
+
self.put(image)
|
158 |
+
return self.get()
|
159 |
+
|
160 |
+
def shutdown(self):
|
161 |
+
for _ in self.procs:
|
162 |
+
self.task_queue.put(AsyncPredictor._StopToken())
|
163 |
+
|
164 |
+
@property
|
165 |
+
def default_buffer_size(self):
|
166 |
+
return len(self.procs) * 5
|
docs/clustering_view_of_mask_transformer.png
ADDED
docs/kmax_decoder.png
ADDED
kmax_deeplab/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import data # register all new datasets
|
2 |
+
from . import modeling
|
3 |
+
|
4 |
+
# config
|
5 |
+
from .config import add_kmax_deeplab_config
|
6 |
+
|
7 |
+
# dataset loading
|
8 |
+
from .data.dataset_mappers.coco_panoptic_kmaxdeeplab_dataset_mapper import COCOPanoptickMaXDeepLabDatasetMapper
|
9 |
+
|
10 |
+
|
11 |
+
# models
|
12 |
+
from .kmax_model import kMaXDeepLab
|
13 |
+
|
14 |
+
# evaluation
|
15 |
+
from .evaluation.instance_evaluation import InstanceSegEvaluator
|
kmax_deeplab/config.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from detectron2.config import CfgNode as CN
|
3 |
+
|
4 |
+
|
5 |
+
def add_kmax_deeplab_config(cfg):
|
6 |
+
"""
|
7 |
+
Add config for KMAX_DEEPLAB.
|
8 |
+
"""
|
9 |
+
# NOTE: configs from original maskformer
|
10 |
+
# data config
|
11 |
+
# select the dataset mapper
|
12 |
+
cfg.INPUT.DATASET_MAPPER_NAME = "coco_panoptic_kmaxdeeplab"
|
13 |
+
# Color augmentation
|
14 |
+
# Pad image and segmentation GT in dataset mapper.
|
15 |
+
cfg.INPUT.SIZE_DIVISIBILITY = -1
|
16 |
+
|
17 |
+
# solver config
|
18 |
+
# weight decay on embedding
|
19 |
+
cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.05
|
20 |
+
# optimizer
|
21 |
+
cfg.SOLVER.OPTIMIZER = "ADAMW"
|
22 |
+
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
|
23 |
+
|
24 |
+
# kMaX-DeepLab model config
|
25 |
+
cfg.MODEL.KMAX_DEEPLAB = CN()
|
26 |
+
|
27 |
+
# whether to share matching results
|
28 |
+
cfg.MODEL.KMAX_DEEPLAB.SHARE_FINAL_MATCHING = True
|
29 |
+
|
30 |
+
# vis
|
31 |
+
cfg.MODEL.KMAX_DEEPLAB.SAVE_VIS_NUM = 0
|
32 |
+
|
33 |
+
# loss
|
34 |
+
cfg.MODEL.KMAX_DEEPLAB.DEEP_SUPERVISION = True
|
35 |
+
cfg.MODEL.KMAX_DEEPLAB.SKIP_CONN_INIT_VALUE = 0.0
|
36 |
+
cfg.MODEL.KMAX_DEEPLAB.NO_OBJECT_WEIGHT = 1e-5
|
37 |
+
cfg.MODEL.KMAX_DEEPLAB.CLASS_WEIGHT = 3.0
|
38 |
+
cfg.MODEL.KMAX_DEEPLAB.DICE_WEIGHT = 3.0
|
39 |
+
cfg.MODEL.KMAX_DEEPLAB.MASK_WEIGHT = 0.3
|
40 |
+
cfg.MODEL.KMAX_DEEPLAB.INSDIS_WEIGHT = 1.0
|
41 |
+
cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_WEIGHT = 1.0
|
42 |
+
|
43 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_INSDIS_TEMPERATURE = 1.5
|
44 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_INSDIS_SAMPLE_K = 4096
|
45 |
+
cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_TEMPERATURE = 2.0
|
46 |
+
cfg.MODEL.KMAX_DEEPLAB.UX_SEMANTIC_SAMPLE_K = 4096
|
47 |
+
|
48 |
+
|
49 |
+
# pixel decoder config
|
50 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC = CN()
|
51 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.NAME = "kMaXPixelDecoder"
|
52 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES = ['res2', 'res3', 'res4', 'res5']
|
53 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_LAYERS = [1, 5, 1, 1]
|
54 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.LAYER_TYPES = ["axial", "axial", "bottleneck", "bottleneck"]
|
55 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_CHANNELS = [512, 256, 128, 64]
|
56 |
+
cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DROP_PATH_PROB = 0.0
|
57 |
+
|
58 |
+
# transformer decoder config
|
59 |
+
cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC = CN()
|
60 |
+
cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NAME = "kMaXTransformerDecoder"
|
61 |
+
cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DEC_LAYERS = [2, 2, 2]
|
62 |
+
cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NUM_OBJECT_QUERIES = 128
|
63 |
+
cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.IN_CHANNELS = [2048, 1024, 512]
|
64 |
+
cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DROP_PATH_PROB = 0.0
|
65 |
+
|
66 |
+
# kMaX-DeepLab inference config
|
67 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST = CN()
|
68 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON = False
|
69 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON = False
|
70 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON = True
|
71 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.OBJECT_MASK_THRESHOLD = 0.4
|
72 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.CLASS_THRESHOLD_THING = 0.7
|
73 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.CLASS_THRESHOLD_STUFF = 0.5
|
74 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.REORDER_CLASS_WEIGHT = 1.0
|
75 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.REORDER_MASK_WEIGHT = 1.0
|
76 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.OVERLAP_THRESHOLD = 0.8
|
77 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
|
78 |
+
|
79 |
+
# Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
|
80 |
+
# you can use this config to override
|
81 |
+
cfg.MODEL.KMAX_DEEPLAB.SIZE_DIVISIBILITY = -1
|
82 |
+
|
83 |
+
# https://github.com/SHI-Labs/OneFormer/blob/main/oneformer/config.py#L197
|
84 |
+
cfg.MODEL.CONVNEXT = CN()
|
85 |
+
cfg.MODEL.CONVNEXT.IN_CHANNELS = 3
|
86 |
+
cfg.MODEL.CONVNEXT.DEPTHS = [3, 3, 27, 3]
|
87 |
+
cfg.MODEL.CONVNEXT.DIMS = [192, 384, 768, 1536]
|
88 |
+
cfg.MODEL.CONVNEXT.DROP_PATH_RATE = 0.6
|
89 |
+
cfg.MODEL.CONVNEXT.LSIT = 1e-6
|
90 |
+
cfg.MODEL.CONVNEXT.OUT_INDICES = [0, 1, 2, 3]
|
91 |
+
cfg.MODEL.CONVNEXT.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
92 |
+
|
93 |
+
cfg.INPUT.IMAGE_SIZE = [1281, 1281]
|
94 |
+
cfg.INPUT.MIN_SCALE = 0.2
|
95 |
+
cfg.INPUT.MAX_SCALE = 2.0
|
96 |
+
|
kmax_deeplab/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import datasets
|
kmax_deeplab/data/dataset_mappers/__init__.py
ADDED
File without changes
|
kmax_deeplab/data/dataset_mappers/coco_panoptic_kmaxdeeplab_dataset_mapper.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py
|
2 |
+
# modified by Qihang Yu
|
3 |
+
import copy
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import random
|
9 |
+
|
10 |
+
from detectron2.config import configurable
|
11 |
+
from detectron2.data import detection_utils as utils
|
12 |
+
from detectron2.data import transforms as T
|
13 |
+
from detectron2.projects.point_rend import ColorAugSSDTransform
|
14 |
+
from detectron2.structures import BitMasks, Boxes, Instances
|
15 |
+
|
16 |
+
import os
|
17 |
+
|
18 |
+
__all__ = ["COCOPanoptickMaXDeepLabDatasetMapper"]
|
19 |
+
|
20 |
+
|
21 |
+
def build_transform_gen(cfg, is_train, scale_ratio=1.0):
|
22 |
+
"""
|
23 |
+
Create a list of default :class:`Augmentation` from config.
|
24 |
+
Now it includes resizing and flipping.
|
25 |
+
Returns:
|
26 |
+
list[Augmentation]
|
27 |
+
"""
|
28 |
+
image_size = cfg.INPUT.IMAGE_SIZE
|
29 |
+
assert is_train
|
30 |
+
|
31 |
+
min_scale = cfg.INPUT.MIN_SCALE * scale_ratio
|
32 |
+
max_scale = cfg.INPUT.MAX_SCALE * scale_ratio
|
33 |
+
|
34 |
+
|
35 |
+
augmentation = [
|
36 |
+
T.ResizeScale(
|
37 |
+
min_scale=min_scale, max_scale=max_scale, target_height=image_size[0], target_width=image_size[1]
|
38 |
+
),
|
39 |
+
ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT),
|
40 |
+
T.RandomCrop(crop_type="absolute", crop_size=(image_size[0], image_size[1])),
|
41 |
+
T.RandomFlip(),
|
42 |
+
]
|
43 |
+
|
44 |
+
return augmentation
|
45 |
+
|
46 |
+
|
47 |
+
class COCOPanoptickMaXDeepLabDatasetMapper:
|
48 |
+
"""
|
49 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
50 |
+
and map it into a format used by kMaX-DeepLab.
|
51 |
+
|
52 |
+
The callable currently does the following:
|
53 |
+
|
54 |
+
1. Read the image from "file_name"
|
55 |
+
2. Applies geometric transforms to the image and annotation
|
56 |
+
3. Find and applies suitable cropping to the image and annotation
|
57 |
+
4. Prepare image and annotation to Tensors
|
58 |
+
"""
|
59 |
+
|
60 |
+
@configurable
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
is_train=True,
|
64 |
+
*,
|
65 |
+
tfm_gens,
|
66 |
+
tfm_gens_copy_paste,
|
67 |
+
image_format,
|
68 |
+
image_size,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
NOTE: this interface is experimental.
|
72 |
+
Args:
|
73 |
+
is_train: for training or inference
|
74 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
75 |
+
tfm_gens: data augmentation
|
76 |
+
tfm_gens_copy_paste: data augmentation
|
77 |
+
image_format: an image format supported by :func:`detection_utils.read_image`
|
78 |
+
image_size: expected image size
|
79 |
+
"""
|
80 |
+
self.tfm_gens = tfm_gens
|
81 |
+
self.tfm_gens_copy_paste = tfm_gens_copy_paste
|
82 |
+
if is_train:
|
83 |
+
logging.getLogger(__name__).info(
|
84 |
+
"[COCOPanopticDeepLab2DatasetMapper] Full TransformGens used in training: {}, {}".format(
|
85 |
+
str(self.tfm_gens), str(self.tfm_gens_copy_paste)
|
86 |
+
)
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
logging.getLogger(__name__).info(
|
90 |
+
"[COCOPanopticDeepLab2DatasetMapper] Full TransformGens used in testing: {}".format(
|
91 |
+
str(self.tfm_gens)
|
92 |
+
)
|
93 |
+
)
|
94 |
+
self.img_format = image_format
|
95 |
+
self.is_train = is_train
|
96 |
+
self.image_size = image_size
|
97 |
+
|
98 |
+
dataset_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
99 |
+
image_dir = os.path.join(dataset_root, "coco/train2017")
|
100 |
+
gt_dir = os.path.join(dataset_root, "coco/panoptic_train2017")
|
101 |
+
semseg_dir = os.path.join(dataset_root, "coco/panoptic_semseg_train2017")
|
102 |
+
json_file = os.path.join(dataset_root, "coco/annotations/panoptic_train2017.json")
|
103 |
+
from ..datasets import register_coco_panoptic_annos_semseg
|
104 |
+
meta_data = register_coco_panoptic_annos_semseg.get_metadata()
|
105 |
+
self.dataset_dict_all = register_coco_panoptic_annos_semseg.load_coco_panoptic_json(
|
106 |
+
json_file, image_dir, gt_dir, semseg_dir, meta_data
|
107 |
+
)
|
108 |
+
self.filename2idx = {}
|
109 |
+
for idx, dataset_dict in enumerate(self.dataset_dict_all):
|
110 |
+
self.filename2idx[dataset_dict["file_name"].split('/')[-1].replace('.jpg', '')] = idx
|
111 |
+
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def from_config(cls, cfg, is_train=True):
|
115 |
+
# Build augmentation
|
116 |
+
tfm_gens = build_transform_gen(cfg, is_train)
|
117 |
+
tfm_gens_copy_paste = build_transform_gen(cfg, is_train, scale_ratio=0.5)
|
118 |
+
ret = {
|
119 |
+
"is_train": is_train,
|
120 |
+
"tfm_gens": tfm_gens,
|
121 |
+
"tfm_gens_copy_paste": tfm_gens_copy_paste,
|
122 |
+
"image_format": cfg.INPUT.FORMAT,
|
123 |
+
"image_size": cfg.INPUT.IMAGE_SIZE
|
124 |
+
}
|
125 |
+
return ret
|
126 |
+
|
127 |
+
def read_dataset_dict(self, dataset_dict, is_copy_paste=False):
|
128 |
+
"""
|
129 |
+
Args:
|
130 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
dict: a format that builtin models in detectron2 accept
|
134 |
+
"""
|
135 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
136 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
137 |
+
utils.check_image_size(dataset_dict, image)
|
138 |
+
|
139 |
+
if not is_copy_paste:
|
140 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
141 |
+
else:
|
142 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens_copy_paste, image)
|
143 |
+
|
144 |
+
dataset_dict["image"] = np.ascontiguousarray(image.transpose(2, 0, 1))
|
145 |
+
|
146 |
+
if not self.is_train:
|
147 |
+
dataset_dict.pop("annotations", None)
|
148 |
+
return dataset_dict, None
|
149 |
+
|
150 |
+
# We pad the image manually, for copy-paste purpose.
|
151 |
+
padded_image = np.zeros((3, self.image_size[0], self.image_size[1]), dtype=dataset_dict["image"].dtype)
|
152 |
+
new_h, new_w = dataset_dict["image"].shape[1:]
|
153 |
+
offset_h, offset_w = 0, 0 # following the d2 panoptic deeplab implementaiton to only perform bottom/right padding.
|
154 |
+
padded_image[:, offset_h:offset_h+new_h, offset_w:offset_w+new_w] = dataset_dict["image"]
|
155 |
+
dataset_dict["image"] = padded_image
|
156 |
+
if "pan_seg_file_name" in dataset_dict:
|
157 |
+
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
|
158 |
+
|
159 |
+
# apply the same transformation to panoptic segmentation
|
160 |
+
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
|
161 |
+
|
162 |
+
from panopticapi.utils import rgb2id
|
163 |
+
|
164 |
+
pan_seg_gt = rgb2id(pan_seg_gt) # int32 # H x W
|
165 |
+
# similarily, we manually pad the label, and we use label -1 to indicate those padded pixels.
|
166 |
+
# In this way, we can masking out the padded pixels values to 0 after normalization, which aligns the
|
167 |
+
# behavior between training and testing.
|
168 |
+
padded_pan_seg_gt = np.zeros((self.image_size[0], self.image_size[1]), dtype=pan_seg_gt.dtype)
|
169 |
+
is_real_pixels = np.zeros((self.image_size[0], self.image_size[1]), dtype=np.bool)
|
170 |
+
padded_pan_seg_gt[offset_h:offset_h+new_h, offset_w:offset_w+new_w] = pan_seg_gt
|
171 |
+
is_real_pixels[offset_h:offset_h+new_h, offset_w:offset_w+new_w] = True
|
172 |
+
dataset_dict["is_real_pixels"] = is_real_pixels
|
173 |
+
pan_seg_gt = padded_pan_seg_gt
|
174 |
+
return dataset_dict, pan_seg_gt
|
175 |
+
|
176 |
+
# This should never happen.
|
177 |
+
raise NotImplementedError
|
178 |
+
|
179 |
+
def call_copypaste(self, dataset_dict):
|
180 |
+
"""
|
181 |
+
Args:
|
182 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
dict: a format that builtin models in detectron2 accept
|
186 |
+
"""
|
187 |
+
# Read main image.
|
188 |
+
dataset_dict, pan_seg_gt = self.read_dataset_dict(dataset_dict, is_copy_paste=False)
|
189 |
+
# Read copy-paste image.
|
190 |
+
# We use the last number as a bias to random number, in case same random numbers are generated across devices.
|
191 |
+
main_image_idx = self.filename2idx[dataset_dict["file_name"].split('/')[-1].replace('.jpg', '')]
|
192 |
+
random_image_idx = main_image_idx + random.randint(0, len(self.dataset_dict_all) - 1)
|
193 |
+
random_image_idx = random_image_idx % len(self.dataset_dict_all)
|
194 |
+
dataset_dict_copy_paste = copy.deepcopy(self.dataset_dict_all[random_image_idx])
|
195 |
+
dataset_dict_copy_paste, pan_seg_gt_copy_paste = self.read_dataset_dict(dataset_dict_copy_paste, is_copy_paste=True)
|
196 |
+
|
197 |
+
# Copy data_dict_copy_paste onto data_dict. 0 means keep original pixel, 1 means use copy-paste pixel.
|
198 |
+
copy_paste_masks = np.zeros((pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
199 |
+
|
200 |
+
segments_info_copy_paste = dataset_dict_copy_paste["segments_info"]
|
201 |
+
all_ids = []
|
202 |
+
thing_ids = []
|
203 |
+
for segment_info_copy_paste in segments_info_copy_paste:
|
204 |
+
class_id = segment_info_copy_paste["category_id"]
|
205 |
+
if not segment_info_copy_paste["iscrowd"]:
|
206 |
+
# -1 is reserved for padded pixels.
|
207 |
+
if segment_info_copy_paste["id"] in [-1, 0]:
|
208 |
+
print(segment_info_copy_paste)
|
209 |
+
raise ValueError("id should not be -1, 0")
|
210 |
+
all_ids.append(segment_info_copy_paste["id"])
|
211 |
+
if segment_info_copy_paste["isthing"]: # All thing classes are copy-pasted.
|
212 |
+
thing_ids.append(segment_info_copy_paste["id"])
|
213 |
+
|
214 |
+
# Shuffle and randomly select kept label ids.
|
215 |
+
random.shuffle(all_ids)
|
216 |
+
keep_number = random.randint(0, len(all_ids))
|
217 |
+
|
218 |
+
for index, label_id in enumerate(all_ids):
|
219 |
+
# randomly copy labels, but keep all thing classes.
|
220 |
+
if index < keep_number or label_id in thing_ids:
|
221 |
+
copy_paste_masks[pan_seg_gt_copy_paste == label_id] = 1
|
222 |
+
|
223 |
+
# We merge the image and copy-paste image based on the copy-paste mask.
|
224 |
+
dataset_dict["image"] = (dataset_dict["image"] * (1.0 - copy_paste_masks).astype(dataset_dict["image"].dtype) +
|
225 |
+
dataset_dict_copy_paste["image"] * copy_paste_masks.astype(dataset_dict["image"].dtype))
|
226 |
+
dataset_dict["image"] = torch.as_tensor(dataset_dict["image"])
|
227 |
+
|
228 |
+
dataset_dict["is_real_pixels"] = (dataset_dict["is_real_pixels"] * (1.0 - copy_paste_masks).astype(dataset_dict["is_real_pixels"].dtype) +
|
229 |
+
dataset_dict_copy_paste["is_real_pixels"] * copy_paste_masks.astype(dataset_dict["is_real_pixels"].dtype))
|
230 |
+
dataset_dict["is_real_pixels"] = torch.as_tensor(dataset_dict["is_real_pixels"])
|
231 |
+
# We set all ids in copy-paste image to be negative, so that there will be no overlap between original id and copy-paste id.
|
232 |
+
pan_seg_gt_copy_paste = -pan_seg_gt_copy_paste
|
233 |
+
pan_seg_gt = (pan_seg_gt * (1.0 - copy_paste_masks).astype(pan_seg_gt.dtype) +
|
234 |
+
pan_seg_gt_copy_paste * copy_paste_masks.astype(pan_seg_gt.dtype))
|
235 |
+
|
236 |
+
# We use 4x downsampled gt for final supervision.
|
237 |
+
pan_seg_gt = pan_seg_gt[::4, ::4]
|
238 |
+
sem_seg_gt = -np.ones_like(pan_seg_gt) # H x W, init with -1
|
239 |
+
|
240 |
+
# We then process the obtained pan_seg_gt to training format.
|
241 |
+
image_shape = dataset_dict["image"].shape[1:] # h, w
|
242 |
+
segments_info = dataset_dict["segments_info"]
|
243 |
+
instances = Instances(image_shape)
|
244 |
+
classes = []
|
245 |
+
masks = []
|
246 |
+
valid_pixel_num = 0
|
247 |
+
# As the two images may share same stuff classes, we use a dict to track existing stuff and merge them.
|
248 |
+
stuff_class_to_idx = {}
|
249 |
+
for segment_info in segments_info:
|
250 |
+
class_id = segment_info["category_id"]
|
251 |
+
if not segment_info["iscrowd"]:
|
252 |
+
# -1 is reserved to indicate padded pixels.
|
253 |
+
if segment_info["id"] in [-1, 0]:
|
254 |
+
print(segment_info)
|
255 |
+
raise ValueError("id should not be -1, 0")
|
256 |
+
binary_mask = (pan_seg_gt == segment_info["id"])
|
257 |
+
# As it is possible that some masks are removed during the copy-paste process, we need
|
258 |
+
# to double check if the maks exists.
|
259 |
+
valid_pixel_num_ = binary_mask.sum()
|
260 |
+
valid_pixel_num += valid_pixel_num_
|
261 |
+
if valid_pixel_num_ > 0:
|
262 |
+
sem_seg_gt[binary_mask] = class_id
|
263 |
+
if not segment_info["isthing"]:
|
264 |
+
# For original image, stuff should only appear once.
|
265 |
+
if class_id in stuff_class_to_idx:
|
266 |
+
raise ValueError('class_id should not already be in stuff_class_to_idx!')
|
267 |
+
else:
|
268 |
+
stuff_class_to_idx[class_id] = len(masks)
|
269 |
+
classes.append(class_id)
|
270 |
+
masks.append(binary_mask)
|
271 |
+
|
272 |
+
for segment_info in segments_info_copy_paste:
|
273 |
+
class_id = segment_info["category_id"]
|
274 |
+
if not segment_info["iscrowd"]:
|
275 |
+
# -1 is reserved to indicate padded pixels.
|
276 |
+
if segment_info["id"] in [-1, 0]:
|
277 |
+
print(segment_info)
|
278 |
+
raise ValueError("id should not be -1, 0")
|
279 |
+
# Note that copy-paste id is negative.
|
280 |
+
binary_mask = (pan_seg_gt == -segment_info["id"])
|
281 |
+
valid_pixel_num_ = binary_mask.sum()
|
282 |
+
valid_pixel_num += valid_pixel_num_
|
283 |
+
if valid_pixel_num_ > 0:
|
284 |
+
sem_seg_gt[binary_mask] = class_id
|
285 |
+
if not segment_info["isthing"]:
|
286 |
+
# The stuff in copy-paste image already appeared in original image.
|
287 |
+
if class_id in stuff_class_to_idx:
|
288 |
+
# Merge into original stuff masks.
|
289 |
+
masks[stuff_class_to_idx[class_id]] = np.logical_or(masks[stuff_class_to_idx[class_id]], binary_mask)
|
290 |
+
continue
|
291 |
+
else:
|
292 |
+
stuff_class_to_idx[class_id] = len(masks)
|
293 |
+
classes.append(class_id)
|
294 |
+
masks.append(binary_mask)
|
295 |
+
|
296 |
+
classes = np.array(classes)
|
297 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
298 |
+
sem_seg_gt = torch.tensor(sem_seg_gt, dtype=torch.int64)
|
299 |
+
|
300 |
+
if len(masks) == 0:
|
301 |
+
# Some image does not have annotation (all ignored)
|
302 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
303 |
+
instances.gt_boxes = Boxes(torch.zeros((0, 4)))
|
304 |
+
else:
|
305 |
+
masks = BitMasks(
|
306 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
307 |
+
)
|
308 |
+
instances.gt_masks = masks.tensor
|
309 |
+
instances.gt_boxes = masks.get_bounding_boxes()
|
310 |
+
|
311 |
+
dataset_dict["instances"] = instances
|
312 |
+
dataset_dict["sem_seg_gt"] = sem_seg_gt
|
313 |
+
dataset_dict["valid_pixel_num"] = valid_pixel_num
|
314 |
+
return dataset_dict
|
315 |
+
|
316 |
+
def __call__(self, dataset_dict):
|
317 |
+
res = self.call_copypaste(dataset_dict)
|
318 |
+
while ("instances" in res and res["instances"].gt_masks.shape[0] == 0) or ("valid_pixel_num" in res and res["valid_pixel_num"] <= 4096):
|
319 |
+
# this gt is empty or contains too many void pixels, let's re-generate one.
|
320 |
+
main_image_idx = self.filename2idx[dataset_dict["file_name"].split('/')[-1].replace('.jpg', '')]
|
321 |
+
random_image_idx = main_image_idx + random.randint(0, len(self.dataset_dict_all) - 1)
|
322 |
+
random_image_idx = random_image_idx % len(self.dataset_dict_all)
|
323 |
+
dataset_dict = self.dataset_dict_all[random_image_idx]
|
324 |
+
res = self.call_copypaste(dataset_dict)
|
325 |
+
|
326 |
+
return res
|
kmax_deeplab/data/datasets/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import (
|
2 |
+
register_coco_panoptic_annos_semseg,
|
3 |
+
)
|
kmax_deeplab/data/datasets/register_coco_panoptic_annos_semseg.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_coco_panoptic_annos_semseg.py
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
7 |
+
from detectron2.data.datasets import load_sem_seg
|
8 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
9 |
+
from detectron2.utils.file_io import PathManager
|
10 |
+
|
11 |
+
|
12 |
+
_PREDEFINED_SPLITS_COCO_PANOPTIC = {
|
13 |
+
"coco_2017_train_panoptic": (
|
14 |
+
# This is the original panoptic annotation directory
|
15 |
+
"coco/panoptic_train2017",
|
16 |
+
"coco/annotations/panoptic_train2017.json",
|
17 |
+
# This directory contains semantic annotations that are
|
18 |
+
# converted from panoptic annotations.
|
19 |
+
# It is used by PanopticFPN.
|
20 |
+
# You can use the script at detectron2/datasets/prepare_panoptic_fpn.py
|
21 |
+
# to create these directories.
|
22 |
+
"coco/panoptic_semseg_train2017",
|
23 |
+
),
|
24 |
+
"coco_2017_val_panoptic": (
|
25 |
+
"coco/panoptic_val2017",
|
26 |
+
"coco/annotations/panoptic_val2017.json",
|
27 |
+
"coco/panoptic_semseg_val2017",
|
28 |
+
),
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def get_metadata():
|
33 |
+
meta = {}
|
34 |
+
# The following metadata maps contiguous id from [0, #thing categories +
|
35 |
+
# #stuff categories) to their names and colors. We have to replica of the
|
36 |
+
# same name and color under "thing_*" and "stuff_*" because the current
|
37 |
+
# visualization function in D2 handles thing and class classes differently
|
38 |
+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
|
39 |
+
# enable reusing existing visualization functions.
|
40 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
41 |
+
thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
42 |
+
stuff_classes = [k["name"] for k in COCO_CATEGORIES]
|
43 |
+
stuff_colors = [k["color"] for k in COCO_CATEGORIES]
|
44 |
+
|
45 |
+
meta["thing_classes"] = thing_classes
|
46 |
+
meta["thing_colors"] = thing_colors
|
47 |
+
meta["stuff_classes"] = stuff_classes
|
48 |
+
meta["stuff_colors"] = stuff_colors
|
49 |
+
|
50 |
+
# Convert category id for training:
|
51 |
+
# category id: like semantic segmentation, it is the class id for each
|
52 |
+
# pixel. Since there are some classes not used in evaluation, the category
|
53 |
+
# id is not always contiguous and thus we have two set of category ids:
|
54 |
+
# - original category id: category id in the original dataset, mainly
|
55 |
+
# used for evaluation.
|
56 |
+
# - contiguous category id: [0, #classes), in order to train the linear
|
57 |
+
# softmax classifier.
|
58 |
+
thing_dataset_id_to_contiguous_id = {}
|
59 |
+
stuff_dataset_id_to_contiguous_id = {}
|
60 |
+
|
61 |
+
for i, cat in enumerate(COCO_CATEGORIES):
|
62 |
+
if cat["isthing"]:
|
63 |
+
thing_dataset_id_to_contiguous_id[cat["id"]] = i
|
64 |
+
# else:
|
65 |
+
# stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
66 |
+
|
67 |
+
# in order to use sem_seg evaluator
|
68 |
+
stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
69 |
+
|
70 |
+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
71 |
+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
72 |
+
|
73 |
+
return meta
|
74 |
+
|
75 |
+
|
76 |
+
def load_coco_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
|
80 |
+
gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
|
81 |
+
json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
|
82 |
+
Returns:
|
83 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
84 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
85 |
+
"""
|
86 |
+
|
87 |
+
def _convert_category_id(segment_info, meta):
|
88 |
+
if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
|
89 |
+
segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
|
90 |
+
segment_info["category_id"]
|
91 |
+
]
|
92 |
+
segment_info["isthing"] = True
|
93 |
+
else:
|
94 |
+
segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
|
95 |
+
segment_info["category_id"]
|
96 |
+
]
|
97 |
+
segment_info["isthing"] = False
|
98 |
+
return segment_info
|
99 |
+
|
100 |
+
with PathManager.open(json_file) as f:
|
101 |
+
json_info = json.load(f)
|
102 |
+
|
103 |
+
ret = []
|
104 |
+
for ann in json_info["annotations"]:
|
105 |
+
image_id = int(ann["image_id"])
|
106 |
+
# TODO: currently we assume image and label has the same filename but
|
107 |
+
# different extension, and images have extension ".jpg" for COCO. Need
|
108 |
+
# to make image extension a user-provided argument if we extend this
|
109 |
+
# function to support other COCO-like datasets.
|
110 |
+
image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
|
111 |
+
label_file = os.path.join(gt_dir, ann["file_name"])
|
112 |
+
sem_label_file = os.path.join(semseg_dir, ann["file_name"])
|
113 |
+
segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
|
114 |
+
ret.append(
|
115 |
+
{
|
116 |
+
"file_name": image_file,
|
117 |
+
"image_id": image_id,
|
118 |
+
"pan_seg_file_name": label_file,
|
119 |
+
"sem_seg_file_name": sem_label_file,
|
120 |
+
"segments_info": segments_info,
|
121 |
+
}
|
122 |
+
)
|
123 |
+
assert len(ret), f"No images found in {image_dir}!"
|
124 |
+
assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
|
125 |
+
assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
|
126 |
+
assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
|
127 |
+
return ret
|
128 |
+
|
129 |
+
|
130 |
+
def register_coco_panoptic_annos_sem_seg(
|
131 |
+
name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json
|
132 |
+
):
|
133 |
+
panoptic_name = name
|
134 |
+
delattr(MetadataCatalog.get(panoptic_name), "thing_classes")
|
135 |
+
delattr(MetadataCatalog.get(panoptic_name), "thing_colors")
|
136 |
+
MetadataCatalog.get(panoptic_name).set(
|
137 |
+
thing_classes=metadata["thing_classes"],
|
138 |
+
thing_colors=metadata["thing_colors"],
|
139 |
+
# thing_dataset_id_to_contiguous_id=metadata["thing_dataset_id_to_contiguous_id"],
|
140 |
+
)
|
141 |
+
|
142 |
+
# the name is "coco_2017_train_panoptic_with_sem_seg" and "coco_2017_val_panoptic_with_sem_seg"
|
143 |
+
semantic_name = name + "_with_sem_seg"
|
144 |
+
DatasetCatalog.register(
|
145 |
+
semantic_name,
|
146 |
+
lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, sem_seg_root, metadata),
|
147 |
+
)
|
148 |
+
MetadataCatalog.get(semantic_name).set(
|
149 |
+
sem_seg_root=sem_seg_root,
|
150 |
+
panoptic_root=panoptic_root,
|
151 |
+
image_root=image_root,
|
152 |
+
panoptic_json=panoptic_json,
|
153 |
+
json_file=instances_json,
|
154 |
+
evaluator_type="coco_panoptic_seg",
|
155 |
+
ignore_label=255,
|
156 |
+
label_divisor=1000,
|
157 |
+
**metadata,
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
def register_all_coco_panoptic_annos_sem_seg(root):
|
162 |
+
for (
|
163 |
+
prefix,
|
164 |
+
(panoptic_root, panoptic_json, semantic_root),
|
165 |
+
) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
|
166 |
+
prefix_instances = prefix[: -len("_panoptic")]
|
167 |
+
instances_meta = MetadataCatalog.get(prefix_instances)
|
168 |
+
image_root, instances_json = instances_meta.image_root, instances_meta.json_file
|
169 |
+
|
170 |
+
register_coco_panoptic_annos_sem_seg(
|
171 |
+
prefix,
|
172 |
+
get_metadata(),
|
173 |
+
image_root,
|
174 |
+
os.path.join(root, panoptic_root),
|
175 |
+
os.path.join(root, panoptic_json),
|
176 |
+
os.path.join(root, semantic_root),
|
177 |
+
instances_json,
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
182 |
+
register_all_coco_panoptic_annos_sem_seg(_root)
|
kmax_deeplab/evaluation/__init__.py
ADDED
File without changes
|
kmax_deeplab/evaluation/instance_evaluation.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/evaluation/instance_evaluation.py
|
2 |
+
import contextlib
|
3 |
+
import copy
|
4 |
+
import io
|
5 |
+
import itertools
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import pickle
|
11 |
+
from collections import OrderedDict
|
12 |
+
import pycocotools.mask as mask_util
|
13 |
+
import torch
|
14 |
+
from pycocotools.coco import COCO
|
15 |
+
from pycocotools.cocoeval import COCOeval
|
16 |
+
from tabulate import tabulate
|
17 |
+
|
18 |
+
import detectron2.utils.comm as comm
|
19 |
+
from detectron2.config import CfgNode
|
20 |
+
from detectron2.data import MetadataCatalog
|
21 |
+
from detectron2.data.datasets.coco import convert_to_coco_json
|
22 |
+
from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
|
23 |
+
from detectron2.evaluation.fast_eval_api import COCOeval_opt
|
24 |
+
from detectron2.structures import Boxes, BoxMode, pairwise_iou
|
25 |
+
from detectron2.utils.file_io import PathManager
|
26 |
+
from detectron2.utils.logger import create_small_table
|
27 |
+
|
28 |
+
|
29 |
+
# modified from COCOEvaluator for instance segmetnat
|
30 |
+
class InstanceSegEvaluator(COCOEvaluator):
|
31 |
+
"""
|
32 |
+
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
|
33 |
+
for keypoint detection outputs using COCO's metrics.
|
34 |
+
See http://cocodataset.org/#detection-eval and
|
35 |
+
http://cocodataset.org/#keypoints-eval to understand its metrics.
|
36 |
+
The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
|
37 |
+
the metric cannot be computed (e.g. due to no predictions made).
|
38 |
+
|
39 |
+
In addition to COCO, this evaluator is able to support any bounding box detection,
|
40 |
+
instance segmentation, or keypoint detection dataset.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def _eval_predictions(self, predictions, img_ids=None):
|
44 |
+
"""
|
45 |
+
Evaluate predictions. Fill self._results with the metrics of the tasks.
|
46 |
+
"""
|
47 |
+
self._logger.info("Preparing results for COCO format ...")
|
48 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
49 |
+
tasks = self._tasks or self._tasks_from_predictions(coco_results)
|
50 |
+
|
51 |
+
# unmap the category ids for COCO
|
52 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
53 |
+
dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
|
54 |
+
# all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
|
55 |
+
# num_classes = len(all_contiguous_ids)
|
56 |
+
# assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
|
57 |
+
|
58 |
+
reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
|
59 |
+
for result in coco_results:
|
60 |
+
category_id = result["category_id"]
|
61 |
+
# assert category_id < num_classes, (
|
62 |
+
# f"A prediction has class={category_id}, "
|
63 |
+
# f"but the dataset only has {num_classes} classes and "
|
64 |
+
# f"predicted class id should be in [0, {num_classes - 1}]."
|
65 |
+
# )
|
66 |
+
assert category_id in reverse_id_mapping, (
|
67 |
+
f"A prediction has class={category_id}, "
|
68 |
+
f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
|
69 |
+
)
|
70 |
+
result["category_id"] = reverse_id_mapping[category_id]
|
71 |
+
|
72 |
+
if self._output_dir:
|
73 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
74 |
+
self._logger.info("Saving results to {}".format(file_path))
|
75 |
+
with PathManager.open(file_path, "w") as f:
|
76 |
+
f.write(json.dumps(coco_results))
|
77 |
+
f.flush()
|
78 |
+
|
79 |
+
if not self._do_evaluation:
|
80 |
+
self._logger.info("Annotations are not available for evaluation.")
|
81 |
+
return
|
82 |
+
|
83 |
+
self._logger.info(
|
84 |
+
"Evaluating predictions with {} COCO API...".format(
|
85 |
+
"unofficial" if self._use_fast_impl else "official"
|
86 |
+
)
|
87 |
+
)
|
88 |
+
for task in sorted(tasks):
|
89 |
+
assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
|
90 |
+
coco_eval = (
|
91 |
+
_evaluate_predictions_on_coco(
|
92 |
+
self._coco_api,
|
93 |
+
coco_results,
|
94 |
+
task,
|
95 |
+
kpt_oks_sigmas=self._kpt_oks_sigmas,
|
96 |
+
use_fast_impl=self._use_fast_impl,
|
97 |
+
img_ids=img_ids,
|
98 |
+
max_dets_per_image=self._max_dets_per_image,
|
99 |
+
)
|
100 |
+
if len(coco_results) > 0
|
101 |
+
else None # cocoapi does not handle empty results very well
|
102 |
+
)
|
103 |
+
|
104 |
+
res = self._derive_coco_results(
|
105 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
106 |
+
)
|
107 |
+
self._results[task] = res
|
kmax_deeplab/evaluation/panoptic_evaluation.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py
|
2 |
+
# Reference: https://github.com/open-mmlab/mmdetection/pull/7538
|
3 |
+
|
4 |
+
#!/usr/bin/env python
|
5 |
+
from __future__ import absolute_import
|
6 |
+
from __future__ import division
|
7 |
+
from __future__ import print_function
|
8 |
+
from __future__ import unicode_literals
|
9 |
+
import os, sys
|
10 |
+
import numpy as np
|
11 |
+
import json
|
12 |
+
import time
|
13 |
+
from datetime import timedelta
|
14 |
+
from collections import defaultdict
|
15 |
+
import argparse
|
16 |
+
import multiprocessing
|
17 |
+
|
18 |
+
import PIL.Image as Image
|
19 |
+
|
20 |
+
from panopticapi.utils import get_traceback, rgb2id
|
21 |
+
|
22 |
+
OFFSET = 256 * 256 * 256
|
23 |
+
VOID = 0
|
24 |
+
|
25 |
+
class PQStatCat():
|
26 |
+
def __init__(self):
|
27 |
+
self.iou = 0.0
|
28 |
+
self.tp = 0
|
29 |
+
self.fp = 0
|
30 |
+
self.fn = 0
|
31 |
+
|
32 |
+
def __iadd__(self, pq_stat_cat):
|
33 |
+
self.iou += pq_stat_cat.iou
|
34 |
+
self.tp += pq_stat_cat.tp
|
35 |
+
self.fp += pq_stat_cat.fp
|
36 |
+
self.fn += pq_stat_cat.fn
|
37 |
+
return self
|
38 |
+
|
39 |
+
|
40 |
+
class PQStat():
|
41 |
+
def __init__(self):
|
42 |
+
self.pq_per_cat = defaultdict(PQStatCat)
|
43 |
+
|
44 |
+
def __getitem__(self, i):
|
45 |
+
return self.pq_per_cat[i]
|
46 |
+
|
47 |
+
def __iadd__(self, pq_stat):
|
48 |
+
for label, pq_stat_cat in pq_stat.pq_per_cat.items():
|
49 |
+
self.pq_per_cat[label] += pq_stat_cat
|
50 |
+
return self
|
51 |
+
|
52 |
+
def pq_average(self, categories, isthing):
|
53 |
+
pq, sq, rq, n = 0, 0, 0, 0
|
54 |
+
per_class_results = {}
|
55 |
+
for label, label_info in categories.items():
|
56 |
+
if isthing is not None:
|
57 |
+
cat_isthing = label_info['isthing'] == 1
|
58 |
+
if isthing != cat_isthing:
|
59 |
+
continue
|
60 |
+
iou = self.pq_per_cat[label].iou
|
61 |
+
tp = self.pq_per_cat[label].tp
|
62 |
+
fp = self.pq_per_cat[label].fp
|
63 |
+
fn = self.pq_per_cat[label].fn
|
64 |
+
if tp + fp + fn == 0:
|
65 |
+
per_class_results[label] = {'pq': 0.0, 'sq': 0.0, 'rq': 0.0}
|
66 |
+
continue
|
67 |
+
n += 1
|
68 |
+
pq_class = iou / (tp + 0.5 * fp + 0.5 * fn)
|
69 |
+
sq_class = iou / tp if tp != 0 else 0
|
70 |
+
rq_class = tp / (tp + 0.5 * fp + 0.5 * fn)
|
71 |
+
per_class_results[label] = {'pq': pq_class, 'sq': sq_class, 'rq': rq_class}
|
72 |
+
pq += pq_class
|
73 |
+
sq += sq_class
|
74 |
+
rq += rq_class
|
75 |
+
|
76 |
+
return {'pq': pq / n, 'sq': sq / n, 'rq': rq / n, 'n': n}, per_class_results
|
77 |
+
|
78 |
+
|
79 |
+
@get_traceback
|
80 |
+
def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, categories):
|
81 |
+
pq_stat = PQStat()
|
82 |
+
|
83 |
+
idx = 0
|
84 |
+
for gt_ann, pred_ann in annotation_set:
|
85 |
+
if idx % 100 == 0:
|
86 |
+
print('Core: {}, {} from {} images processed'.format(proc_id, idx, len(annotation_set)))
|
87 |
+
idx += 1
|
88 |
+
|
89 |
+
pan_gt = np.array(Image.open(os.path.join(gt_folder, gt_ann['file_name'])), dtype=np.uint32)
|
90 |
+
pan_gt = rgb2id(pan_gt)
|
91 |
+
pan_pred = np.array(Image.open(os.path.join(pred_folder, pred_ann['file_name'])), dtype=np.uint32)
|
92 |
+
pan_pred = rgb2id(pan_pred)
|
93 |
+
|
94 |
+
gt_segms = {el['id']: el for el in gt_ann['segments_info']}
|
95 |
+
pred_segms = {el['id']: el for el in pred_ann['segments_info']}
|
96 |
+
|
97 |
+
# predicted segments area calculation + prediction sanity checks
|
98 |
+
pred_labels_set = set(el['id'] for el in pred_ann['segments_info'])
|
99 |
+
labels, labels_cnt = np.unique(pan_pred, return_counts=True)
|
100 |
+
for label, label_cnt in zip(labels, labels_cnt):
|
101 |
+
if label not in pred_segms:
|
102 |
+
if label == VOID:
|
103 |
+
continue
|
104 |
+
raise KeyError('In the image with ID {} segment with ID {} is presented in PNG and not presented in JSON.'.format(gt_ann['image_id'], label))
|
105 |
+
pred_segms[label]['area'] = label_cnt
|
106 |
+
pred_labels_set.remove(label)
|
107 |
+
if pred_segms[label]['category_id'] not in categories:
|
108 |
+
raise KeyError('In the image with ID {} segment with ID {} has unknown category_id {}.'.format(gt_ann['image_id'], label, pred_segms[label]['category_id']))
|
109 |
+
if len(pred_labels_set) != 0:
|
110 |
+
raise KeyError('In the image with ID {} the following segment IDs {} are presented in JSON and not presented in PNG.'.format(gt_ann['image_id'], list(pred_labels_set)))
|
111 |
+
|
112 |
+
# confusion matrix calculation
|
113 |
+
pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype(np.uint64)
|
114 |
+
gt_pred_map = {}
|
115 |
+
labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True)
|
116 |
+
for label, intersection in zip(labels, labels_cnt):
|
117 |
+
gt_id = label // OFFSET
|
118 |
+
pred_id = label % OFFSET
|
119 |
+
gt_pred_map[(gt_id, pred_id)] = intersection
|
120 |
+
|
121 |
+
# count all matched pairs
|
122 |
+
gt_matched = set()
|
123 |
+
pred_matched = set()
|
124 |
+
for label_tuple, intersection in gt_pred_map.items():
|
125 |
+
gt_label, pred_label = label_tuple
|
126 |
+
if gt_label not in gt_segms:
|
127 |
+
continue
|
128 |
+
if pred_label not in pred_segms:
|
129 |
+
continue
|
130 |
+
if gt_segms[gt_label]['iscrowd'] == 1:
|
131 |
+
continue
|
132 |
+
if gt_segms[gt_label]['category_id'] != pred_segms[pred_label]['category_id']:
|
133 |
+
continue
|
134 |
+
|
135 |
+
union = pred_segms[pred_label]['area'] + gt_segms[gt_label]['area'] - intersection - gt_pred_map.get((VOID, pred_label), 0)
|
136 |
+
iou = intersection / union
|
137 |
+
if iou > 0.5:
|
138 |
+
pq_stat[gt_segms[gt_label]['category_id']].tp += 1
|
139 |
+
pq_stat[gt_segms[gt_label]['category_id']].iou += iou
|
140 |
+
gt_matched.add(gt_label)
|
141 |
+
pred_matched.add(pred_label)
|
142 |
+
|
143 |
+
# count false positives
|
144 |
+
crowd_labels_dict = {}
|
145 |
+
for gt_label, gt_info in gt_segms.items():
|
146 |
+
if gt_label in gt_matched:
|
147 |
+
continue
|
148 |
+
# crowd segments are ignored
|
149 |
+
if gt_info['iscrowd'] == 1:
|
150 |
+
crowd_labels_dict[gt_info['category_id']] = gt_label
|
151 |
+
continue
|
152 |
+
pq_stat[gt_info['category_id']].fn += 1
|
153 |
+
|
154 |
+
# count false positives
|
155 |
+
for pred_label, pred_info in pred_segms.items():
|
156 |
+
if pred_label in pred_matched:
|
157 |
+
continue
|
158 |
+
# intersection of the segment with VOID
|
159 |
+
intersection = gt_pred_map.get((VOID, pred_label), 0)
|
160 |
+
# plus intersection with corresponding CROWD region if it exists
|
161 |
+
if pred_info['category_id'] in crowd_labels_dict:
|
162 |
+
intersection += gt_pred_map.get((crowd_labels_dict[pred_info['category_id']], pred_label), 0)
|
163 |
+
# predicted segment is ignored if more than half of the segment correspond to VOID and CROWD regions
|
164 |
+
if intersection / pred_info['area'] > 0.5:
|
165 |
+
continue
|
166 |
+
pq_stat[pred_info['category_id']].fp += 1
|
167 |
+
print('Core: {}, all {} images processed'.format(proc_id, len(annotation_set)))
|
168 |
+
return pq_stat
|
169 |
+
|
170 |
+
|
171 |
+
def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories):
|
172 |
+
cpu_num = multiprocessing.cpu_count()
|
173 |
+
annotations_split = np.array_split(matched_annotations_list, cpu_num)
|
174 |
+
print("Number of cores: {}, images per core: {}".format(cpu_num, len(annotations_split[0])))
|
175 |
+
workers = multiprocessing.Pool(processes=cpu_num)
|
176 |
+
processes = []
|
177 |
+
for proc_id, annotation_set in enumerate(annotations_split):
|
178 |
+
p = workers.apply_async(pq_compute_single_core,
|
179 |
+
(proc_id, annotation_set, gt_folder, pred_folder, categories))
|
180 |
+
processes.append(p)
|
181 |
+
|
182 |
+
# https://github.com/open-mmlab/mmdetection/pull/7538
|
183 |
+
# Close the process pool, otherwise it will lead to memory
|
184 |
+
# leaking problems.
|
185 |
+
workers.close()
|
186 |
+
workers.join()
|
187 |
+
|
188 |
+
|
189 |
+
pq_stat = PQStat()
|
190 |
+
for p in processes:
|
191 |
+
pq_stat += p.get()
|
192 |
+
return pq_stat
|
193 |
+
|
194 |
+
|
195 |
+
def pq_compute(gt_json_file, pred_json_file, gt_folder=None, pred_folder=None):
|
196 |
+
|
197 |
+
start_time = time.time()
|
198 |
+
with open(gt_json_file, 'r') as f:
|
199 |
+
gt_json = json.load(f)
|
200 |
+
with open(pred_json_file, 'r') as f:
|
201 |
+
pred_json = json.load(f)
|
202 |
+
|
203 |
+
if gt_folder is None:
|
204 |
+
gt_folder = gt_json_file.replace('.json', '')
|
205 |
+
if pred_folder is None:
|
206 |
+
pred_folder = pred_json_file.replace('.json', '')
|
207 |
+
categories = {el['id']: el for el in gt_json['categories']}
|
208 |
+
|
209 |
+
print("Evaluation panoptic segmentation metrics:")
|
210 |
+
print("Ground truth:")
|
211 |
+
print("\tSegmentation folder: {}".format(gt_folder))
|
212 |
+
print("\tJSON file: {}".format(gt_json_file))
|
213 |
+
print("Prediction:")
|
214 |
+
print("\tSegmentation folder: {}".format(pred_folder))
|
215 |
+
print("\tJSON file: {}".format(pred_json_file))
|
216 |
+
|
217 |
+
if not os.path.isdir(gt_folder):
|
218 |
+
raise Exception("Folder {} with ground truth segmentations doesn't exist".format(gt_folder))
|
219 |
+
if not os.path.isdir(pred_folder):
|
220 |
+
raise Exception("Folder {} with predicted segmentations doesn't exist".format(pred_folder))
|
221 |
+
|
222 |
+
pred_annotations = {el['image_id']: el for el in pred_json['annotations']}
|
223 |
+
matched_annotations_list = []
|
224 |
+
for gt_ann in gt_json['annotations']:
|
225 |
+
image_id = gt_ann['image_id']
|
226 |
+
if image_id not in pred_annotations:
|
227 |
+
raise Exception('no prediction for the image with id: {}'.format(image_id))
|
228 |
+
matched_annotations_list.append((gt_ann, pred_annotations[image_id]))
|
229 |
+
|
230 |
+
pq_stat = pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories)
|
231 |
+
|
232 |
+
metrics = [("All", None), ("Things", True), ("Stuff", False)]
|
233 |
+
results = {}
|
234 |
+
for name, isthing in metrics:
|
235 |
+
results[name], per_class_results = pq_stat.pq_average(categories, isthing=isthing)
|
236 |
+
if name == 'All':
|
237 |
+
results['per_class'] = per_class_results
|
238 |
+
print("{:10s}| {:>5s} {:>5s} {:>5s} {:>5s}".format("", "PQ", "SQ", "RQ", "N"))
|
239 |
+
print("-" * (10 + 7 * 4))
|
240 |
+
|
241 |
+
for name, _isthing in metrics:
|
242 |
+
print("{:10s}| {:5.1f} {:5.1f} {:5.1f} {:5d}".format(
|
243 |
+
name,
|
244 |
+
100 * results[name]['pq'],
|
245 |
+
100 * results[name]['sq'],
|
246 |
+
100 * results[name]['rq'],
|
247 |
+
results[name]['n'])
|
248 |
+
)
|
249 |
+
|
250 |
+
t_delta = time.time() - start_time
|
251 |
+
print("Time elapsed: {:0.2f} seconds".format(t_delta))
|
252 |
+
|
253 |
+
return results
|
254 |
+
|
255 |
+
|
256 |
+
if __name__ == "__main__":
|
257 |
+
parser = argparse.ArgumentParser()
|
258 |
+
parser.add_argument('--gt_json_file', type=str,
|
259 |
+
help="JSON file with ground truth data")
|
260 |
+
parser.add_argument('--pred_json_file', type=str,
|
261 |
+
help="JSON file with predictions data")
|
262 |
+
parser.add_argument('--gt_folder', type=str, default=None,
|
263 |
+
help="Folder with ground turth COCO format segmentations. \
|
264 |
+
Default: X if the corresponding json file is X.json")
|
265 |
+
parser.add_argument('--pred_folder', type=str, default=None,
|
266 |
+
help="Folder with prediction COCO format segmentations. \
|
267 |
+
Default: X if the corresponding json file is X.json")
|
268 |
+
args = parser.parse_args()
|
269 |
+
pq_compute(args.gt_json_file, args.pred_json_file, args.gt_folder, args.pred_folder)
|
kmax_deeplab/kmax_model.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/maskformer_model.py
|
2 |
+
# Reference: https://github.com/google-research/deeplab2/blob/main/model/kmax_deeplab.py
|
3 |
+
# Reference: https://github.com/google-research/deeplab2/blob/main/model/post_processor/max_deeplab.py
|
4 |
+
# Modified by Qihang Yu
|
5 |
+
|
6 |
+
from typing import Tuple, List
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from detectron2.config import configurable
|
13 |
+
from detectron2.data import MetadataCatalog
|
14 |
+
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
|
15 |
+
from detectron2.modeling.backbone import Backbone
|
16 |
+
from detectron2.modeling.postprocessing import sem_seg_postprocess
|
17 |
+
from detectron2.structures import Boxes, ImageList, Instances
|
18 |
+
from detectron2.utils.memory import retry_if_cuda_oom
|
19 |
+
|
20 |
+
from .modeling.criterion import SetCriterion
|
21 |
+
from .modeling.matcher import HungarianMatcher
|
22 |
+
from torch.cuda.amp import autocast
|
23 |
+
|
24 |
+
|
25 |
+
@META_ARCH_REGISTRY.register()
|
26 |
+
class kMaXDeepLab(nn.Module):
|
27 |
+
"""
|
28 |
+
Main class for mask classification semantic segmentation architectures.
|
29 |
+
"""
|
30 |
+
|
31 |
+
@configurable
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
*,
|
35 |
+
backbone: Backbone,
|
36 |
+
sem_seg_head: nn.Module,
|
37 |
+
criterion: nn.Module,
|
38 |
+
num_queries: int,
|
39 |
+
object_mask_threshold: float,
|
40 |
+
class_threshold_thing: float,
|
41 |
+
class_threshold_stuff: float,
|
42 |
+
overlap_threshold: float,
|
43 |
+
reorder_class_weight: float,
|
44 |
+
reorder_mask_weight: float,
|
45 |
+
metadata,
|
46 |
+
size_divisibility: int,
|
47 |
+
sem_seg_postprocess_before_inference: bool,
|
48 |
+
pixel_mean: Tuple[float],
|
49 |
+
pixel_std: Tuple[float],
|
50 |
+
# inference
|
51 |
+
semantic_on: bool,
|
52 |
+
panoptic_on: bool,
|
53 |
+
instance_on: bool,
|
54 |
+
test_topk_per_image: int,
|
55 |
+
input_shape: List[int]
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Args:
|
59 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
60 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
61 |
+
criterion: a module that defines the loss
|
62 |
+
num_queries: int, number of queries
|
63 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
64 |
+
for panoptic segmentation inference
|
65 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
66 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
67 |
+
segmentation inference
|
68 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
69 |
+
specific integer. We can use this to override such requirement.
|
70 |
+
sem_seg_postprocess_before_inference: whether to resize the prediction back
|
71 |
+
to original input size before semantic segmentation inference or after.
|
72 |
+
For high-resolution dataset like Mapillary, resizing predictions before
|
73 |
+
inference will cause OOM error.
|
74 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
75 |
+
the per-channel mean and std to be used to normalize the input image
|
76 |
+
semantic_on: bool, whether to output semantic segmentation prediction
|
77 |
+
instance_on: bool, whether to output instance segmentation prediction
|
78 |
+
panoptic_on: bool, whether to output panoptic segmentation prediction
|
79 |
+
test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
|
80 |
+
"""
|
81 |
+
super().__init__()
|
82 |
+
self.backbone = backbone
|
83 |
+
self.sem_seg_head = sem_seg_head
|
84 |
+
self.criterion = criterion
|
85 |
+
self.num_queries = num_queries
|
86 |
+
self.overlap_threshold = overlap_threshold
|
87 |
+
self.object_mask_threshold = object_mask_threshold
|
88 |
+
self.class_threshold_thing = class_threshold_thing
|
89 |
+
self.class_threshold_stuff = class_threshold_stuff
|
90 |
+
self.reorder_class_weight = reorder_class_weight
|
91 |
+
self.reorder_mask_weight = reorder_mask_weight
|
92 |
+
self.metadata = metadata
|
93 |
+
if size_divisibility < 0:
|
94 |
+
# use backbone size_divisibility if not set
|
95 |
+
size_divisibility = self.backbone.size_divisibility
|
96 |
+
self.size_divisibility = size_divisibility
|
97 |
+
self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
|
98 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
99 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
100 |
+
|
101 |
+
# additional args
|
102 |
+
self.semantic_on = semantic_on
|
103 |
+
self.instance_on = instance_on
|
104 |
+
self.panoptic_on = panoptic_on
|
105 |
+
self.test_topk_per_image = test_topk_per_image
|
106 |
+
|
107 |
+
if not self.semantic_on:
|
108 |
+
assert self.sem_seg_postprocess_before_inference
|
109 |
+
|
110 |
+
self.input_shape = input_shape
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def from_config(cls, cfg):
|
114 |
+
backbone = build_backbone(cfg)
|
115 |
+
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
|
116 |
+
|
117 |
+
# Loss parameters:
|
118 |
+
deep_supervision = cfg.MODEL.KMAX_DEEPLAB.DEEP_SUPERVISION
|
119 |
+
no_object_weight = cfg.MODEL.KMAX_DEEPLAB.NO_OBJECT_WEIGHT
|
120 |
+
share_final_matching = cfg.MODEL.KMAX_DEEPLAB.SHARE_FINAL_MATCHING
|
121 |
+
|
122 |
+
# loss weights
|
123 |
+
class_weight = cfg.MODEL.KMAX_DEEPLAB.CLASS_WEIGHT
|
124 |
+
dice_weight = cfg.MODEL.KMAX_DEEPLAB.DICE_WEIGHT
|
125 |
+
mask_weight = cfg.MODEL.KMAX_DEEPLAB.MASK_WEIGHT
|
126 |
+
insdis_weight = cfg.MODEL.KMAX_DEEPLAB.INSDIS_WEIGHT
|
127 |
+
aux_semantic_weight = cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_WEIGHT
|
128 |
+
|
129 |
+
# building criterion
|
130 |
+
matcher = HungarianMatcher()
|
131 |
+
|
132 |
+
weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight,
|
133 |
+
"loss_pixel_insdis": insdis_weight, "loss_aux_semantic": aux_semantic_weight}
|
134 |
+
|
135 |
+
if deep_supervision:
|
136 |
+
dec_layers = sum(cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DEC_LAYERS)
|
137 |
+
aux_weight_dict = {}
|
138 |
+
for i in range(dec_layers):
|
139 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
140 |
+
weight_dict.update(aux_weight_dict)
|
141 |
+
|
142 |
+
losses = ["labels", "masks"]
|
143 |
+
if insdis_weight > 0:
|
144 |
+
losses += ["pixels"]
|
145 |
+
if aux_semantic_weight > 0:
|
146 |
+
losses += ["aux_semantic"]
|
147 |
+
|
148 |
+
criterion = SetCriterion(
|
149 |
+
sem_seg_head.num_classes,
|
150 |
+
matcher=matcher,
|
151 |
+
weight_dict=weight_dict,
|
152 |
+
eos_coef=no_object_weight,
|
153 |
+
losses=losses,
|
154 |
+
share_final_matching=share_final_matching,
|
155 |
+
pixel_insdis_temperature=cfg.MODEL.KMAX_DEEPLAB.PIXEL_INSDIS_TEMPERATURE,
|
156 |
+
pixel_insdis_sample_k=cfg.MODEL.KMAX_DEEPLAB.PIXEL_INSDIS_SAMPLE_K,
|
157 |
+
aux_semantic_temperature=cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_TEMPERATURE,
|
158 |
+
aux_semantic_sample_k=cfg.MODEL.KMAX_DEEPLAB.UX_SEMANTIC_SAMPLE_K
|
159 |
+
)
|
160 |
+
|
161 |
+
return {
|
162 |
+
"backbone": backbone,
|
163 |
+
"sem_seg_head": sem_seg_head,
|
164 |
+
"criterion": criterion,
|
165 |
+
"num_queries": cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NUM_OBJECT_QUERIES,
|
166 |
+
"object_mask_threshold": cfg.MODEL.KMAX_DEEPLAB.TEST.OBJECT_MASK_THRESHOLD,
|
167 |
+
"class_threshold_thing": cfg.MODEL.KMAX_DEEPLAB.TEST.CLASS_THRESHOLD_THING,
|
168 |
+
"class_threshold_stuff": cfg.MODEL.KMAX_DEEPLAB.TEST.CLASS_THRESHOLD_STUFF,
|
169 |
+
"overlap_threshold": cfg.MODEL.KMAX_DEEPLAB.TEST.OVERLAP_THRESHOLD,
|
170 |
+
"reorder_class_weight": cfg.MODEL.KMAX_DEEPLAB.TEST.REORDER_CLASS_WEIGHT,
|
171 |
+
"reorder_mask_weight": cfg.MODEL.KMAX_DEEPLAB.TEST.REORDER_MASK_WEIGHT,
|
172 |
+
"metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
|
173 |
+
"size_divisibility": cfg.MODEL.KMAX_DEEPLAB.SIZE_DIVISIBILITY,
|
174 |
+
"sem_seg_postprocess_before_inference": (
|
175 |
+
cfg.MODEL.KMAX_DEEPLAB.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
|
176 |
+
or cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON
|
177 |
+
or cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON
|
178 |
+
),
|
179 |
+
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
|
180 |
+
"pixel_std": cfg.MODEL.PIXEL_STD,
|
181 |
+
# inference
|
182 |
+
"semantic_on": cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON,
|
183 |
+
"instance_on": cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON,
|
184 |
+
"panoptic_on": cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON,
|
185 |
+
"test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
|
186 |
+
"input_shape": cfg.INPUT.IMAGE_SIZE
|
187 |
+
}
|
188 |
+
|
189 |
+
@property
|
190 |
+
def device(self):
|
191 |
+
return self.pixel_mean.device
|
192 |
+
|
193 |
+
def forward(self, batched_inputs):
|
194 |
+
"""
|
195 |
+
Args:
|
196 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
197 |
+
Each item in the list contains the inputs for one image.
|
198 |
+
For now, each item in the list is a dict that contains:
|
199 |
+
* "image": Tensor, image in (C, H, W) format.
|
200 |
+
* "instances": per-region ground truth
|
201 |
+
* Other information that's included in the original dicts, such as:
|
202 |
+
"height", "width" (int): the output resolution of the model (may be different
|
203 |
+
from input resolution), used in inference.
|
204 |
+
Returns:
|
205 |
+
list[dict]:
|
206 |
+
each dict has the results for one image. The dict contains the following keys:
|
207 |
+
|
208 |
+
* "sem_seg":
|
209 |
+
A Tensor that represents the
|
210 |
+
per-pixel segmentation prediced by the head.
|
211 |
+
The prediction has shape KxHxW that represents the logits of
|
212 |
+
each class for each pixel.
|
213 |
+
* "panoptic_seg":
|
214 |
+
A tuple that represent panoptic output
|
215 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
216 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
217 |
+
Each dict contains keys "id", "category_id", "isthing".
|
218 |
+
"""
|
219 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
220 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
221 |
+
if "is_real_pixels" in batched_inputs[0]:
|
222 |
+
is_real_pixels = [x["is_real_pixels"] for x in batched_inputs]
|
223 |
+
# Set all padded pixel values to 0.
|
224 |
+
images = [x * y.to(x) for x, y in zip(images, is_real_pixels)]
|
225 |
+
|
226 |
+
# We perform zero padding to ensure input shape equal to self.input_shape.
|
227 |
+
# The padding is done on the right and bottom sides.
|
228 |
+
for idx in range(len(images)):
|
229 |
+
cur_height, cur_width = images[idx].shape[-2:]
|
230 |
+
padding = (0, max(0, self.input_shape[1] - cur_width), 0, max(0, self.input_shape[0] - cur_height), 0, 0)
|
231 |
+
images[idx] = F.pad(images[idx], padding, value=0)
|
232 |
+
images = ImageList.from_tensors(images, -1)
|
233 |
+
|
234 |
+
if self.training:
|
235 |
+
# mask classification target
|
236 |
+
if "instances" in batched_inputs[0]:
|
237 |
+
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
238 |
+
gt_semantic = [x["sem_seg_gt"].to(self.device) for x in batched_inputs]
|
239 |
+
targets = self.prepare_targets(gt_instances, gt_semantic, images)
|
240 |
+
else:
|
241 |
+
targets = None
|
242 |
+
|
243 |
+
features = self.backbone(images.tensor)
|
244 |
+
outputs = self.sem_seg_head(features)
|
245 |
+
|
246 |
+
if self.training:
|
247 |
+
|
248 |
+
with autocast(enabled=False):
|
249 |
+
# bipartite matching-based loss
|
250 |
+
for output_key in ["pixel_feature", "pred_masks", "pred_logits", "aux_semantic_pred"]:
|
251 |
+
if output_key in outputs:
|
252 |
+
outputs[output_key] = outputs[output_key].float()
|
253 |
+
for i in range(len(outputs["aux_outputs"])):
|
254 |
+
for output_key in ["pixel_feature", "pred_masks", "pred_logits"]:
|
255 |
+
outputs["aux_outputs"][i][output_key] = outputs["aux_outputs"][i][output_key].float()
|
256 |
+
|
257 |
+
losses = self.criterion(outputs, targets)
|
258 |
+
|
259 |
+
for k in list(losses.keys()):
|
260 |
+
if k in self.criterion.weight_dict:
|
261 |
+
losses[k] *= self.criterion.weight_dict[k]
|
262 |
+
else:
|
263 |
+
# remove this loss if not specified in `weight_dict`
|
264 |
+
losses.pop(k)
|
265 |
+
return losses
|
266 |
+
else:
|
267 |
+
mask_cls_results = outputs["pred_logits"]
|
268 |
+
mask_pred_results = outputs["pred_masks"]
|
269 |
+
|
270 |
+
align_corners = (images.tensor.shape[-1] % 2 == 1)
|
271 |
+
# upsample masks
|
272 |
+
mask_pred_results = F.interpolate(
|
273 |
+
mask_pred_results,
|
274 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
275 |
+
mode="bilinear",
|
276 |
+
align_corners=align_corners,
|
277 |
+
)
|
278 |
+
|
279 |
+
del outputs
|
280 |
+
|
281 |
+
processed_results = []
|
282 |
+
for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
283 |
+
mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
|
284 |
+
):
|
285 |
+
height = input_per_image.get("height", image_size[0])
|
286 |
+
width = input_per_image.get("width", image_size[1])
|
287 |
+
cur_image = input_per_image["image"].to(self.device)
|
288 |
+
processed_results.append({})
|
289 |
+
scale_factor = max(images.tensor.shape[-2:]) / max(height, width)
|
290 |
+
ori_height, ori_width = round(height * scale_factor), round(width * scale_factor)
|
291 |
+
mask_pred_result = mask_pred_result[:, :ori_height, :ori_width].expand(1, -1, -1, -1)
|
292 |
+
cur_image = cur_image[:, :ori_height, :ori_width].expand(1, -1, -1, -1)
|
293 |
+
mask_pred_result = F.interpolate(
|
294 |
+
mask_pred_result, size=(height, width), mode="bilinear", align_corners=align_corners
|
295 |
+
)[0]
|
296 |
+
cur_image = F.interpolate(
|
297 |
+
cur_image.float(), size=(height, width), mode="bilinear", align_corners=align_corners
|
298 |
+
)[0].to(torch.uint8)
|
299 |
+
|
300 |
+
if self.sem_seg_postprocess_before_inference:
|
301 |
+
mask_cls_result = mask_cls_result.to(mask_pred_result)
|
302 |
+
|
303 |
+
# semantic segmentation inference
|
304 |
+
if self.semantic_on:
|
305 |
+
r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
|
306 |
+
if not self.sem_seg_postprocess_before_inference:
|
307 |
+
r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
|
308 |
+
processed_results[-1]["sem_seg"] = r
|
309 |
+
|
310 |
+
# panoptic segmentation inference
|
311 |
+
if self.panoptic_on:
|
312 |
+
panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
|
313 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
314 |
+
processed_results[-1]["original_image"] = cur_image
|
315 |
+
|
316 |
+
# instance segmentation inference
|
317 |
+
if self.instance_on:
|
318 |
+
instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result)
|
319 |
+
processed_results[-1]["instances"] = instance_r
|
320 |
+
|
321 |
+
return processed_results
|
322 |
+
|
323 |
+
def prepare_targets(self, targets, targets_semantic, images):
|
324 |
+
new_targets = []
|
325 |
+
for targets_per_image, semantic_gt_mask in zip(targets, targets_semantic):
|
326 |
+
gt_masks = targets_per_image.gt_masks
|
327 |
+
new_targets.append(
|
328 |
+
{
|
329 |
+
"labels": targets_per_image.gt_classes,
|
330 |
+
"masks": gt_masks,
|
331 |
+
"semantic_masks": semantic_gt_mask
|
332 |
+
}
|
333 |
+
)
|
334 |
+
return new_targets
|
335 |
+
|
336 |
+
def semantic_inference(self, mask_cls, mask_pred):
|
337 |
+
# For cls prob, we exluced the void class following
|
338 |
+
# https://github.com/google-research/deeplab2/blob/main/model/post_processor/max_deeplab.py#L199
|
339 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
340 |
+
mask_pred = F.softmax(mask_pred, dim=0)
|
341 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
342 |
+
return semseg
|
343 |
+
|
344 |
+
def panoptic_inference(self, mask_cls, mask_pred):
|
345 |
+
# mask_cls: N x C
|
346 |
+
# mask_pred: N x H x W
|
347 |
+
# some hyper-params
|
348 |
+
num_mask_slots = mask_pred.shape[0]
|
349 |
+
cls_threshold_thing = self.class_threshold_thing
|
350 |
+
cls_threshold_stuff = self.class_threshold_stuff
|
351 |
+
object_mask_threshold = self.object_mask_threshold
|
352 |
+
overlap_threshold = self.overlap_threshold
|
353 |
+
reorder_class_weight = self.reorder_class_weight
|
354 |
+
reorder_mask_weight = self.reorder_mask_weight
|
355 |
+
|
356 |
+
# https://github.com/google-research/deeplab2/blob/main/model/post_processor/max_deeplab.py#L675
|
357 |
+
# https://github.com/google-research/deeplab2/blob/main/model/post_processor/max_deeplab.py#L199
|
358 |
+
cls_scores, cls_labels = F.softmax(mask_cls, dim=-1)[..., :-1].max(-1) # N
|
359 |
+
mask_scores = F.softmax(mask_pred, dim=0)
|
360 |
+
binary_masks = mask_scores > object_mask_threshold # N x H x W
|
361 |
+
mask_scores_flat = mask_scores.flatten(1) # N x HW
|
362 |
+
binary_masks_flat = binary_masks.flatten(1).float() # N x HW
|
363 |
+
pixel_number_flat = binary_masks_flat.sum(1) # N
|
364 |
+
mask_scores_flat = (mask_scores_flat * binary_masks_flat).sum(1) / torch.clamp(pixel_number_flat, min=1.0) # N
|
365 |
+
|
366 |
+
reorder_score = (cls_scores ** reorder_class_weight) * (mask_scores_flat ** reorder_mask_weight) # N
|
367 |
+
reorder_indices = torch.argsort(reorder_score, dim=-1, descending=True)
|
368 |
+
|
369 |
+
panoptic_seg = torch.zeros((mask_pred.shape[1], mask_pred.shape[2]),
|
370 |
+
dtype=torch.int32, device=mask_pred.device)
|
371 |
+
segments_info = []
|
372 |
+
|
373 |
+
current_segment_id = 0
|
374 |
+
stuff_memory_list = {}
|
375 |
+
for i in range(num_mask_slots):
|
376 |
+
cur_idx = reorder_indices[i].item() # 1
|
377 |
+
cur_binary_mask = binary_masks[cur_idx] # H x W
|
378 |
+
cur_cls_score = cls_scores[cur_idx].item() # 1
|
379 |
+
cur_cls_label = cls_labels[cur_idx].item() # 1
|
380 |
+
is_thing = cur_cls_label in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
381 |
+
is_confident = (is_thing and cur_cls_score > cls_threshold_thing) or (
|
382 |
+
(not is_thing) and cur_cls_score > cls_threshold_stuff)
|
383 |
+
|
384 |
+
original_pixel_number = cur_binary_mask.float().sum()
|
385 |
+
new_binary_mask = torch.logical_and(cur_binary_mask, (panoptic_seg == 0))
|
386 |
+
new_pixel_number = new_binary_mask.float().sum()
|
387 |
+
is_not_overlap_too_much = new_pixel_number > (original_pixel_number * overlap_threshold)
|
388 |
+
|
389 |
+
if is_confident and is_not_overlap_too_much:
|
390 |
+
# merge stuff regions
|
391 |
+
if not is_thing:
|
392 |
+
if int(cur_cls_label) in stuff_memory_list.keys():
|
393 |
+
panoptic_seg[new_binary_mask] = stuff_memory_list[int(cur_cls_label)]
|
394 |
+
continue
|
395 |
+
else:
|
396 |
+
stuff_memory_list[int(cur_cls_label)] = current_segment_id + 1
|
397 |
+
|
398 |
+
current_segment_id += 1
|
399 |
+
panoptic_seg[new_binary_mask] = current_segment_id
|
400 |
+
|
401 |
+
segments_info.append(
|
402 |
+
{
|
403 |
+
"id": current_segment_id,
|
404 |
+
"isthing": bool(is_thing),
|
405 |
+
"category_id": int(cur_cls_label),
|
406 |
+
}
|
407 |
+
)
|
408 |
+
|
409 |
+
return panoptic_seg, segments_info
|
410 |
+
|
411 |
+
|
412 |
+
def instance_inference(self, mask_cls, mask_pred):
|
413 |
+
# mask_pred is already processed to have the same shape as original input
|
414 |
+
image_size = mask_pred.shape[-2:]
|
415 |
+
|
416 |
+
mask_pred = mask_pred.softmax(dim=0)
|
417 |
+
# [Q, K]
|
418 |
+
scores = F.softmax(mask_cls[:, :-1], dim=-1)
|
419 |
+
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
420 |
+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
421 |
+
labels_per_image = labels[topk_indices]
|
422 |
+
|
423 |
+
topk_indices = topk_indices // self.sem_seg_head.num_classes
|
424 |
+
mask_pred = mask_pred[topk_indices]
|
425 |
+
|
426 |
+
# if this is panoptic segmentation, we only keep the "thing" classes
|
427 |
+
if self.panoptic_on:
|
428 |
+
keep = torch.zeros_like(scores_per_image).bool()
|
429 |
+
for i, lab in enumerate(labels_per_image):
|
430 |
+
keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
431 |
+
|
432 |
+
scores_per_image = scores_per_image[keep]
|
433 |
+
labels_per_image = labels_per_image[keep]
|
434 |
+
mask_pred = mask_pred[keep]
|
435 |
+
|
436 |
+
result = Instances(image_size)
|
437 |
+
result.pred_masks = (mask_pred > self.object_mask_threshold).float()
|
438 |
+
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
439 |
+
# Uncomment the following to get boxes from masks (this is slow)
|
440 |
+
# result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
|
441 |
+
|
442 |
+
# calculate average mask prob
|
443 |
+
mask_scores_per_image = (mask_pred.flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
|
444 |
+
result.scores = scores_per_image * mask_scores_per_image
|
445 |
+
result.pred_classes = labels_per_image
|
446 |
+
return result
|
kmax_deeplab/modeling/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .backbone.convnext import D2ConvNeXt
|
2 |
+
from .backbone.resnet import custom_bn_build_resnet_backbone
|
3 |
+
from .pixel_decoder.kmax_pixel_decoder import kMaXPixelDecoder
|
4 |
+
from .meta_arch.kmax_deeplab_head import kMaXDeepLabHead
|
kmax_deeplab/modeling/backbone/__init__.py
ADDED
File without changes
|
kmax_deeplab/modeling/backbone/convnext.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# reference: https://github.com/SHI-Labs/OneFormer/blob/main/oneformer/modeling/backbone/convnext.py
|
2 |
+
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from timm.models.layers import DropPath
|
10 |
+
|
11 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
12 |
+
from torch.cuda.amp import autocast
|
13 |
+
|
14 |
+
|
15 |
+
class Block(nn.Module):
|
16 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
17 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
18 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
19 |
+
We use (2) as we find it slightly faster in PyTorch
|
20 |
+
|
21 |
+
Args:
|
22 |
+
dim (int): Number of input channels.
|
23 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
24 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
25 |
+
"""
|
26 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
27 |
+
super().__init__()
|
28 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
29 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
30 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
31 |
+
self.act = nn.GELU()
|
32 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
33 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
34 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
35 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
input = x
|
39 |
+
x = self.dwconv(x)
|
40 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
41 |
+
x = self.norm(x)
|
42 |
+
x = self.pwconv1(x)
|
43 |
+
x = self.act(x)
|
44 |
+
x = self.pwconv2(x)
|
45 |
+
if self.gamma is not None:
|
46 |
+
x = self.gamma * x
|
47 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
48 |
+
|
49 |
+
x = input + self.drop_path(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
class LayerNorm(nn.Module):
|
53 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
54 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
55 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
56 |
+
with shape (batch_size, channels, height, width).
|
57 |
+
"""
|
58 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
59 |
+
super().__init__()
|
60 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
61 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
62 |
+
self.eps = eps
|
63 |
+
self.data_format = data_format
|
64 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
65 |
+
raise NotImplementedError
|
66 |
+
self.normalized_shape = (normalized_shape, )
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
with autocast(enabled=False):
|
70 |
+
x = x.float()
|
71 |
+
if self.data_format == "channels_last":
|
72 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
73 |
+
elif self.data_format == "channels_first":
|
74 |
+
u = x.mean(1, keepdim=True)
|
75 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
76 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
77 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class ConvNeXt(nn.Module):
|
82 |
+
r""" ConvNeXt
|
83 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
84 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
85 |
+
|
86 |
+
Args:
|
87 |
+
in_chans (int): Number of input image channels. Default: 3
|
88 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
89 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
90 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
91 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
92 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
93 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
94 |
+
"""
|
95 |
+
def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
|
96 |
+
drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3],
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.num_features = dims
|
101 |
+
|
102 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
103 |
+
stem = nn.Sequential(
|
104 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
105 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
106 |
+
)
|
107 |
+
self.downsample_layers.append(stem)
|
108 |
+
for i in range(3):
|
109 |
+
downsample_layer = nn.Sequential(
|
110 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
111 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
112 |
+
)
|
113 |
+
self.downsample_layers.append(downsample_layer)
|
114 |
+
|
115 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
116 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
117 |
+
cur = 0
|
118 |
+
for i in range(4):
|
119 |
+
stage = nn.Sequential(
|
120 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
121 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
122 |
+
)
|
123 |
+
self.stages.append(stage)
|
124 |
+
cur += depths[i]
|
125 |
+
|
126 |
+
self.out_indices = out_indices
|
127 |
+
|
128 |
+
def forward_features(self, x):
|
129 |
+
outs = {}
|
130 |
+
for i in range(4):
|
131 |
+
# We add zero padding here for downstream tasks.
|
132 |
+
# ref: https://github.com/google-research/deeplab2/blob/main/model/pixel_encoder/convnext.py#L128
|
133 |
+
if i == 0:
|
134 |
+
x = F.pad(x, (1, 2, 1, 2, 0, 0, 0, 0), "constant", 0)
|
135 |
+
else:
|
136 |
+
x = F.pad(x, (0, 1, 0, 1, 0, 0, 0, 0), "constant", 0)
|
137 |
+
x = self.downsample_layers[i](x)
|
138 |
+
x = self.stages[i](x)
|
139 |
+
if i in self.out_indices:
|
140 |
+
outs["res{}".format(i + 2)] = x
|
141 |
+
|
142 |
+
return outs
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x = self.forward_features(x)
|
146 |
+
return x
|
147 |
+
|
148 |
+
@BACKBONE_REGISTRY.register()
|
149 |
+
class D2ConvNeXt(ConvNeXt, Backbone):
|
150 |
+
def __init__(self, cfg, input_shape):
|
151 |
+
|
152 |
+
in_chans = cfg.MODEL.CONVNEXT.IN_CHANNELS
|
153 |
+
depths = cfg.MODEL.CONVNEXT.DEPTHS
|
154 |
+
dims = cfg.MODEL.CONVNEXT.DIMS
|
155 |
+
drop_path_rate = cfg.MODEL.CONVNEXT.DROP_PATH_RATE
|
156 |
+
layer_scale_init_value = cfg.MODEL.CONVNEXT.LSIT
|
157 |
+
out_indices = cfg.MODEL.CONVNEXT.OUT_INDICES
|
158 |
+
|
159 |
+
super().__init__(
|
160 |
+
in_chans=in_chans,
|
161 |
+
depths=depths,
|
162 |
+
dims=dims,
|
163 |
+
drop_path_rate=drop_path_rate,
|
164 |
+
layer_scale_init_value=layer_scale_init_value,
|
165 |
+
out_indices=out_indices,
|
166 |
+
)
|
167 |
+
|
168 |
+
self._out_features = cfg.MODEL.CONVNEXT.OUT_FEATURES
|
169 |
+
|
170 |
+
self._out_feature_strides = {
|
171 |
+
"res2": 4,
|
172 |
+
"res3": 8,
|
173 |
+
"res4": 16,
|
174 |
+
"res5": 32,
|
175 |
+
}
|
176 |
+
self._out_feature_channels = {
|
177 |
+
"res2": self.num_features[0],
|
178 |
+
"res3": self.num_features[1],
|
179 |
+
"res4": self.num_features[2],
|
180 |
+
"res5": self.num_features[3],
|
181 |
+
}
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
"""
|
185 |
+
Args:
|
186 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
187 |
+
Returns:
|
188 |
+
dict[str->Tensor]: names and the corresponding features
|
189 |
+
"""
|
190 |
+
assert (
|
191 |
+
x.dim() == 4
|
192 |
+
), f"ConvNeXt takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
193 |
+
outputs = {}
|
194 |
+
y = super().forward(x)
|
195 |
+
for k in y.keys():
|
196 |
+
if k in self._out_features:
|
197 |
+
outputs[k] = y[k]
|
198 |
+
return outputs
|
199 |
+
|
200 |
+
def output_shape(self):
|
201 |
+
return {
|
202 |
+
name: ShapeSpec(
|
203 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
204 |
+
)
|
205 |
+
for name in self._out_features
|
206 |
+
}
|
207 |
+
|
208 |
+
@property
|
209 |
+
def size_divisibility(self):
|
210 |
+
return -1
|
kmax_deeplab/modeling/backbone/resnet.py
ADDED
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py
|
2 |
+
# Modified by Qihang Yu
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import fvcore.nn.weight_init as weight_init
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from detectron2.layers import (
|
11 |
+
CNNBlockBase,
|
12 |
+
Conv2d,
|
13 |
+
DeformConv,
|
14 |
+
ModulatedDeformConv,
|
15 |
+
#ShapeSpec,
|
16 |
+
#get_norm,
|
17 |
+
)
|
18 |
+
|
19 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
20 |
+
|
21 |
+
from ..pixel_decoder.kmax_pixel_decoder import get_norm
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
"ResNetBlockBase",
|
25 |
+
"BasicBlock",
|
26 |
+
"BottleneckBlock",
|
27 |
+
"DeformBottleneckBlock",
|
28 |
+
"BasicStem",
|
29 |
+
"ResNet",
|
30 |
+
"make_stage",
|
31 |
+
"custom_bn_build_resnet_backbone",
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
class BasicBlock(CNNBlockBase):
|
36 |
+
"""
|
37 |
+
The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
|
38 |
+
with two 3x3 conv layers and a projection shortcut if needed.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
in_channels (int): Number of input channels.
|
45 |
+
out_channels (int): Number of output channels.
|
46 |
+
stride (int): Stride for the first conv.
|
47 |
+
norm (str or callable): normalization for all conv layers.
|
48 |
+
See :func:`layers.get_norm` for supported format.
|
49 |
+
"""
|
50 |
+
super().__init__(in_channels, out_channels, stride)
|
51 |
+
|
52 |
+
if in_channels != out_channels:
|
53 |
+
self.shortcut = Conv2d(
|
54 |
+
in_channels,
|
55 |
+
out_channels,
|
56 |
+
kernel_size=1,
|
57 |
+
stride=stride,
|
58 |
+
bias=False,
|
59 |
+
norm=get_norm(norm, out_channels),
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
self.shortcut = None
|
63 |
+
|
64 |
+
self.conv1 = Conv2d(
|
65 |
+
in_channels,
|
66 |
+
out_channels,
|
67 |
+
kernel_size=3,
|
68 |
+
stride=stride,
|
69 |
+
padding=1,
|
70 |
+
bias=False,
|
71 |
+
norm=get_norm(norm, out_channels),
|
72 |
+
)
|
73 |
+
|
74 |
+
self.conv2 = Conv2d(
|
75 |
+
out_channels,
|
76 |
+
out_channels,
|
77 |
+
kernel_size=3,
|
78 |
+
stride=1,
|
79 |
+
padding=1,
|
80 |
+
bias=False,
|
81 |
+
norm=get_norm(norm, out_channels),
|
82 |
+
)
|
83 |
+
|
84 |
+
for layer in [self.conv1, self.conv2, self.shortcut]:
|
85 |
+
if layer is not None: # shortcut can be None
|
86 |
+
weight_init.c2_msra_fill(layer)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
out = self.conv1(x)
|
90 |
+
out = F.relu_(out)
|
91 |
+
out = self.conv2(out)
|
92 |
+
|
93 |
+
if self.shortcut is not None:
|
94 |
+
shortcut = self.shortcut(x)
|
95 |
+
else:
|
96 |
+
shortcut = x
|
97 |
+
|
98 |
+
out += shortcut
|
99 |
+
out = F.relu_(out)
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class BottleneckBlock(CNNBlockBase):
|
104 |
+
"""
|
105 |
+
The standard bottleneck residual block used by ResNet-50, 101 and 152
|
106 |
+
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
|
107 |
+
1x1, 3x3, 1x1, and a projection shortcut if needed.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
in_channels,
|
113 |
+
out_channels,
|
114 |
+
*,
|
115 |
+
bottleneck_channels,
|
116 |
+
stride=1,
|
117 |
+
num_groups=1,
|
118 |
+
norm="BN",
|
119 |
+
stride_in_1x1=False,
|
120 |
+
dilation=1,
|
121 |
+
):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
125 |
+
"bottleneck" conv layers.
|
126 |
+
num_groups (int): number of groups for the 3x3 conv layer.
|
127 |
+
norm (str or callable): normalization for all conv layers.
|
128 |
+
See :func:`layers.get_norm` for supported format.
|
129 |
+
stride_in_1x1 (bool): when stride>1, whether to put stride in the
|
130 |
+
first 1x1 convolution or the bottleneck 3x3 convolution.
|
131 |
+
dilation (int): the dilation rate of the 3x3 conv layer.
|
132 |
+
"""
|
133 |
+
super().__init__(in_channels, out_channels, stride)
|
134 |
+
|
135 |
+
if in_channels != out_channels:
|
136 |
+
self.shortcut = Conv2d(
|
137 |
+
in_channels,
|
138 |
+
out_channels,
|
139 |
+
kernel_size=1,
|
140 |
+
stride=stride,
|
141 |
+
bias=False,
|
142 |
+
norm=get_norm(norm, out_channels),
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
self.shortcut = None
|
146 |
+
|
147 |
+
# The original MSRA ResNet models have stride in the first 1x1 conv
|
148 |
+
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
|
149 |
+
# stride in the 3x3 conv
|
150 |
+
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
151 |
+
|
152 |
+
self.conv1 = Conv2d(
|
153 |
+
in_channels,
|
154 |
+
bottleneck_channels,
|
155 |
+
kernel_size=1,
|
156 |
+
stride=stride_1x1,
|
157 |
+
bias=False,
|
158 |
+
norm=get_norm(norm, bottleneck_channels),
|
159 |
+
)
|
160 |
+
|
161 |
+
self.conv2 = Conv2d(
|
162 |
+
bottleneck_channels,
|
163 |
+
bottleneck_channels,
|
164 |
+
kernel_size=3,
|
165 |
+
stride=stride_3x3,
|
166 |
+
padding=1 * dilation,
|
167 |
+
bias=False,
|
168 |
+
groups=num_groups,
|
169 |
+
dilation=dilation,
|
170 |
+
norm=get_norm(norm, bottleneck_channels),
|
171 |
+
)
|
172 |
+
|
173 |
+
self.conv3 = Conv2d(
|
174 |
+
bottleneck_channels,
|
175 |
+
out_channels,
|
176 |
+
kernel_size=1,
|
177 |
+
bias=False,
|
178 |
+
norm=get_norm(norm, out_channels),
|
179 |
+
)
|
180 |
+
|
181 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
182 |
+
if layer is not None: # shortcut can be None
|
183 |
+
weight_init.c2_msra_fill(layer)
|
184 |
+
|
185 |
+
# Zero-initialize the last normalization in each residual branch,
|
186 |
+
# so that at the beginning, the residual branch starts with zeros,
|
187 |
+
# and each residual block behaves like an identity.
|
188 |
+
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
189 |
+
# "For BN layers, the learnable scaling coefficient γ is initialized
|
190 |
+
# to be 1, except for each residual block's last BN
|
191 |
+
# where γ is initialized to be 0."
|
192 |
+
|
193 |
+
# nn.init.constant_(self.conv3.norm.weight, 0)
|
194 |
+
# TODO this somehow hurts performance when training GN models from scratch.
|
195 |
+
# Add it as an option when we need to use this code to train a backbone.
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
out = self.conv1(x)
|
199 |
+
out = F.relu_(out)
|
200 |
+
|
201 |
+
out = self.conv2(out)
|
202 |
+
out = F.relu_(out)
|
203 |
+
|
204 |
+
out = self.conv3(out)
|
205 |
+
|
206 |
+
if self.shortcut is not None:
|
207 |
+
shortcut = self.shortcut(x)
|
208 |
+
else:
|
209 |
+
shortcut = x
|
210 |
+
|
211 |
+
out += shortcut
|
212 |
+
out = F.relu_(out)
|
213 |
+
return out
|
214 |
+
|
215 |
+
|
216 |
+
class DeformBottleneckBlock(CNNBlockBase):
|
217 |
+
"""
|
218 |
+
Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
|
219 |
+
in the 3x3 convolution.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
in_channels,
|
225 |
+
out_channels,
|
226 |
+
*,
|
227 |
+
bottleneck_channels,
|
228 |
+
stride=1,
|
229 |
+
num_groups=1,
|
230 |
+
norm="BN",
|
231 |
+
stride_in_1x1=False,
|
232 |
+
dilation=1,
|
233 |
+
deform_modulated=False,
|
234 |
+
deform_num_groups=1,
|
235 |
+
):
|
236 |
+
super().__init__(in_channels, out_channels, stride)
|
237 |
+
self.deform_modulated = deform_modulated
|
238 |
+
|
239 |
+
if in_channels != out_channels:
|
240 |
+
self.shortcut = Conv2d(
|
241 |
+
in_channels,
|
242 |
+
out_channels,
|
243 |
+
kernel_size=1,
|
244 |
+
stride=stride,
|
245 |
+
bias=False,
|
246 |
+
norm=get_norm(norm, out_channels),
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
self.shortcut = None
|
250 |
+
|
251 |
+
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
252 |
+
|
253 |
+
self.conv1 = Conv2d(
|
254 |
+
in_channels,
|
255 |
+
bottleneck_channels,
|
256 |
+
kernel_size=1,
|
257 |
+
stride=stride_1x1,
|
258 |
+
bias=False,
|
259 |
+
norm=get_norm(norm, bottleneck_channels),
|
260 |
+
)
|
261 |
+
|
262 |
+
if deform_modulated:
|
263 |
+
deform_conv_op = ModulatedDeformConv
|
264 |
+
# offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
|
265 |
+
offset_channels = 27
|
266 |
+
else:
|
267 |
+
deform_conv_op = DeformConv
|
268 |
+
offset_channels = 18
|
269 |
+
|
270 |
+
self.conv2_offset = Conv2d(
|
271 |
+
bottleneck_channels,
|
272 |
+
offset_channels * deform_num_groups,
|
273 |
+
kernel_size=3,
|
274 |
+
stride=stride_3x3,
|
275 |
+
padding=1 * dilation,
|
276 |
+
dilation=dilation,
|
277 |
+
)
|
278 |
+
self.conv2 = deform_conv_op(
|
279 |
+
bottleneck_channels,
|
280 |
+
bottleneck_channels,
|
281 |
+
kernel_size=3,
|
282 |
+
stride=stride_3x3,
|
283 |
+
padding=1 * dilation,
|
284 |
+
bias=False,
|
285 |
+
groups=num_groups,
|
286 |
+
dilation=dilation,
|
287 |
+
deformable_groups=deform_num_groups,
|
288 |
+
norm=get_norm(norm, bottleneck_channels),
|
289 |
+
)
|
290 |
+
|
291 |
+
self.conv3 = Conv2d(
|
292 |
+
bottleneck_channels,
|
293 |
+
out_channels,
|
294 |
+
kernel_size=1,
|
295 |
+
bias=False,
|
296 |
+
norm=get_norm(norm, out_channels),
|
297 |
+
)
|
298 |
+
|
299 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
300 |
+
if layer is not None: # shortcut can be None
|
301 |
+
weight_init.c2_msra_fill(layer)
|
302 |
+
|
303 |
+
nn.init.constant_(self.conv2_offset.weight, 0)
|
304 |
+
nn.init.constant_(self.conv2_offset.bias, 0)
|
305 |
+
|
306 |
+
def forward(self, x):
|
307 |
+
out = self.conv1(x)
|
308 |
+
out = F.relu_(out)
|
309 |
+
|
310 |
+
if self.deform_modulated:
|
311 |
+
offset_mask = self.conv2_offset(out)
|
312 |
+
offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
|
313 |
+
offset = torch.cat((offset_x, offset_y), dim=1)
|
314 |
+
mask = mask.sigmoid()
|
315 |
+
out = self.conv2(out, offset, mask)
|
316 |
+
else:
|
317 |
+
offset = self.conv2_offset(out)
|
318 |
+
out = self.conv2(out, offset)
|
319 |
+
out = F.relu_(out)
|
320 |
+
|
321 |
+
out = self.conv3(out)
|
322 |
+
|
323 |
+
if self.shortcut is not None:
|
324 |
+
shortcut = self.shortcut(x)
|
325 |
+
else:
|
326 |
+
shortcut = x
|
327 |
+
|
328 |
+
out += shortcut
|
329 |
+
out = F.relu_(out)
|
330 |
+
return out
|
331 |
+
|
332 |
+
|
333 |
+
class BasicStem(CNNBlockBase):
|
334 |
+
"""
|
335 |
+
The standard ResNet stem (layers before the first residual block),
|
336 |
+
with a conv, relu and max_pool.
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
|
340 |
+
"""
|
341 |
+
Args:
|
342 |
+
norm (str or callable): norm after the first conv layer.
|
343 |
+
See :func:`layers.get_norm` for supported format.
|
344 |
+
"""
|
345 |
+
super().__init__(in_channels, out_channels, 4)
|
346 |
+
self.in_channels = in_channels
|
347 |
+
self.conv1 = Conv2d(
|
348 |
+
in_channels,
|
349 |
+
out_channels,
|
350 |
+
kernel_size=7,
|
351 |
+
stride=2,
|
352 |
+
padding=3,
|
353 |
+
bias=False,
|
354 |
+
norm=get_norm(norm, out_channels),
|
355 |
+
)
|
356 |
+
weight_init.c2_msra_fill(self.conv1)
|
357 |
+
|
358 |
+
def forward(self, x):
|
359 |
+
x = self.conv1(x)
|
360 |
+
x = F.relu_(x)
|
361 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
|
362 |
+
return x
|
363 |
+
|
364 |
+
|
365 |
+
class ResNet(Backbone):
|
366 |
+
"""
|
367 |
+
Implement :paper:`ResNet`.
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
|
371 |
+
"""
|
372 |
+
Args:
|
373 |
+
stem (nn.Module): a stem module
|
374 |
+
stages (list[list[CNNBlockBase]]): several (typically 4) stages,
|
375 |
+
each contains multiple :class:`CNNBlockBase`.
|
376 |
+
num_classes (None or int): if None, will not perform classification.
|
377 |
+
Otherwise, will create a linear layer.
|
378 |
+
out_features (list[str]): name of the layers whose outputs should
|
379 |
+
be returned in forward. Can be anything in "stem", "linear", or "res2" ...
|
380 |
+
If None, will return the output of the last layer.
|
381 |
+
freeze_at (int): The number of stages at the beginning to freeze.
|
382 |
+
see :meth:`freeze` for detailed explanation.
|
383 |
+
"""
|
384 |
+
super().__init__()
|
385 |
+
self.stem = stem
|
386 |
+
self.num_classes = num_classes
|
387 |
+
|
388 |
+
current_stride = self.stem.stride
|
389 |
+
self._out_feature_strides = {"stem": current_stride}
|
390 |
+
self._out_feature_channels = {"stem": self.stem.out_channels}
|
391 |
+
|
392 |
+
self.stage_names, self.stages = [], []
|
393 |
+
|
394 |
+
if out_features is not None:
|
395 |
+
# Avoid keeping unused layers in this module. They consume extra memory
|
396 |
+
# and may cause allreduce to fail
|
397 |
+
num_stages = max(
|
398 |
+
[{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
|
399 |
+
)
|
400 |
+
stages = stages[:num_stages]
|
401 |
+
for i, blocks in enumerate(stages):
|
402 |
+
assert len(blocks) > 0, len(blocks)
|
403 |
+
for block in blocks:
|
404 |
+
assert isinstance(block, CNNBlockBase), block
|
405 |
+
|
406 |
+
name = "res" + str(i + 2)
|
407 |
+
stage = nn.Sequential(*blocks)
|
408 |
+
|
409 |
+
self.add_module(name, stage)
|
410 |
+
self.stage_names.append(name)
|
411 |
+
self.stages.append(stage)
|
412 |
+
|
413 |
+
self._out_feature_strides[name] = current_stride = int(
|
414 |
+
current_stride * np.prod([k.stride for k in blocks])
|
415 |
+
)
|
416 |
+
self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
|
417 |
+
self.stage_names = tuple(self.stage_names) # Make it static for scripting
|
418 |
+
|
419 |
+
if num_classes is not None:
|
420 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
421 |
+
self.linear = nn.Linear(curr_channels, num_classes)
|
422 |
+
|
423 |
+
# Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
424 |
+
# "The 1000-way fully-connected layer is initialized by
|
425 |
+
# drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
|
426 |
+
nn.init.normal_(self.linear.weight, std=0.01)
|
427 |
+
name = "linear"
|
428 |
+
|
429 |
+
if out_features is None:
|
430 |
+
out_features = [name]
|
431 |
+
self._out_features = out_features
|
432 |
+
assert len(self._out_features)
|
433 |
+
children = [x[0] for x in self.named_children()]
|
434 |
+
for out_feature in self._out_features:
|
435 |
+
assert out_feature in children, "Available children: {}".format(", ".join(children))
|
436 |
+
self.freeze(freeze_at)
|
437 |
+
|
438 |
+
def forward(self, x):
|
439 |
+
"""
|
440 |
+
Args:
|
441 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
442 |
+
|
443 |
+
Returns:
|
444 |
+
dict[str->Tensor]: names and the corresponding features
|
445 |
+
"""
|
446 |
+
assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
447 |
+
outputs = {}
|
448 |
+
x = self.stem(x)
|
449 |
+
if "stem" in self._out_features:
|
450 |
+
outputs["stem"] = x
|
451 |
+
for name, stage in zip(self.stage_names, self.stages):
|
452 |
+
x = stage(x)
|
453 |
+
if name in self._out_features:
|
454 |
+
outputs[name] = x
|
455 |
+
if self.num_classes is not None:
|
456 |
+
x = self.avgpool(x)
|
457 |
+
x = torch.flatten(x, 1)
|
458 |
+
x = self.linear(x)
|
459 |
+
if "linear" in self._out_features:
|
460 |
+
outputs["linear"] = x
|
461 |
+
return outputs
|
462 |
+
|
463 |
+
def output_shape(self):
|
464 |
+
return {
|
465 |
+
name: ShapeSpec(
|
466 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
467 |
+
)
|
468 |
+
for name in self._out_features
|
469 |
+
}
|
470 |
+
|
471 |
+
def freeze(self, freeze_at=0):
|
472 |
+
"""
|
473 |
+
Freeze the first several stages of the ResNet. Commonly used in
|
474 |
+
fine-tuning.
|
475 |
+
|
476 |
+
Layers that produce the same feature map spatial size are defined as one
|
477 |
+
"stage" by :paper:`FPN`.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
freeze_at (int): number of stages to freeze.
|
481 |
+
`1` means freezing the stem. `2` means freezing the stem and
|
482 |
+
one residual stage, etc.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
nn.Module: this ResNet itself
|
486 |
+
"""
|
487 |
+
if freeze_at >= 1:
|
488 |
+
self.stem.freeze()
|
489 |
+
for idx, stage in enumerate(self.stages, start=2):
|
490 |
+
if freeze_at >= idx:
|
491 |
+
for block in stage.children():
|
492 |
+
block.freeze()
|
493 |
+
return self
|
494 |
+
|
495 |
+
@staticmethod
|
496 |
+
def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
|
497 |
+
"""
|
498 |
+
Create a list of blocks of the same type that forms one ResNet stage.
|
499 |
+
|
500 |
+
Args:
|
501 |
+
block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
|
502 |
+
stage. A module of this type must not change spatial resolution of inputs unless its
|
503 |
+
stride != 1.
|
504 |
+
num_blocks (int): number of blocks in this stage
|
505 |
+
in_channels (int): input channels of the entire stage.
|
506 |
+
out_channels (int): output channels of **every block** in the stage.
|
507 |
+
kwargs: other arguments passed to the constructor of
|
508 |
+
`block_class`. If the argument name is "xx_per_block", the
|
509 |
+
argument is a list of values to be passed to each block in the
|
510 |
+
stage. Otherwise, the same argument is passed to every block
|
511 |
+
in the stage.
|
512 |
+
|
513 |
+
Returns:
|
514 |
+
list[CNNBlockBase]: a list of block module.
|
515 |
+
|
516 |
+
Examples:
|
517 |
+
::
|
518 |
+
stage = ResNet.make_stage(
|
519 |
+
BottleneckBlock, 3, in_channels=16, out_channels=64,
|
520 |
+
bottleneck_channels=16, num_groups=1,
|
521 |
+
stride_per_block=[2, 1, 1],
|
522 |
+
dilations_per_block=[1, 1, 2]
|
523 |
+
)
|
524 |
+
|
525 |
+
Usually, layers that produce the same feature map spatial size are defined as one
|
526 |
+
"stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
|
527 |
+
all be 1.
|
528 |
+
"""
|
529 |
+
blocks = []
|
530 |
+
for i in range(num_blocks):
|
531 |
+
curr_kwargs = {}
|
532 |
+
for k, v in kwargs.items():
|
533 |
+
if k.endswith("_per_block"):
|
534 |
+
assert len(v) == num_blocks, (
|
535 |
+
f"Argument '{k}' of make_stage should have the "
|
536 |
+
f"same length as num_blocks={num_blocks}."
|
537 |
+
)
|
538 |
+
newk = k[: -len("_per_block")]
|
539 |
+
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
540 |
+
curr_kwargs[newk] = v[i]
|
541 |
+
else:
|
542 |
+
curr_kwargs[k] = v
|
543 |
+
|
544 |
+
blocks.append(
|
545 |
+
block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
|
546 |
+
)
|
547 |
+
in_channels = out_channels
|
548 |
+
return blocks
|
549 |
+
|
550 |
+
@staticmethod
|
551 |
+
def make_default_stages(depth, block_class=None, **kwargs):
|
552 |
+
"""
|
553 |
+
Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
|
554 |
+
If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
|
555 |
+
instead for fine-grained customization.
|
556 |
+
|
557 |
+
Args:
|
558 |
+
depth (int): depth of ResNet
|
559 |
+
block_class (type): the CNN block class. Has to accept
|
560 |
+
`bottleneck_channels` argument for depth > 50.
|
561 |
+
By default it is BasicBlock or BottleneckBlock, based on the
|
562 |
+
depth.
|
563 |
+
kwargs:
|
564 |
+
other arguments to pass to `make_stage`. Should not contain
|
565 |
+
stride and channels, as they are predefined for each depth.
|
566 |
+
|
567 |
+
Returns:
|
568 |
+
list[list[CNNBlockBase]]: modules in all stages; see arguments of
|
569 |
+
:class:`ResNet.__init__`.
|
570 |
+
"""
|
571 |
+
num_blocks_per_stage = {
|
572 |
+
18: [2, 2, 2, 2],
|
573 |
+
34: [3, 4, 6, 3],
|
574 |
+
50: [3, 4, 6, 3],
|
575 |
+
101: [3, 4, 23, 3],
|
576 |
+
152: [3, 8, 36, 3],
|
577 |
+
}[depth]
|
578 |
+
if block_class is None:
|
579 |
+
block_class = BasicBlock if depth < 50 else BottleneckBlock
|
580 |
+
if depth < 50:
|
581 |
+
in_channels = [64, 64, 128, 256]
|
582 |
+
out_channels = [64, 128, 256, 512]
|
583 |
+
else:
|
584 |
+
in_channels = [64, 256, 512, 1024]
|
585 |
+
out_channels = [256, 512, 1024, 2048]
|
586 |
+
ret = []
|
587 |
+
for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
|
588 |
+
if depth >= 50:
|
589 |
+
kwargs["bottleneck_channels"] = o // 4
|
590 |
+
ret.append(
|
591 |
+
ResNet.make_stage(
|
592 |
+
block_class=block_class,
|
593 |
+
num_blocks=n,
|
594 |
+
stride_per_block=[s] + [1] * (n - 1),
|
595 |
+
in_channels=i,
|
596 |
+
out_channels=o,
|
597 |
+
**kwargs,
|
598 |
+
)
|
599 |
+
)
|
600 |
+
return ret
|
601 |
+
|
602 |
+
|
603 |
+
ResNetBlockBase = CNNBlockBase
|
604 |
+
"""
|
605 |
+
Alias for backward compatibiltiy.
|
606 |
+
"""
|
607 |
+
|
608 |
+
|
609 |
+
def make_stage(*args, **kwargs):
|
610 |
+
"""
|
611 |
+
Deprecated alias for backward compatibiltiy.
|
612 |
+
"""
|
613 |
+
return ResNet.make_stage(*args, **kwargs)
|
614 |
+
|
615 |
+
|
616 |
+
@BACKBONE_REGISTRY.register()
|
617 |
+
def custom_bn_build_resnet_backbone(cfg, input_shape):
|
618 |
+
"""
|
619 |
+
Create a ResNet instance from config.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
ResNet: a :class:`ResNet` instance.
|
623 |
+
"""
|
624 |
+
# need registration of new blocks/stems?
|
625 |
+
norm = cfg.MODEL.RESNETS.NORM
|
626 |
+
stem = BasicStem(
|
627 |
+
in_channels=input_shape.channels,
|
628 |
+
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
629 |
+
norm=norm,
|
630 |
+
)
|
631 |
+
|
632 |
+
# fmt: off
|
633 |
+
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
|
634 |
+
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
|
635 |
+
depth = cfg.MODEL.RESNETS.DEPTH
|
636 |
+
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
637 |
+
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
638 |
+
bottleneck_channels = num_groups * width_per_group
|
639 |
+
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
|
640 |
+
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
|
641 |
+
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
642 |
+
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
|
643 |
+
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
|
644 |
+
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
|
645 |
+
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
|
646 |
+
# fmt: on
|
647 |
+
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
|
648 |
+
|
649 |
+
num_blocks_per_stage = {
|
650 |
+
18: [2, 2, 2, 2],
|
651 |
+
34: [3, 4, 6, 3],
|
652 |
+
50: [3, 4, 6, 3],
|
653 |
+
101: [3, 4, 23, 3],
|
654 |
+
152: [3, 8, 36, 3],
|
655 |
+
}[depth]
|
656 |
+
|
657 |
+
if depth in [18, 34]:
|
658 |
+
assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
|
659 |
+
assert not any(
|
660 |
+
deform_on_per_stage
|
661 |
+
), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
|
662 |
+
assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
|
663 |
+
assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
|
664 |
+
|
665 |
+
stages = []
|
666 |
+
|
667 |
+
for idx, stage_idx in enumerate(range(2, 6)):
|
668 |
+
# res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
|
669 |
+
dilation = res5_dilation if stage_idx == 5 else 1
|
670 |
+
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
|
671 |
+
stage_kargs = {
|
672 |
+
"num_blocks": num_blocks_per_stage[idx],
|
673 |
+
"stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
|
674 |
+
"in_channels": in_channels,
|
675 |
+
"out_channels": out_channels,
|
676 |
+
"norm": norm,
|
677 |
+
}
|
678 |
+
# Use BasicBlock for R18 and R34.
|
679 |
+
if depth in [18, 34]:
|
680 |
+
stage_kargs["block_class"] = BasicBlock
|
681 |
+
else:
|
682 |
+
stage_kargs["bottleneck_channels"] = bottleneck_channels
|
683 |
+
stage_kargs["stride_in_1x1"] = stride_in_1x1
|
684 |
+
stage_kargs["dilation"] = dilation
|
685 |
+
stage_kargs["num_groups"] = num_groups
|
686 |
+
if deform_on_per_stage[idx]:
|
687 |
+
stage_kargs["block_class"] = DeformBottleneckBlock
|
688 |
+
stage_kargs["deform_modulated"] = deform_modulated
|
689 |
+
stage_kargs["deform_num_groups"] = deform_num_groups
|
690 |
+
else:
|
691 |
+
stage_kargs["block_class"] = BottleneckBlock
|
692 |
+
blocks = ResNet.make_stage(**stage_kargs)
|
693 |
+
in_channels = out_channels
|
694 |
+
out_channels *= 2
|
695 |
+
bottleneck_channels *= 2
|
696 |
+
stages.append(blocks)
|
697 |
+
return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
|
kmax_deeplab/modeling/criterion.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
|
2 |
+
# Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py
|
3 |
+
# Modified by Qihang Yu
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
_SOFTMAX_MASKING_CONSTANT = -99999.0
|
10 |
+
|
11 |
+
# https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan
|
12 |
+
def divide_no_nan(x: torch.Tensor, y: torch.Tensor):
|
13 |
+
return torch.nan_to_num(x / y, nan=0.0, posinf=0.0, neginf=0.0)
|
14 |
+
|
15 |
+
|
16 |
+
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L393
|
17 |
+
def focal_cross_entropy_loss(
|
18 |
+
pred: torch.Tensor,
|
19 |
+
gt: torch.Tensor,
|
20 |
+
weight: torch.Tensor, # This is for PQ-loss weighting
|
21 |
+
focal_loss_alpha: float = 0.75,
|
22 |
+
focal_loss_gamma: float = 0.0,
|
23 |
+
background_channel_index: int = -1):
|
24 |
+
"""
|
25 |
+
pred: B x N x C
|
26 |
+
gt: B x N
|
27 |
+
weight: B x N
|
28 |
+
"""
|
29 |
+
pred = pred.transpose(1, 2) # B x C x N
|
30 |
+
gt = F.one_hot(gt, num_classes=pred.shape[1]).transpose(1, 2).to(pred) # B x C x N
|
31 |
+
loss = F.cross_entropy(pred, gt, reduction="none") # B x N
|
32 |
+
if focal_loss_gamma == 0.0:
|
33 |
+
focal_loss = loss
|
34 |
+
else:
|
35 |
+
pred = F.softmax(pred, dim=1) # B x C x N
|
36 |
+
pt = (pred * gt).sum(1) # B x N
|
37 |
+
focal_loss = torch.pow(1.0 - pt, focal_loss_gamma) * loss # B x N
|
38 |
+
|
39 |
+
if focal_loss_alpha >= 0:
|
40 |
+
alpha_weights = (
|
41 |
+
focal_loss_alpha * (1.0 - gt[:, background_channel_index])
|
42 |
+
+ (1 - focal_loss_alpha) * gt[:, background_channel_index]) # B x N
|
43 |
+
focal_loss = alpha_weights * focal_loss # B x N
|
44 |
+
|
45 |
+
focal_loss = focal_loss * weight # B x N
|
46 |
+
focal_loss = focal_loss.flatten(1)
|
47 |
+
num_non_zero = (focal_loss != 0.0).to(focal_loss).sum(-1) # B
|
48 |
+
num_non_zero = torch.clamp(num_non_zero, min=1.0)
|
49 |
+
loss_sum_per_sample = focal_loss.sum(-1) # B
|
50 |
+
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
|
51 |
+
|
52 |
+
|
53 |
+
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L50
|
54 |
+
def _gumbel_topk_sample(logits: torch.Tensor, k: int):
|
55 |
+
"""Samples k points from the softmax distribution with Gumbel-Top-k trick."""
|
56 |
+
# Note that torch.rand is [0, 1), we need to make it (0, 1) to ensure the log is valid.
|
57 |
+
gumbel_noise = torch.rand(size=logits.shape, dtype=logits.dtype, device=logits.device)
|
58 |
+
gumbel_noise = -torch.log(-torch.log(gumbel_noise))
|
59 |
+
_, indices = torch.topk(logits + gumbel_noise, k)
|
60 |
+
return indices
|
61 |
+
|
62 |
+
|
63 |
+
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L576
|
64 |
+
def pixelwise_insdis_loss(
|
65 |
+
pixel_feature: torch.Tensor,
|
66 |
+
gt_mask: torch.Tensor,
|
67 |
+
sample_temperature: float,
|
68 |
+
sample_k: int,
|
69 |
+
instance_discrimination_temperature: float,
|
70 |
+
pixel_gt_void_mask: torch.Tensor,
|
71 |
+
inverse_gt_mask_area: torch.Tensor
|
72 |
+
):
|
73 |
+
|
74 |
+
# pixel_feature: B x C x H x W
|
75 |
+
# gt_mask: B x N x H x W
|
76 |
+
pixel_feature = pixel_feature.flatten(2) # B x C x HW
|
77 |
+
gt_mask = gt_mask.flatten(2) # B x N x HW
|
78 |
+
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
|
79 |
+
inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW
|
80 |
+
|
81 |
+
sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW
|
82 |
+
# sample_logits.masked_fill_(pixel_gt_void_mask, float('-inf'))
|
83 |
+
sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT
|
84 |
+
|
85 |
+
sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K
|
86 |
+
# Sample ground truth one-hot encodings and compute gt_similarity.
|
87 |
+
pixel_gt_sampled_feature = torch.gather(gt_mask, dim=2, index=sample_indices.unsqueeze(1).repeat(1, gt_mask.shape[1], 1)) # B x N x K
|
88 |
+
sampled_gt_similarity = torch.einsum('bnk,bnj->bkj', pixel_gt_sampled_feature, pixel_gt_sampled_feature) # B x K x K
|
89 |
+
|
90 |
+
# Normalize the ground truth similarity into a distribution (sum to 1).
|
91 |
+
pixel_normalizing_constant = sampled_gt_similarity.sum(dim=1, keepdim=True) # B x 1 x K
|
92 |
+
sampled_gt_similarity /= torch.clamp(pixel_normalizing_constant, min=1.0) # B x K x K
|
93 |
+
|
94 |
+
# Sample predicted features and compute pred_similarity.
|
95 |
+
pixel_pred_sampled_feature = torch.gather(pixel_feature, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pixel_feature.shape[1], 1)) # B x C x K
|
96 |
+
sampled_pred_similarity = torch.einsum('bck,bcj->bkj', pixel_pred_sampled_feature, pixel_pred_sampled_feature) # B x K x K
|
97 |
+
sampled_pred_similarity /= instance_discrimination_temperature # B x K x K
|
98 |
+
loss = F.cross_entropy(sampled_pred_similarity, sampled_gt_similarity, reduction="none") # B x K
|
99 |
+
|
100 |
+
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
|
101 |
+
num_non_zero = torch.clamp(num_non_zero, min=1.0)
|
102 |
+
loss_sum_per_sample = loss.sum(-1) # B
|
103 |
+
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
|
104 |
+
|
105 |
+
|
106 |
+
def aux_semantic_loss(
|
107 |
+
pred_semantic_logits: torch.Tensor,
|
108 |
+
ground_truth_semantic: torch.Tensor,
|
109 |
+
sample_temperature: float,
|
110 |
+
sample_k: int,
|
111 |
+
pixel_gt_void_mask: torch.Tensor,
|
112 |
+
inverse_gt_mask_area: torch.Tensor,
|
113 |
+
num_classes: int):
|
114 |
+
|
115 |
+
pred_semantic_logits = pred_semantic_logits.flatten(2) # B x C x HW
|
116 |
+
ground_truth_semantic = ground_truth_semantic.flatten(1) # B x HW
|
117 |
+
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
|
118 |
+
inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW
|
119 |
+
|
120 |
+
sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW
|
121 |
+
sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT
|
122 |
+
|
123 |
+
sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K
|
124 |
+
sampled_ground_truth_semantic = torch.gather(ground_truth_semantic, dim=1, index=sample_indices) # B x K
|
125 |
+
sampled_pred_semantic_logits = torch.gather(pred_semantic_logits, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pred_semantic_logits.shape[1], 1)) # B x C x K
|
126 |
+
# ignore the class index num_classes.
|
127 |
+
keep_mask = (sampled_ground_truth_semantic != num_classes) # B x K
|
128 |
+
loss = F.cross_entropy(sampled_pred_semantic_logits, sampled_ground_truth_semantic, ignore_index=num_classes, reduction='none') # B x K
|
129 |
+
loss = loss * keep_mask.to(loss)
|
130 |
+
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
|
131 |
+
num_non_zero = torch.clamp(num_non_zero, min=1.0)
|
132 |
+
loss_sum_per_sample = loss.sum(-1) # B
|
133 |
+
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
|
134 |
+
|
135 |
+
|
136 |
+
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L56
|
137 |
+
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L510
|
138 |
+
def dice_loss(
|
139 |
+
inputs: torch.Tensor,
|
140 |
+
targets: torch.Tensor,
|
141 |
+
pixel_gt_void_mask: torch.Tensor,
|
142 |
+
matched_cls_prob: torch.Tensor
|
143 |
+
):
|
144 |
+
"""
|
145 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
146 |
+
Args:
|
147 |
+
inputs: A float tensor of arbitrary shape.
|
148 |
+
The predictions for each example.
|
149 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
150 |
+
classification label for each element in inputs
|
151 |
+
(0 for the negative class and 1 for the positive class).
|
152 |
+
"""
|
153 |
+
inputs = inputs.softmax(1) # B N HW
|
154 |
+
# https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L111
|
155 |
+
inputs = inputs.masked_fill(pixel_gt_void_mask.unsqueeze(1), 0) # remove void pixels.
|
156 |
+
smooth = 1.0
|
157 |
+
intersection = 2 * (inputs * targets).sum(-1) + smooth # B x N
|
158 |
+
denominator = inputs.sum(-1) + targets.sum(-1) + smooth # B x N
|
159 |
+
loss = 1.0 - divide_no_nan(intersection, denominator)
|
160 |
+
loss *= matched_cls_prob
|
161 |
+
# Note: kMaX-DeepLab sum over num_masks and avg over batches. But here batch and num_mask are one
|
162 |
+
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L559
|
163 |
+
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L402
|
164 |
+
# As the existing of modifer, it equals to multiplier by 0.75
|
165 |
+
return (loss.sum(1) * 0.75/128).mean() # sum over masks and mean over batches.
|
166 |
+
|
167 |
+
|
168 |
+
def softmax_ce_loss(
|
169 |
+
inputs: torch.Tensor,
|
170 |
+
targets: torch.Tensor,
|
171 |
+
pixel_gt_void_mask: torch.Tensor,
|
172 |
+
):
|
173 |
+
"""
|
174 |
+
Args:
|
175 |
+
inputs: A float tensor of arbitrary shape.
|
176 |
+
The predictions for each example.
|
177 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
178 |
+
classification label for each element in inputs
|
179 |
+
(0 for the negative class and 1 for the positive class).
|
180 |
+
Returns:
|
181 |
+
Loss tensor
|
182 |
+
"""
|
183 |
+
loss = F.cross_entropy(inputs, targets, reduction="none") # B x HW
|
184 |
+
loss = loss.masked_fill(pixel_gt_void_mask, 0) # remove void pixels.
|
185 |
+
|
186 |
+
num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
|
187 |
+
num_non_zero = torch.clamp(num_non_zero, min=1.0)
|
188 |
+
loss_sum_per_sample = loss.sum(-1) # B
|
189 |
+
return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
|
190 |
+
|
191 |
+
|
192 |
+
class SetCriterion(nn.Module):
|
193 |
+
"""This class computes the loss for DETR.
|
194 |
+
The process happens in two steps:
|
195 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
196 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, share_final_matching,
|
200 |
+
pixel_insdis_temperature=1.5, pixel_insdis_sample_k=4096,
|
201 |
+
aux_semantic_temperature=2.0, aux_semantic_sample_k=4096):
|
202 |
+
"""Create the criterion.
|
203 |
+
Parameters:
|
204 |
+
num_classes: number of object categories, omitting the special no-object category
|
205 |
+
matcher: module able to compute a matching between targets and proposals
|
206 |
+
eos_coef: relative classification weight applied to the no-object category
|
207 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
208 |
+
"""
|
209 |
+
super().__init__()
|
210 |
+
self.num_classes = num_classes
|
211 |
+
self.matcher = matcher
|
212 |
+
self.weight_dict = weight_dict
|
213 |
+
self.eos_coef = eos_coef
|
214 |
+
self.losses = losses
|
215 |
+
self.share_final_matching = share_final_matching
|
216 |
+
self.pixel_insdis_temperature = pixel_insdis_temperature
|
217 |
+
self.pixel_insdis_sample_k = pixel_insdis_sample_k
|
218 |
+
self.aux_semantic_temperature = aux_semantic_temperature
|
219 |
+
self.aux_semantic_sample_k = aux_semantic_sample_k
|
220 |
+
|
221 |
+
def loss_labels(self, outputs, targets):
|
222 |
+
"""Classification loss (NLL)
|
223 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
224 |
+
"""
|
225 |
+
assert "pred_logits" in outputs
|
226 |
+
src_logits = outputs["pred_logits"] # B x N x C
|
227 |
+
target_classes = targets["labels"] # B x N
|
228 |
+
pq_loss_class_weight = targets["pq_loss_class_weight"]
|
229 |
+
losses = {"loss_ce": focal_cross_entropy_loss(src_logits, target_classes, pq_loss_class_weight)}
|
230 |
+
return losses
|
231 |
+
|
232 |
+
def loss_masks(self, outputs, targets):
|
233 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
234 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
235 |
+
"""
|
236 |
+
src_masks = outputs["pred_masks"] # B x N x H x W
|
237 |
+
target_masks = targets["masks"]
|
238 |
+
pq_loss_mask_weight = targets["pq_loss_mask_weight"]
|
239 |
+
pixel_gt_void_mask = targets["pixel_gt_void_mask"]
|
240 |
+
|
241 |
+
src_masks = src_masks.flatten(2) # B x N x HW
|
242 |
+
target_masks = target_masks.flatten(2) # B x N x HW
|
243 |
+
pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
|
244 |
+
|
245 |
+
losses = {
|
246 |
+
"loss_mask": softmax_ce_loss(src_masks, target_masks, pixel_gt_void_mask),
|
247 |
+
"loss_dice": dice_loss(src_masks, target_masks, pixel_gt_void_mask, pq_loss_mask_weight),
|
248 |
+
}
|
249 |
+
|
250 |
+
return losses
|
251 |
+
|
252 |
+
def loss_pixels(self, outputs, targets):
|
253 |
+
pixel_feature = outputs["pixel_feature"]
|
254 |
+
target_masks = targets["masks"]
|
255 |
+
pixel_gt_void_mask = targets["pixel_gt_void_mask"]
|
256 |
+
inverse_gt_mask_area = targets["inverse_gt_mask_area"]
|
257 |
+
|
258 |
+
losses = {"loss_pixel_insdis": pixelwise_insdis_loss(
|
259 |
+
pixel_feature=pixel_feature,
|
260 |
+
gt_mask=target_masks,
|
261 |
+
sample_temperature=self.pixel_insdis_temperature,
|
262 |
+
sample_k=self.pixel_insdis_sample_k,
|
263 |
+
instance_discrimination_temperature=0.3,
|
264 |
+
pixel_gt_void_mask=pixel_gt_void_mask,
|
265 |
+
inverse_gt_mask_area=inverse_gt_mask_area
|
266 |
+
)}
|
267 |
+
|
268 |
+
del target_masks
|
269 |
+
return losses
|
270 |
+
|
271 |
+
def loss_semantic(self, outputs, targets):
|
272 |
+
pred_semantic_logits = outputs["aux_semantic_pred"]
|
273 |
+
ground_truth_semantic = targets["ground_truth_semantic"]
|
274 |
+
pixel_gt_void_mask = targets["pixel_gt_void_mask"].flatten(1)
|
275 |
+
inverse_gt_mask_area = targets["inverse_gt_mask_area"].flatten(1)
|
276 |
+
|
277 |
+
losses = {"loss_aux_semantic": aux_semantic_loss(
|
278 |
+
pred_semantic_logits=pred_semantic_logits,
|
279 |
+
ground_truth_semantic=ground_truth_semantic,
|
280 |
+
sample_temperature=self.aux_semantic_temperature,
|
281 |
+
sample_k=self.aux_semantic_sample_k,
|
282 |
+
pixel_gt_void_mask=pixel_gt_void_mask,
|
283 |
+
inverse_gt_mask_area=inverse_gt_mask_area,
|
284 |
+
num_classes=self.num_classes
|
285 |
+
)}
|
286 |
+
return losses
|
287 |
+
|
288 |
+
@torch.no_grad()
|
289 |
+
def _get_src_permutation_idx(self, indices):
|
290 |
+
# permute predictions following indices
|
291 |
+
# torch.full_like gives a tensor full of i in shape of src.shape
|
292 |
+
# at each iter, i is the index, src is the src ind in shape of (N)
|
293 |
+
# so batch_idx is concat of (0,0,...), (1,1,...), with shape (N0+N1+N2+...+Nb)
|
294 |
+
# so if we flatten gt/pred across bathces, this gives the batch_id of each sample
|
295 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
296 |
+
# src_idx is src_ind concated to shape (N0+N1+N2+...+Nb)
|
297 |
+
# it is a flattened concat of mask_id at each batch
|
298 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
299 |
+
return batch_idx, src_idx
|
300 |
+
|
301 |
+
|
302 |
+
def get_loss(self, loss, outputs, targets):
|
303 |
+
loss_map = {
|
304 |
+
'labels': self.loss_labels,
|
305 |
+
'masks': self.loss_masks,
|
306 |
+
'pixels': self.loss_pixels,
|
307 |
+
'aux_semantic': self.loss_semantic,
|
308 |
+
}
|
309 |
+
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
310 |
+
return loss_map[loss](outputs, targets)
|
311 |
+
|
312 |
+
@torch.no_grad()
|
313 |
+
def process_gt(self, outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=False):
|
314 |
+
# Permute&Pad Pred> for loss compuation.
|
315 |
+
# By controling process_gt, we can share the matching results for all preds.
|
316 |
+
src_idx = self._get_src_permutation_idx(indices)
|
317 |
+
|
318 |
+
src_masks = outputs["pred_masks"].detach() # B x N x H x W
|
319 |
+
|
320 |
+
# Pad and permute the target_mask to B x N x H x W
|
321 |
+
target_masks = torch.zeros_like(src_masks)
|
322 |
+
target_masks_o = torch.cat([t["masks"][J] for t, (_, J) in zip(targets, indices)]).to(target_masks)
|
323 |
+
target_masks[src_idx] = target_masks_o
|
324 |
+
|
325 |
+
# Pad and permute the matched_cls_prob to B x N
|
326 |
+
matched_cls_prob_o = torch.cat([cls_prob for cls_prob in matched_cls_prob])
|
327 |
+
matched_cls_prob_o = torch.clamp(matched_cls_prob_o, min=self.eos_coef)
|
328 |
+
# https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L1034
|
329 |
+
# no penalty for unmatched masks.
|
330 |
+
matched_cls_prob = torch.full(
|
331 |
+
src_masks.shape[:2], 0, dtype=src_masks.dtype, device=src_masks.device
|
332 |
+
) # B x N
|
333 |
+
matched_cls_prob[src_idx] = matched_cls_prob_o.to(matched_cls_prob)
|
334 |
+
|
335 |
+
# pixel_gt_void_mask is used to indicate those pixels without labels.
|
336 |
+
pixel_gt_void_mask = (target_masks.sum(1) < 1) # B x H x W
|
337 |
+
|
338 |
+
# inverse_gt_mask_area is used to sample pixels.
|
339 |
+
mask_gt_area = target_masks.sum(2).sum(2) # B x N
|
340 |
+
pixel_gt_area = torch.einsum('bnhw,bn->bhw', target_masks, mask_gt_area) # B x H x W
|
341 |
+
inverse_gt_mask_area = (pixel_gt_area.shape[1] * pixel_gt_area.shape[2]) / torch.clamp(pixel_gt_area, min=1.0) # B x H x W
|
342 |
+
|
343 |
+
src_logits = outputs["pred_logits"] # B x N x C
|
344 |
+
# Pad and permute the target_classes to B x N
|
345 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
346 |
+
# This serves as a padding.
|
347 |
+
target_classes = torch.full(
|
348 |
+
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
349 |
+
)
|
350 |
+
# We put real GT to those corresponds to src_idx, and put void into other places.
|
351 |
+
target_classes[src_idx] = target_classes_o
|
352 |
+
|
353 |
+
src_masks_prob = src_masks.softmax(1)
|
354 |
+
void_mask = pixel_gt_void_mask.to(src_masks_prob) # B x H x W
|
355 |
+
# compute iou instead of dice for void overlapping.
|
356 |
+
def computer_iou_score(x, y):
|
357 |
+
# x : B x N x H x W
|
358 |
+
# y : B x H x W
|
359 |
+
x = x.flatten(2) # B x N x L
|
360 |
+
y = y.flatten(1) # B x L
|
361 |
+
intersection = torch.einsum('bnl,bl->bn', x, y) # B x N
|
362 |
+
denominator = x.sum(-1) # B x N
|
363 |
+
return intersection / (denominator + 1e-5) # B x N
|
364 |
+
|
365 |
+
# Pad and permute the matched_dice to B x N
|
366 |
+
matched_dice_o = torch.cat([dice for dice in matched_dice])
|
367 |
+
matched_dice = computer_iou_score(src_masks_prob, void_mask) # unmatched masks use their dice with void
|
368 |
+
matched_dice[src_idx] = matched_dice_o.to(matched_dice)
|
369 |
+
matched_dice = torch.clamp(matched_dice, min=self.eos_coef)
|
370 |
+
|
371 |
+
|
372 |
+
processed_gt = {"masks": target_masks, "labels": target_classes,
|
373 |
+
"pq_loss_mask_weight": matched_cls_prob,
|
374 |
+
"pq_loss_class_weight": matched_dice,
|
375 |
+
"pixel_gt_void_mask": pixel_gt_void_mask,
|
376 |
+
"inverse_gt_mask_area": inverse_gt_mask_area,}
|
377 |
+
|
378 |
+
if process_semantic:
|
379 |
+
# To obtain semantic gt
|
380 |
+
ground_truth_semantic = [t["semantic_masks"] for t in targets]
|
381 |
+
ground_truth_semantic = torch.stack(ground_truth_semantic, dim=0) # B x H x W
|
382 |
+
# self.num_classes is set to ignore label
|
383 |
+
ground_truth_semantic[ground_truth_semantic==-1] = self.num_classes
|
384 |
+
processed_gt.update({"ground_truth_semantic": ground_truth_semantic})
|
385 |
+
|
386 |
+
return processed_gt
|
387 |
+
|
388 |
+
|
389 |
+
def forward(self, outputs, targets):
|
390 |
+
"""This performs the loss computation.
|
391 |
+
Parameters:
|
392 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
393 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
394 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
395 |
+
"""
|
396 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
397 |
+
indices, matched_dice, matched_cls_prob = self.matcher(outputs_without_aux, targets)
|
398 |
+
# Pad GT to the same number of prediction.
|
399 |
+
processed_targets = self.process_gt(outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=True)
|
400 |
+
# Compute all the requested losses
|
401 |
+
losses = {}
|
402 |
+
for loss in self.losses:
|
403 |
+
losses.update(self.get_loss(loss, outputs, processed_targets))
|
404 |
+
|
405 |
+
if "aux_outputs" in outputs:
|
406 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
407 |
+
# We share matching results across predictions.
|
408 |
+
if not self.share_final_matching:
|
409 |
+
indices, matched_dice, matched_cls_prob = self.matcher(aux_outputs, targets)
|
410 |
+
if not self.share_final_matching:
|
411 |
+
processed_targets = self.process_gt(aux_outputs, targets, indices, matched_dice, matched_cls_prob)
|
412 |
+
for loss in self.losses:
|
413 |
+
if loss in ['aux_semantic']:
|
414 |
+
# Only for final output.
|
415 |
+
continue
|
416 |
+
l_dict = self.get_loss(loss, aux_outputs, processed_targets)
|
417 |
+
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
418 |
+
losses.update(l_dict)
|
419 |
+
return losses
|
420 |
+
|
421 |
+
def __repr__(self):
|
422 |
+
head = "Criterion " + self.__class__.__name__
|
423 |
+
body = [
|
424 |
+
"matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
|
425 |
+
"losses: {}".format(self.losses),
|
426 |
+
"weight_dict: {}".format(self.weight_dict),
|
427 |
+
"num_classes: {}".format(self.num_classes),
|
428 |
+
"eos_coef: {}".format(self.eos_coef),
|
429 |
+
]
|
430 |
+
_repr_indent = 4
|
431 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
432 |
+
return "\n".join(lines)
|
kmax_deeplab/modeling/matcher.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py
|
2 |
+
# Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py
|
3 |
+
# Modified by Qihang Yu
|
4 |
+
|
5 |
+
"""
|
6 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
7 |
+
"""
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from scipy.optimize import linear_sum_assignment
|
11 |
+
from torch import nn
|
12 |
+
from torch.cuda.amp import autocast
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L158
|
17 |
+
@torch.no_grad()
|
18 |
+
def compute_mask_similarity(inputs: torch.Tensor, targets: torch.Tensor):
|
19 |
+
"""
|
20 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
21 |
+
Args:
|
22 |
+
inputs: A float tensor of arbitrary shape.
|
23 |
+
The predictions for each example.
|
24 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
25 |
+
classification label for each element in inputs
|
26 |
+
(0 for the negative class and 1 for the positive class).
|
27 |
+
"""
|
28 |
+
denominator_epsilon = 1e-5
|
29 |
+
inputs = F.softmax(inputs, dim=0)
|
30 |
+
inputs = inputs.flatten(1) # N x HW
|
31 |
+
|
32 |
+
pixel_gt_non_void_mask = (targets.sum(0, keepdim=True) > 0).to(inputs) # 1xHW
|
33 |
+
inputs = inputs * pixel_gt_non_void_mask
|
34 |
+
|
35 |
+
intersection = torch.einsum("nc,mc->nm", inputs, targets)
|
36 |
+
denominator = (inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]) / 2.0
|
37 |
+
return intersection / (denominator + denominator_epsilon)
|
38 |
+
|
39 |
+
|
40 |
+
# https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L941
|
41 |
+
@torch.no_grad()
|
42 |
+
def compute_class_similarity(inputs: torch.Tensor, targets: torch.Tensor):
|
43 |
+
pred_class_prob = inputs.softmax(-1)[..., :-1] # exclude the void class
|
44 |
+
return pred_class_prob[:, targets]
|
45 |
+
|
46 |
+
|
47 |
+
class HungarianMatcher(nn.Module):
|
48 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
49 |
+
|
50 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
51 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
52 |
+
while the others are un-matched (and thus treated as non-objects).
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self):
|
56 |
+
"""Creates the matcher
|
57 |
+
|
58 |
+
Params:
|
59 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
60 |
+
cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
|
61 |
+
cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
|
62 |
+
"""
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def memory_efficient_forward(self, outputs, targets):
|
67 |
+
"""More memory-friendly matching"""
|
68 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
69 |
+
|
70 |
+
indices = []
|
71 |
+
matched_dice = []
|
72 |
+
matched_cls_prob = []
|
73 |
+
# Iterate through batch size
|
74 |
+
for b in range(bs):
|
75 |
+
with autocast(enabled=False):
|
76 |
+
class_similarity = compute_class_similarity(outputs["pred_logits"][b].float(), targets[b]["labels"])
|
77 |
+
out_mask = outputs["pred_masks"][b].flatten(1) # [num_queries, H_pred, W_pred]
|
78 |
+
# gt masks are already padded when preparing target
|
79 |
+
tgt_mask = targets[b]["masks"].to(out_mask).flatten(1)
|
80 |
+
with autocast(enabled=False):
|
81 |
+
mask_similarity = compute_mask_similarity(out_mask.float(), tgt_mask.float())
|
82 |
+
|
83 |
+
# Final cost matrix
|
84 |
+
C = - mask_similarity * class_similarity
|
85 |
+
C = C.reshape(num_queries, -1).cpu() # N x M , N = num_queries, M = num_gt
|
86 |
+
|
87 |
+
# the assignment will be truncated to a square matrix.
|
88 |
+
row_ind, col_ind = linear_sum_assignment(C)
|
89 |
+
matched_dice.append(mask_similarity[row_ind, col_ind].detach())
|
90 |
+
matched_cls_prob.append(class_similarity[row_ind, col_ind].detach())
|
91 |
+
indices.append((row_ind, col_ind)) # row_ind and col_ind, row_ind = 0,1,2,3,...,N-1, col_ind = a,b,c,d,...
|
92 |
+
|
93 |
+
indices = [
|
94 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
95 |
+
for i, j in indices
|
96 |
+
]
|
97 |
+
|
98 |
+
return indices, matched_dice, matched_cls_prob
|
99 |
+
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def forward(self, outputs, targets):
|
103 |
+
"""Performs the matching
|
104 |
+
|
105 |
+
Params:
|
106 |
+
outputs: This is a dict that contains at least these entries:
|
107 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
108 |
+
"pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
|
109 |
+
|
110 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
111 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
112 |
+
objects in the target) containing the class labels
|
113 |
+
"masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
117 |
+
- index_i is the indices of the selected predictions (in order)
|
118 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
119 |
+
For each batch element, it holds:
|
120 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
121 |
+
"""
|
122 |
+
return self.memory_efficient_forward(outputs, targets)
|
123 |
+
|
124 |
+
def __repr__(self, _repr_indent=4):
|
125 |
+
head = "Matcher " + self.__class__.__name__
|
126 |
+
body = []
|
127 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
128 |
+
return "\n".join(lines)
|
kmax_deeplab/modeling/meta_arch/__init__.py
ADDED
File without changes
|
kmax_deeplab/modeling/meta_arch/kmax_deeplab_head.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/meta_arch/mask_former_head.py
|
2 |
+
# Modified by Qihang Yu
|
3 |
+
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.layers import ShapeSpec
|
11 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
12 |
+
|
13 |
+
from ..transformer_decoder.kmax_transformer_decoder import build_transformer_decoder
|
14 |
+
|
15 |
+
|
16 |
+
def build_pixel_decoder(cfg, input_shape):
|
17 |
+
"""
|
18 |
+
Build a pixel decoder from `cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.NAME`.
|
19 |
+
"""
|
20 |
+
name = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.NAME
|
21 |
+
model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
|
22 |
+
forward_features = getattr(model, "forward_features", None)
|
23 |
+
if not callable(forward_features):
|
24 |
+
raise ValueError(
|
25 |
+
"Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
|
26 |
+
f"Please implement forward_features for {name} to only return mask features."
|
27 |
+
)
|
28 |
+
return model
|
29 |
+
|
30 |
+
|
31 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
32 |
+
class kMaXDeepLabHead(nn.Module):
|
33 |
+
|
34 |
+
@configurable
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
input_shape: Dict[str, ShapeSpec],
|
38 |
+
*,
|
39 |
+
num_classes: int,
|
40 |
+
pixel_decoder: nn.Module,
|
41 |
+
loss_weight: float = 1.0,
|
42 |
+
ignore_value: int = -1,
|
43 |
+
transformer_predictor: nn.Module,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
NOTE: this interface is experimental.
|
47 |
+
Args:
|
48 |
+
input_shape: shapes (channels and stride) of the input features
|
49 |
+
num_classes: number of classes to predict
|
50 |
+
pixel_decoder: the pixel decoder module
|
51 |
+
loss_weight: loss weight
|
52 |
+
ignore_value: category id to be ignored during training.
|
53 |
+
transformer_predictor: the transformer decoder that makes prediction
|
54 |
+
transformer_in_feature: input feature name to the transformer_predictor
|
55 |
+
"""
|
56 |
+
super().__init__()
|
57 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
58 |
+
self.in_features = [k for k, v in input_shape]
|
59 |
+
|
60 |
+
self.ignore_value = ignore_value
|
61 |
+
self.common_stride = 4
|
62 |
+
self.loss_weight = loss_weight
|
63 |
+
|
64 |
+
self.pixel_decoder = pixel_decoder
|
65 |
+
self.predictor = transformer_predictor
|
66 |
+
|
67 |
+
self.num_classes = num_classes
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
71 |
+
return {
|
72 |
+
"input_shape": {
|
73 |
+
k: v for k, v in input_shape.items() if k in cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES
|
74 |
+
},
|
75 |
+
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
76 |
+
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
77 |
+
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
|
78 |
+
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
|
79 |
+
"transformer_predictor": build_transformer_decoder(cfg, input_shape),
|
80 |
+
}
|
81 |
+
|
82 |
+
def forward(self, features):
|
83 |
+
return self.layers(features)
|
84 |
+
|
85 |
+
def layers(self, features):
|
86 |
+
panoptic_features, semantic_features, multi_scale_features = self.pixel_decoder.forward_features(features)
|
87 |
+
predictions = self.predictor(multi_scale_features, panoptic_features, semantic_features)
|
88 |
+
return predictions
|
kmax_deeplab/modeling/pixel_decoder/__init__.py
ADDED
File without changes
|
kmax_deeplab/modeling/pixel_decoder/kmax_pixel_decoder.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/google-research/deeplab2/blob/main/model/pixel_decoder/kmax.py
|
2 |
+
# Modified by Qihang Yu
|
3 |
+
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from timm.models.layers import DropPath
|
11 |
+
from timm.models.layers import trunc_normal_tf_ as trunc_normal_
|
12 |
+
|
13 |
+
from detectron2.config import configurable
|
14 |
+
from detectron2.layers import ShapeSpec
|
15 |
+
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
16 |
+
from torch.cuda.amp import autocast
|
17 |
+
|
18 |
+
from ..backbone.convnext import LayerNorm
|
19 |
+
|
20 |
+
import math
|
21 |
+
|
22 |
+
|
23 |
+
def get_activation(name):
|
24 |
+
if name is None or name.lower() == 'none':
|
25 |
+
return nn.Identity()
|
26 |
+
if name == 'relu':
|
27 |
+
return nn.ReLU()
|
28 |
+
elif name == 'gelu':
|
29 |
+
return nn.GELU()
|
30 |
+
|
31 |
+
|
32 |
+
def get_norm(name, channels):
|
33 |
+
if name is None or name.lower() == 'none':
|
34 |
+
return nn.Identity()
|
35 |
+
|
36 |
+
if name.lower() == 'syncbn':
|
37 |
+
return nn.SyncBatchNorm(channels, eps=1e-3, momentum=0.01)
|
38 |
+
|
39 |
+
|
40 |
+
class ConvBN(nn.Module):
|
41 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, norm=None, act=None,
|
42 |
+
conv_type='2d', conv_init='he_normal', norm_init=1.0):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
if conv_type == '2d':
|
46 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
47 |
+
elif conv_type == '1d':
|
48 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
49 |
+
|
50 |
+
self.norm = get_norm(norm, out_channels)
|
51 |
+
self.act = get_activation(act)
|
52 |
+
|
53 |
+
if conv_init == 'normal':
|
54 |
+
nn.init.normal_(self.conv.weight, std=.02)
|
55 |
+
elif conv_init == 'trunc_normal':
|
56 |
+
trunc_normal_(self.conv.weight, std=.02)
|
57 |
+
elif conv_init == 'he_normal':
|
58 |
+
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal
|
59 |
+
trunc_normal_(self.conv.weight, std=math.sqrt(2.0 / in_channels))
|
60 |
+
elif conv_init == 'xavier_uniform':
|
61 |
+
nn.init.xavier_uniform_(self.conv.weight)
|
62 |
+
if bias:
|
63 |
+
nn.init.zeros_(self.conv.bias)
|
64 |
+
|
65 |
+
if norm is not None:
|
66 |
+
nn.init.constant_(self.norm.weight, norm_init)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
return self.act(self.norm(self.conv(x)))
|
70 |
+
|
71 |
+
|
72 |
+
MAX_SPAN = 255
|
73 |
+
def _compute_relative_distance_matrix(query_length, key_length):
|
74 |
+
if (key_length - query_length) % 2:
|
75 |
+
raise ValueError('Key_length should be query_length + 2 * memory_flange.')
|
76 |
+
key_index = torch.arange(key_length)
|
77 |
+
query_index = torch.arange(query_length) + (key_length - query_length) // 2
|
78 |
+
distance_matrix = key_index[None, :] - query_index[:, None]
|
79 |
+
# Shift the distance_matrix so that it is >= 0. Each entry of the
|
80 |
+
# distance_matrix distance will index a relative positional embedding.
|
81 |
+
distance_matrix = distance_matrix + MAX_SPAN - 1
|
82 |
+
return distance_matrix
|
83 |
+
|
84 |
+
|
85 |
+
class RelativePositionalEncoding(nn.Module):
|
86 |
+
def __init__(self, query_length, key_length, depth):
|
87 |
+
super().__init__()
|
88 |
+
self._embeddings = nn.Embedding(MAX_SPAN * 2 - 1, depth)
|
89 |
+
trunc_normal_(self._embeddings.weight, std=1.0)
|
90 |
+
self._relative_distance_matrix = _compute_relative_distance_matrix(query_length, key_length)
|
91 |
+
self.query_length = query_length
|
92 |
+
self.key_length = key_length
|
93 |
+
self.depth = depth
|
94 |
+
|
95 |
+
def forward(self):
|
96 |
+
return self._embeddings.weight[self._relative_distance_matrix.reshape(-1)].reshape(self.query_length, self.key_length, self.depth)
|
97 |
+
|
98 |
+
|
99 |
+
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L36
|
100 |
+
class AxialAttention(nn.Module):
|
101 |
+
def __init__(self, in_planes, query_shape=56, total_key_depth=512, total_value_depth=1024, num_heads=8):
|
102 |
+
assert (total_key_depth % num_heads == 0) and (total_value_depth % num_heads == 0)
|
103 |
+
super().__init__()
|
104 |
+
self._in_planes = in_planes
|
105 |
+
self._query_shape = query_shape
|
106 |
+
self._total_key_depth = total_key_depth
|
107 |
+
self._total_value_depth = total_value_depth
|
108 |
+
self._num_heads = num_heads
|
109 |
+
self._key_depth_per_head = total_key_depth // num_heads
|
110 |
+
|
111 |
+
self.qkv_transform = ConvBN(in_planes, self._total_key_depth * 2 + self._total_value_depth, kernel_size=1, stride=1,
|
112 |
+
padding=0, bias=False, norm=None, act=None, conv_type='1d')
|
113 |
+
trunc_normal_(self.qkv_transform.conv.weight, std=in_planes ** -0.5)
|
114 |
+
|
115 |
+
self._query_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head)
|
116 |
+
self._key_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head)
|
117 |
+
self._value_rpe = RelativePositionalEncoding(query_shape, query_shape, total_value_depth // num_heads)
|
118 |
+
|
119 |
+
self._batch_norm_qkv = get_norm('syncbn', self._total_key_depth * 2 + self._total_value_depth)
|
120 |
+
self._batch_norm_similarity = get_norm('syncbn', num_heads * 3)
|
121 |
+
self._batch_norm_retrieved_output = get_norm('syncbn', self._total_value_depth * 2)
|
122 |
+
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
N, C, L = x.shape
|
126 |
+
qkv = self._batch_norm_qkv(self.qkv_transform(x))
|
127 |
+
q, k, v = torch.split(qkv, [self._total_key_depth, self._total_key_depth, self._total_value_depth], dim=1)
|
128 |
+
q = q.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L)
|
129 |
+
k = k.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L)
|
130 |
+
v = v.reshape(N, self._num_heads, self._total_value_depth // self._num_heads, L)
|
131 |
+
|
132 |
+
similarity_logits = []
|
133 |
+
content_similarity = torch.einsum('bhdl,bhdm->bhlm', q, k)
|
134 |
+
query_rpe = self._query_rpe()
|
135 |
+
query_rpe_similarity = torch.einsum('bhdl,lmd->bhlm', q, query_rpe)
|
136 |
+
key_rpe = self._key_rpe()
|
137 |
+
key_rpe_similarity = torch.einsum('bhdm,lmd->bhlm', k, key_rpe)
|
138 |
+
similarity_logits = torch.cat([content_similarity, query_rpe_similarity, key_rpe_similarity], dim=1)
|
139 |
+
similarity_logits = self._batch_norm_similarity(similarity_logits).reshape(N, 3, self._num_heads, L, L).sum(dim=1)
|
140 |
+
|
141 |
+
with autocast(enabled=False):
|
142 |
+
weights = F.softmax(similarity_logits.float(), dim=-1)
|
143 |
+
|
144 |
+
retrieved_content = torch.einsum('bhlm,bhdm->bhdl', weights, v)
|
145 |
+
value_rpe = self._value_rpe()
|
146 |
+
retrieved_rpe = torch.einsum('bhlm,lmd->bhdl', weights, value_rpe)
|
147 |
+
|
148 |
+
retrieved_output = torch.cat([retrieved_content, retrieved_rpe], dim=1).reshape(N, 2*self._total_value_depth, L)
|
149 |
+
retrieved_output = self._batch_norm_retrieved_output(retrieved_output).reshape(N, 2, self._total_value_depth, L).sum(1)
|
150 |
+
|
151 |
+
return retrieved_output
|
152 |
+
|
153 |
+
|
154 |
+
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L316
|
155 |
+
class AxialAttention2D(nn.Module):
|
156 |
+
def __init__(self, in_planes, query_shape=[56, 56], filters=512, key_expansion=1, value_expansion=2, num_heads=8):
|
157 |
+
super().__init__()
|
158 |
+
total_key_depth = int(round(filters * key_expansion))
|
159 |
+
total_value_depth = int(round(filters * value_expansion))
|
160 |
+
self._total_key_depth = total_key_depth
|
161 |
+
self._total_value_depth = total_value_depth
|
162 |
+
self._height_axis = AxialAttention(
|
163 |
+
in_planes=in_planes,
|
164 |
+
query_shape=query_shape[0],
|
165 |
+
total_key_depth=total_key_depth,
|
166 |
+
total_value_depth=total_value_depth,
|
167 |
+
num_heads=num_heads)
|
168 |
+
self._width_axis = AxialAttention(
|
169 |
+
in_planes=total_value_depth,
|
170 |
+
query_shape=query_shape[1],
|
171 |
+
total_key_depth=total_key_depth,
|
172 |
+
total_value_depth=total_value_depth,
|
173 |
+
num_heads=num_heads)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
# N C H W -> N W C H
|
177 |
+
N, C, H, W = x.shape
|
178 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
179 |
+
x = x.reshape(N*W, C, H)
|
180 |
+
x = self._height_axis(x)
|
181 |
+
# N W C H -> N H C W
|
182 |
+
x = x.reshape(N, W, self._total_value_depth, H).permute(0, 3, 2, 1).contiguous()
|
183 |
+
x = x.reshape(N*H, self._total_value_depth, W)
|
184 |
+
x = self._width_axis(x)
|
185 |
+
x = x.reshape(N, H, self._total_value_depth, W).permute(0, 2, 1, 3).contiguous()
|
186 |
+
x = x.reshape(N, self._total_value_depth, H, W)
|
187 |
+
return x
|
188 |
+
|
189 |
+
|
190 |
+
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_blocks.py#L36
|
191 |
+
class SingleBlock(nn.Module):
|
192 |
+
|
193 |
+
def __init__(self, inplanes, filter_list, block_type, query_shape=[56, 56], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0):
|
194 |
+
super(SingleBlock, self).__init__()
|
195 |
+
self._block_type = block_type.lower()
|
196 |
+
self._filter_list = filter_list
|
197 |
+
self._conv1_bn_act = ConvBN(inplanes, self._filter_list[0], kernel_size=1, bias=False, norm='syncbn', act='gelu')
|
198 |
+
if self._block_type == 'axial':
|
199 |
+
self._attention = AxialAttention2D(in_planes=self._filter_list[0], query_shape=query_shape, filters=self._filter_list[1],
|
200 |
+
key_expansion=key_expansion, value_expansion=value_expansion, num_heads=num_heads)
|
201 |
+
output_channel = filter_list[1] * value_expansion
|
202 |
+
elif self._block_type == 'bottleneck':
|
203 |
+
self._conv2_bn_act = ConvBN(self._filter_list[0], self._filter_list[1], kernel_size=3, padding=1, bias=False, norm='syncbn', act='gelu')
|
204 |
+
output_channel = filter_list[1]
|
205 |
+
self._conv3_bn = ConvBN(output_channel, self._filter_list[2], kernel_size=1, bias=False, norm='syncbn', act=None, norm_init=0.0)
|
206 |
+
|
207 |
+
self._shortcut = None
|
208 |
+
if inplanes != self._filter_list[-1]:
|
209 |
+
self._shortcut = ConvBN(inplanes, self._filter_list[-1], kernel_size=1, bias=False, norm='syncbn', act=None)
|
210 |
+
self.drop_path = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
x = F.gelu(x)
|
214 |
+
|
215 |
+
shortcut = x
|
216 |
+
if self._shortcut is not None:
|
217 |
+
shortcut = self._shortcut(shortcut)
|
218 |
+
|
219 |
+
x = self._conv1_bn_act(x)
|
220 |
+
if self._block_type == 'axial':
|
221 |
+
x = self._attention(x)
|
222 |
+
x = F.gelu(x)
|
223 |
+
elif self._block_type == 'bottleneck':
|
224 |
+
x = self._conv2_bn_act(x)
|
225 |
+
x = self._conv3_bn(x)
|
226 |
+
|
227 |
+
x = self.drop_path(x) + shortcut
|
228 |
+
|
229 |
+
return x
|
230 |
+
|
231 |
+
|
232 |
+
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L42
|
233 |
+
class BlockGroup(nn.Module):
|
234 |
+
def __init__(self, inplanes, base_filter, num_blocks, block_type, **kwargs):
|
235 |
+
super().__init__()
|
236 |
+
self._num_blocks = num_blocks
|
237 |
+
block_type = block_type.lower()
|
238 |
+
if block_type == 'axial':
|
239 |
+
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L247
|
240 |
+
filter_list = [base_filter * 2, base_filter, base_filter * 4]
|
241 |
+
elif block_type == 'bottleneck':
|
242 |
+
# https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L250
|
243 |
+
filter_list = [base_filter, base_filter, base_filter * 4]
|
244 |
+
|
245 |
+
self._blocks = nn.ModuleList()
|
246 |
+
for i in range(num_blocks):
|
247 |
+
self._blocks.append(SingleBlock(inplanes=inplanes, filter_list=filter_list, block_type=block_type, **kwargs))
|
248 |
+
inplanes = filter_list[-1]
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
for i in range(self._num_blocks):
|
252 |
+
x = self._blocks[i](x)
|
253 |
+
return x
|
254 |
+
|
255 |
+
|
256 |
+
# https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/resized_fuse.py#L31
|
257 |
+
class ResizedFuse(nn.Module):
|
258 |
+
def __init__(self, low_in_channels, high_in_channels, out_channels):
|
259 |
+
super().__init__()
|
260 |
+
self.low_in_channels = low_in_channels
|
261 |
+
self.high_in_channels = high_in_channels
|
262 |
+
self.out_channels = out_channels
|
263 |
+
if low_in_channels != out_channels:
|
264 |
+
self._conv_bn_low = ConvBN(low_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None)
|
265 |
+
if high_in_channels != out_channels:
|
266 |
+
self._conv_bn_high = ConvBN(high_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None)
|
267 |
+
|
268 |
+
def forward(self, lowres_x, highres_x):
|
269 |
+
|
270 |
+
align_corners = (lowres_x.shape[-1] % 2 == 1)
|
271 |
+
if self.low_in_channels != self.out_channels:
|
272 |
+
lowres_x = F.gelu(lowres_x)
|
273 |
+
lowres_x = self._conv_bn_low(lowres_x)
|
274 |
+
lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners)
|
275 |
+
else:
|
276 |
+
lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners)
|
277 |
+
|
278 |
+
if self.high_in_channels != self.out_channels:
|
279 |
+
highres_x = F.gelu(highres_x)
|
280 |
+
highres_x = self._conv_bn_high(highres_x)
|
281 |
+
|
282 |
+
return lowres_x + highres_x
|
283 |
+
|
284 |
+
|
285 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
286 |
+
class kMaXPixelDecoder(nn.Module):
|
287 |
+
@configurable
|
288 |
+
def __init__(
|
289 |
+
self,
|
290 |
+
input_shape: Dict[str, ShapeSpec],
|
291 |
+
*,
|
292 |
+
dec_layers: List[int],
|
293 |
+
dec_channels: List[int],
|
294 |
+
layer_types: List[str],
|
295 |
+
drop_path_prob: float,
|
296 |
+
spatial_shape: List[int],
|
297 |
+
):
|
298 |
+
"""
|
299 |
+
NOTE: this interface is experimental.
|
300 |
+
Args:
|
301 |
+
"""
|
302 |
+
super().__init__()
|
303 |
+
self.num_stages = len(input_shape)
|
304 |
+
assert self.num_stages == len(dec_layers) and self.num_stages == len(dec_channels) and self.num_stages == len(layer_types)
|
305 |
+
# For now, we hard code all hyper-parameters.
|
306 |
+
block_types = ['axial', 'axial', 'bottleneck', 'bottleneck']
|
307 |
+
input_shape = sorted(input_shape.items(), key=lambda x: -x[1].stride)
|
308 |
+
self.in_features = [k for k, v in input_shape] # starting from "res5" to "res2"
|
309 |
+
in_channels = [v.channels for k, v in input_shape]
|
310 |
+
|
311 |
+
add_one = (spatial_shape[0] % 2, spatial_shape[1] % 2)
|
312 |
+
query_shape = [
|
313 |
+
(spatial_shape[0]//32+add_one[0], spatial_shape[1]//32+add_one[1]),
|
314 |
+
(spatial_shape[0]//16+add_one[0], spatial_shape[1]//16+add_one[1]),
|
315 |
+
(spatial_shape[0]//8+add_one[0], spatial_shape[1]//8+add_one[1]),
|
316 |
+
(spatial_shape[0]//4+add_one[0], spatial_shape[1]//4+add_one[1])]
|
317 |
+
|
318 |
+
self._in_norms = nn.ModuleList()
|
319 |
+
self._stages = nn.ModuleList()
|
320 |
+
self._resized_fuses = nn.ModuleList()
|
321 |
+
|
322 |
+
for i in range(self.num_stages):
|
323 |
+
self._in_norms.append(LayerNorm(in_channels[i], data_format="channels_first"))
|
324 |
+
inplanes = in_channels[i] if i == 0 else dec_channels[i]
|
325 |
+
self._stages.append(BlockGroup(inplanes=inplanes,
|
326 |
+
base_filter=dec_channels[i], num_blocks=dec_layers[i], block_type=block_types[i],
|
327 |
+
query_shape=query_shape[i], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0))
|
328 |
+
|
329 |
+
if i > 0:
|
330 |
+
self._resized_fuses.append(ResizedFuse(
|
331 |
+
low_in_channels=dec_channels[i-1] * 4,
|
332 |
+
high_in_channels=in_channels[i],
|
333 |
+
out_channels=dec_channels[i]))
|
334 |
+
|
335 |
+
|
336 |
+
@classmethod
|
337 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
338 |
+
ret = {}
|
339 |
+
ret["input_shape"] = {
|
340 |
+
k: v for k, v in input_shape.items() if k in cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES
|
341 |
+
}
|
342 |
+
ret["dec_layers"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_LAYERS
|
343 |
+
ret["dec_channels"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_CHANNELS
|
344 |
+
ret["layer_types"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.LAYER_TYPES
|
345 |
+
ret["drop_path_prob"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DROP_PATH_PROB
|
346 |
+
ret["spatial_shape"] = cfg.INPUT.IMAGE_SIZE # We expect the height == width
|
347 |
+
return ret
|
348 |
+
|
349 |
+
|
350 |
+
def forward_features(self, features):
|
351 |
+
out = []
|
352 |
+
multi_scale_features = []
|
353 |
+
|
354 |
+
x = self._in_norms[0](features[self.in_features[0]])
|
355 |
+
|
356 |
+
for idx in range(self.num_stages - 1):
|
357 |
+
x = self._stages[idx](x)
|
358 |
+
out.append(x)
|
359 |
+
x = self._resized_fuses[idx](
|
360 |
+
lowres_x=x,
|
361 |
+
highres_x=self._in_norms[idx+1](features[self.in_features[idx+1]]))
|
362 |
+
|
363 |
+
x = self._stages[-1](x)
|
364 |
+
out.append(x)
|
365 |
+
multi_scale_features = out[:3] # OS32, 16, 8, they are used for kmax_transformer_decoder.
|
366 |
+
panoptic_features = out[-1] # OS4, it is used for final mask prediction.
|
367 |
+
# OS 32, 8, 4
|
368 |
+
semantic_features = [features[self.in_features[0]], features[self.in_features[2]], features[self.in_features[3]]]
|
369 |
+
return panoptic_features, semantic_features, multi_scale_features
|
370 |
+
|
kmax_deeplab/modeling/transformer_decoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .kmax_transformer_decoder import kMaXTransformerDecoder
|
kmax_deeplab/modeling/transformer_decoder/kmax_transformer_decoder.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/google-research/deeplab2/blob/main/model/transformer_decoder/kmax.py
|
2 |
+
# Modified by Qihang Yu
|
3 |
+
|
4 |
+
from typing import List
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.cuda.amp import autocast
|
9 |
+
|
10 |
+
from timm.models.layers import DropPath
|
11 |
+
from timm.models.layers import trunc_normal_tf_ as trunc_normal_
|
12 |
+
|
13 |
+
from detectron2.config import configurable
|
14 |
+
from detectron2.utils.registry import Registry
|
15 |
+
|
16 |
+
from ..pixel_decoder.kmax_pixel_decoder import get_norm, ConvBN
|
17 |
+
|
18 |
+
import math
|
19 |
+
|
20 |
+
|
21 |
+
TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE")
|
22 |
+
TRANSFORMER_DECODER_REGISTRY.__doc__ = """
|
23 |
+
Registry for transformer module.
|
24 |
+
"""
|
25 |
+
def build_transformer_decoder(cfg, input_shape_from_backbone):
|
26 |
+
"""
|
27 |
+
Build a instance embedding branch from `cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NAME`.
|
28 |
+
"""
|
29 |
+
name = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NAME
|
30 |
+
return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, input_shape_from_backbone)
|
31 |
+
|
32 |
+
|
33 |
+
# https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/decoder/max_deeplab.py#L60
|
34 |
+
def add_bias_towards_void(query_class_logits, void_prior_prob=0.9):
|
35 |
+
class_logits_shape = query_class_logits.shape
|
36 |
+
init_bias = [0.0] * class_logits_shape[-1]
|
37 |
+
init_bias[-1] = math.log(
|
38 |
+
(class_logits_shape[-1] - 1) * void_prior_prob / (1 - void_prior_prob))
|
39 |
+
return query_class_logits + torch.tensor(init_bias, dtype=query_class_logits.dtype).to(query_class_logits)
|
40 |
+
|
41 |
+
|
42 |
+
# https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/dual_path_transformer.py#L41
|
43 |
+
class AttentionOperation(nn.Module):
|
44 |
+
def __init__(self, channels_v, num_heads):
|
45 |
+
super().__init__()
|
46 |
+
self._batch_norm_similarity = get_norm('syncbn', num_heads)
|
47 |
+
self._batch_norm_retrieved_value = get_norm('syncbn', channels_v)
|
48 |
+
|
49 |
+
def forward(self, query, key, value):
|
50 |
+
N, _, _, L = query.shape
|
51 |
+
_, num_heads, C, _ = value.shape
|
52 |
+
similarity_logits = torch.einsum('bhdl,bhdm->bhlm', query, key)
|
53 |
+
similarity_logits = self._batch_norm_similarity(similarity_logits)
|
54 |
+
|
55 |
+
with autocast(enabled=False):
|
56 |
+
attention_weights = F.softmax(similarity_logits.float(), dim=-1)
|
57 |
+
retrieved_value = torch.einsum(
|
58 |
+
'bhlm,bhdm->bhdl', attention_weights, value)
|
59 |
+
retrieved_value = retrieved_value.reshape(N, num_heads * C, L)
|
60 |
+
retrieved_value = self._batch_norm_retrieved_value(
|
61 |
+
retrieved_value)
|
62 |
+
retrieved_value = F.gelu(retrieved_value)
|
63 |
+
return retrieved_value
|
64 |
+
|
65 |
+
|
66 |
+
# https://github.com/google-research/deeplab2/blob/main/model/kmax_deeplab.py#L32
|
67 |
+
class kMaXPredictor(nn.Module):
|
68 |
+
def __init__(self, in_channel_pixel, in_channel_query, num_classes=133+1):
|
69 |
+
super().__init__()
|
70 |
+
self._pixel_space_head_conv0bnact = ConvBN(in_channel_pixel, in_channel_pixel, kernel_size=5, groups=in_channel_pixel, padding=2, bias=False,
|
71 |
+
norm='syncbn', act='gelu', conv_init='xavier_uniform')
|
72 |
+
self._pixel_space_head_conv1bnact = ConvBN(in_channel_pixel, 256, kernel_size=1, bias=False, norm='syncbn', act='gelu')
|
73 |
+
self._pixel_space_head_last_convbn = ConvBN(256, 128, kernel_size=1, bias=True, norm='syncbn', act=None)
|
74 |
+
trunc_normal_(self._pixel_space_head_last_convbn.conv.weight, std=0.01)
|
75 |
+
|
76 |
+
self._transformer_mask_head = ConvBN(256, 128, kernel_size=1, bias=False, norm='syncbn', act=None, conv_type='1d')
|
77 |
+
self._transformer_class_head = ConvBN(256, num_classes, kernel_size=1, norm=None, act=None, conv_type='1d')
|
78 |
+
trunc_normal_(self._transformer_class_head.conv.weight, std=0.01)
|
79 |
+
|
80 |
+
self._pixel_space_mask_batch_norm = get_norm('syncbn', channels=1)
|
81 |
+
nn.init.constant_(self._pixel_space_mask_batch_norm.weight, 0.1)
|
82 |
+
|
83 |
+
|
84 |
+
def forward(self, mask_embeddings, class_embeddings, pixel_feature):
|
85 |
+
# mask_embeddings/class_embeddings: B x C x N
|
86 |
+
# pixel feature: B x C x H x W
|
87 |
+
pixel_space_feature = self._pixel_space_head_conv0bnact(pixel_feature)
|
88 |
+
pixel_space_feature = self._pixel_space_head_conv1bnact(pixel_space_feature)
|
89 |
+
pixel_space_feature = self._pixel_space_head_last_convbn(pixel_space_feature)
|
90 |
+
pixel_space_normalized_feature = F.normalize(pixel_space_feature, p=2, dim=1)
|
91 |
+
|
92 |
+
cluster_class_logits = self._transformer_class_head(class_embeddings).permute(0, 2, 1).contiguous()
|
93 |
+
cluster_class_logits = add_bias_towards_void(cluster_class_logits)
|
94 |
+
cluster_mask_kernel = self._transformer_mask_head(mask_embeddings)
|
95 |
+
mask_logits = torch.einsum('bchw,bcn->bnhw',
|
96 |
+
pixel_space_normalized_feature, cluster_mask_kernel)
|
97 |
+
|
98 |
+
mask_logits = self._pixel_space_mask_batch_norm(mask_logits.unsqueeze(dim=1)).squeeze(dim=1)
|
99 |
+
|
100 |
+
|
101 |
+
return {
|
102 |
+
'class_logits': cluster_class_logits,
|
103 |
+
'mask_logits': mask_logits,
|
104 |
+
'pixel_feature': pixel_space_normalized_feature}
|
105 |
+
|
106 |
+
|
107 |
+
# https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/dual_path_transformer.py#L107
|
108 |
+
class kMaXTransformerLayer(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
num_classes=133,
|
112 |
+
in_channel_pixel=2048,
|
113 |
+
in_channel_query=256,
|
114 |
+
base_filters=128,
|
115 |
+
num_heads=8,
|
116 |
+
bottleneck_expansion=2,
|
117 |
+
key_expansion=1,
|
118 |
+
value_expansion=2,
|
119 |
+
drop_path_prob=0.0,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
self._num_classes = num_classes
|
124 |
+
self._num_heads = num_heads
|
125 |
+
self._bottleneck_channels = int(round(base_filters * bottleneck_expansion))
|
126 |
+
self._total_key_depth = int(round(base_filters * key_expansion))
|
127 |
+
self._total_value_depth = int(round(base_filters * value_expansion))
|
128 |
+
|
129 |
+
# Per tf2 implementation, the same drop path prob are applied to:
|
130 |
+
# 1. k-means update for object query
|
131 |
+
# 2. self/cross-attetion for object query
|
132 |
+
# 3. ffn for object query
|
133 |
+
self.drop_path_kmeans = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
|
134 |
+
self.drop_path_attn = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
|
135 |
+
self.drop_path_ffn = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
|
136 |
+
|
137 |
+
initialization_std = self._bottleneck_channels ** -0.5
|
138 |
+
self._query_conv1_bn_act = ConvBN(in_channel_query, self._bottleneck_channels, kernel_size=1, bias=False,
|
139 |
+
norm='syncbn', act='gelu', conv_type='1d')
|
140 |
+
|
141 |
+
self._pixel_conv1_bn_act = ConvBN(in_channel_pixel, self._bottleneck_channels, kernel_size=1, bias=False,
|
142 |
+
norm='syncbn', act='gelu')
|
143 |
+
|
144 |
+
self._query_qkv_conv_bn = ConvBN(self._bottleneck_channels, self._total_key_depth * 2 + self._total_value_depth, kernel_size=1, bias=False,
|
145 |
+
norm='syncbn', act=None, conv_type='1d')
|
146 |
+
trunc_normal_(self._query_qkv_conv_bn.conv.weight, std=initialization_std)
|
147 |
+
|
148 |
+
self._pixel_v_conv_bn = ConvBN(self._bottleneck_channels, self._total_value_depth, kernel_size=1, bias=False,
|
149 |
+
norm='syncbn', act=None)
|
150 |
+
trunc_normal_(self._pixel_v_conv_bn.conv.weight, std=initialization_std)
|
151 |
+
|
152 |
+
self._query_self_attention = AttentionOperation(channels_v=self._total_value_depth, num_heads=num_heads)
|
153 |
+
|
154 |
+
self._query_conv3_bn = ConvBN(self._total_value_depth, in_channel_query, kernel_size=1, bias=False,
|
155 |
+
norm='syncbn', act=None, conv_type='1d', norm_init=0.0)
|
156 |
+
|
157 |
+
self._query_ffn_conv1_bn_act = ConvBN(in_channel_query, 2048, kernel_size=1, bias=False,
|
158 |
+
norm='syncbn', act='gelu', conv_type='1d')
|
159 |
+
self._query_ffn_conv2_bn = ConvBN(2048, in_channel_query, kernel_size=1, bias=False,
|
160 |
+
norm='syncbn', act=None, conv_type='1d', norm_init=0.0)
|
161 |
+
|
162 |
+
self._predcitor = kMaXPredictor(in_channel_pixel=self._bottleneck_channels,
|
163 |
+
in_channel_query=self._bottleneck_channels, num_classes=num_classes)
|
164 |
+
self._kmeans_query_batch_norm_retrieved_value = get_norm('syncbn', self._total_value_depth)
|
165 |
+
self._kmeans_query_conv3_bn = ConvBN(self._total_value_depth, in_channel_query, kernel_size=1, bias=False,
|
166 |
+
norm='syncbn', act=None, conv_type='1d', norm_init=0.0)
|
167 |
+
|
168 |
+
|
169 |
+
def forward(self, pixel_feature, query_feature):
|
170 |
+
N, C, H, W = pixel_feature.shape
|
171 |
+
_, D, L = query_feature.shape
|
172 |
+
pixel_space = self._pixel_conv1_bn_act(F.gelu(pixel_feature)) # N C H W
|
173 |
+
query_space = self._query_conv1_bn_act(query_feature) # N x C x L
|
174 |
+
|
175 |
+
# k-means cross-attention.
|
176 |
+
pixel_value = self._pixel_v_conv_bn(pixel_space) # N C H W
|
177 |
+
pixel_value = pixel_value.reshape(N, self._total_value_depth, H*W)
|
178 |
+
# k-means assignment.
|
179 |
+
prediction_result = self._predcitor(
|
180 |
+
mask_embeddings=query_space, class_embeddings=query_space, pixel_feature=pixel_space)
|
181 |
+
clustering_result = prediction_result['mask_logits'].flatten(2).detach() # N L HW
|
182 |
+
|
183 |
+
with torch.no_grad():
|
184 |
+
clustering_result = prediction_result['mask_logits'].flatten(2).detach() # N L HW
|
185 |
+
index = clustering_result.max(1, keepdim=True)[1]
|
186 |
+
clustering_result = torch.zeros_like(clustering_result, memory_format=torch.legacy_contiguous_format).scatter_(1, index, 1.0)
|
187 |
+
|
188 |
+
with autocast(enabled=False):
|
189 |
+
# k-means update.
|
190 |
+
kmeans_update = torch.einsum('blm,bdm->bdl', clustering_result.float(), pixel_value.float()) # N x C x L
|
191 |
+
|
192 |
+
kmeans_update = self._kmeans_query_batch_norm_retrieved_value(kmeans_update)
|
193 |
+
kmeans_update = self._kmeans_query_conv3_bn(kmeans_update)
|
194 |
+
query_feature = query_feature + self.drop_path_kmeans(kmeans_update)
|
195 |
+
|
196 |
+
# query self-attention.
|
197 |
+
query_qkv = self._query_qkv_conv_bn(query_space)
|
198 |
+
query_q, query_k, query_v = torch.split(query_qkv,
|
199 |
+
[self._total_key_depth, self._total_key_depth, self._total_value_depth], dim=1)
|
200 |
+
query_q = query_q.reshape(N, self._num_heads, self._total_key_depth//self._num_heads, L)
|
201 |
+
query_k = query_k.reshape(N, self._num_heads, self._total_key_depth//self._num_heads, L)
|
202 |
+
query_v = query_v.reshape(N, self._num_heads, self._total_value_depth//self._num_heads, L)
|
203 |
+
self_attn_update = self._query_self_attention(query_q, query_k, query_v)
|
204 |
+
self_attn_update = self._query_conv3_bn(self_attn_update)
|
205 |
+
query_feature = query_feature + self.drop_path_attn(self_attn_update)
|
206 |
+
query_feature = F.gelu(query_feature)
|
207 |
+
|
208 |
+
# FFN.
|
209 |
+
ffn_update = self._query_ffn_conv1_bn_act(query_feature)
|
210 |
+
ffn_update = self._query_ffn_conv2_bn(ffn_update)
|
211 |
+
query_feature = query_feature + self.drop_path_ffn(ffn_update)
|
212 |
+
query_feature = F.gelu(query_feature)
|
213 |
+
|
214 |
+
return query_feature, prediction_result
|
215 |
+
|
216 |
+
|
217 |
+
class ASPP(nn.Module):
|
218 |
+
def __init__(self, in_channels, output_channels, atrous_rates):
|
219 |
+
super().__init__()
|
220 |
+
|
221 |
+
self._aspp_conv0 = ConvBN(in_channels, output_channels, kernel_size=1, bias=False,
|
222 |
+
norm='syncbn', act='gelu')
|
223 |
+
|
224 |
+
rate1, rate2, rate3 = atrous_rates
|
225 |
+
self._aspp_conv1 = ConvBN(in_channels, output_channels, kernel_size=3, dilation=rate1, padding=rate1, bias=False,
|
226 |
+
norm='syncbn', act='gelu')
|
227 |
+
|
228 |
+
self._aspp_conv2 = ConvBN(in_channels, output_channels, kernel_size=3, dilation=rate2, padding=rate2, bias=False,
|
229 |
+
norm='syncbn', act='gelu')
|
230 |
+
|
231 |
+
self._aspp_conv3 = ConvBN(in_channels, output_channels, kernel_size=3, dilation=rate3, padding=rate3, bias=False,
|
232 |
+
norm='syncbn', act='gelu')
|
233 |
+
|
234 |
+
self._avg_pool = nn.AdaptiveAvgPool2d(1)
|
235 |
+
self._aspp_pool = ConvBN(in_channels, output_channels, kernel_size=1, bias=False,
|
236 |
+
norm='syncbn', act='gelu')
|
237 |
+
|
238 |
+
self._proj_conv_bn_act = ConvBN(output_channels * 5, output_channels, kernel_size=1, bias=False,
|
239 |
+
norm='syncbn', act='gelu')
|
240 |
+
# https://github.com/google-research/deeplab2/blob/main/model/decoder/aspp.py#L249
|
241 |
+
self._proj_drop = nn.Dropout(p=0.1)
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
results = []
|
245 |
+
results.append(self._aspp_conv0(x))
|
246 |
+
results.append(self._aspp_conv1(x))
|
247 |
+
results.append(self._aspp_conv2(x))
|
248 |
+
results.append(self._aspp_conv3(x))
|
249 |
+
align_corners = (x.shape[-1] % 2 == 1)
|
250 |
+
results.append(F.interpolate(self._aspp_pool(self._avg_pool(x)), size=x.shape[-2:], mode='bilinear', align_corners=align_corners))
|
251 |
+
|
252 |
+
x = torch.cat(results, dim=1)
|
253 |
+
x = self._proj_conv_bn_act(x)
|
254 |
+
x = self._proj_drop(x)
|
255 |
+
|
256 |
+
return x
|
257 |
+
|
258 |
+
|
259 |
+
class SemanticPredictor(nn.Module):
|
260 |
+
def __init__(self, in_channels, os8_channels, os4_channels, num_classes):
|
261 |
+
super().__init__()
|
262 |
+
|
263 |
+
# Below is PanopticDeepLabSingleDecoder
|
264 |
+
self._aspp = ASPP(
|
265 |
+
in_channels=in_channels,
|
266 |
+
# https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_r50_os32.textproto#L35
|
267 |
+
output_channels=256,
|
268 |
+
# https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_r50_os32.textproto#L36
|
269 |
+
atrous_rates=[6,12,18])
|
270 |
+
|
271 |
+
self._low_level_projection_os8 = ConvBN(os8_channels, 64, kernel_size=1, bias=False,
|
272 |
+
norm='syncbn', act='gelu')
|
273 |
+
|
274 |
+
self._low_level_fusion_os8_conv0_bn_act = ConvBN(256 + 64, 256 + 64, groups=256 + 64, kernel_size=5, padding=2, bias=False,
|
275 |
+
norm='syncbn', act='gelu', conv_init='xavier_uniform')
|
276 |
+
self._low_level_fusion_os8_conv1_bn_act = ConvBN(256 + 64, 256, kernel_size=1,bias=False,
|
277 |
+
norm='syncbn', act='gelu')
|
278 |
+
|
279 |
+
self._low_level_projection_os4 = ConvBN(os4_channels, 32, kernel_size=1, bias=False,
|
280 |
+
norm='syncbn', act='gelu')
|
281 |
+
|
282 |
+
self._low_level_fusion_os4_conv0_bn_act = ConvBN(256 + 32, 256 + 32, groups=256 + 32, kernel_size=5, padding=2, bias=False,
|
283 |
+
norm='syncbn', act='gelu', conv_init='xavier_uniform')
|
284 |
+
self._low_level_fusion_os4_conv1_bn_act = ConvBN(256 + 32, 256, kernel_size=1,bias=False,
|
285 |
+
norm='syncbn', act='gelu')
|
286 |
+
|
287 |
+
# Below is PanopticDeepLabSingleHead
|
288 |
+
self.conv_block_0 = ConvBN(256, 256, groups=256, kernel_size=5, padding=2, bias=False,
|
289 |
+
norm='syncbn', act='gelu', conv_init='xavier_uniform')
|
290 |
+
self.conv_block_1 = ConvBN(256, 256, kernel_size=1,bias=False,
|
291 |
+
norm='syncbn', act='gelu')
|
292 |
+
self.final_conv = ConvBN(256, num_classes, kernel_size=1, norm=None, act=None)
|
293 |
+
trunc_normal_(self.final_conv.conv.weight, std=0.01)
|
294 |
+
|
295 |
+
def forward(self, x, low_features_os8, low_features_os4):
|
296 |
+
x = self._aspp(x)
|
297 |
+
align_corners = (x.shape[-1] % 2 == 1)
|
298 |
+
low_features_os8 = self._low_level_projection_os8(low_features_os8)
|
299 |
+
x = F.interpolate(x, size=low_features_os8.shape[-2:], mode='bilinear', align_corners=align_corners)
|
300 |
+
x = torch.concat([x, low_features_os8], dim=1)
|
301 |
+
x = self._low_level_fusion_os8_conv0_bn_act(x)
|
302 |
+
x = self._low_level_fusion_os8_conv1_bn_act(x)
|
303 |
+
|
304 |
+
low_features_os4 = self._low_level_projection_os4(low_features_os4)
|
305 |
+
x = F.interpolate(x, size=low_features_os4.shape[-2:], mode='bilinear', align_corners=align_corners)
|
306 |
+
x = torch.concat([x, low_features_os4], dim=1)
|
307 |
+
x = self._low_level_fusion_os4_conv0_bn_act(x)
|
308 |
+
x = self._low_level_fusion_os4_conv1_bn_act(x)
|
309 |
+
|
310 |
+
x = self.conv_block_0(x)
|
311 |
+
x = self.conv_block_1(x)
|
312 |
+
x = self.final_conv(x)
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
@TRANSFORMER_DECODER_REGISTRY.register()
|
317 |
+
class kMaXTransformerDecoder(nn.Module):
|
318 |
+
|
319 |
+
@configurable
|
320 |
+
def __init__(
|
321 |
+
self,
|
322 |
+
*,
|
323 |
+
dec_layers: List[int],
|
324 |
+
in_channels: List[int],
|
325 |
+
num_classes: int,
|
326 |
+
num_queries: int,
|
327 |
+
drop_path_prob: float,
|
328 |
+
add_aux_semantic_pred: bool,
|
329 |
+
input_shape_from_backbone,
|
330 |
+
):
|
331 |
+
"""
|
332 |
+
NOTE: this interface is experimental.
|
333 |
+
Args:
|
334 |
+
"""
|
335 |
+
super().__init__()
|
336 |
+
|
337 |
+
# define Transformer decoder here
|
338 |
+
self._kmax_transformer_layers = nn.ModuleList()
|
339 |
+
self._num_blocks = dec_layers
|
340 |
+
os2channels = {32: in_channels[0], 16: in_channels[1], 8: in_channels[2]}
|
341 |
+
|
342 |
+
for index, output_stride in enumerate([32, 16, 8]):
|
343 |
+
for _ in range(self._num_blocks[index]):
|
344 |
+
self._kmax_transformer_layers.append(
|
345 |
+
kMaXTransformerLayer(num_classes=num_classes+1,
|
346 |
+
in_channel_pixel=os2channels[output_stride],
|
347 |
+
in_channel_query=256,
|
348 |
+
base_filters=128,
|
349 |
+
num_heads=8,
|
350 |
+
bottleneck_expansion=2,
|
351 |
+
key_expansion=1,
|
352 |
+
value_expansion=2,
|
353 |
+
drop_path_prob=drop_path_prob)
|
354 |
+
)
|
355 |
+
|
356 |
+
|
357 |
+
self._num_queries = num_queries
|
358 |
+
# learnable query features
|
359 |
+
self._cluster_centers = nn.Embedding(256, num_queries)
|
360 |
+
trunc_normal_(self._cluster_centers.weight, std=1.0)
|
361 |
+
|
362 |
+
self._class_embedding_projection = ConvBN(256, 256, kernel_size=1, bias=False, norm='syncbn', act='gelu',
|
363 |
+
conv_type='1d')
|
364 |
+
|
365 |
+
self._mask_embedding_projection = ConvBN(256, 256, kernel_size=1, bias=False, norm='syncbn', act='gelu',
|
366 |
+
conv_type='1d')
|
367 |
+
|
368 |
+
self._predcitor = kMaXPredictor(in_channel_pixel=256,
|
369 |
+
in_channel_query=256, num_classes=num_classes+1)
|
370 |
+
|
371 |
+
|
372 |
+
self._add_aux_semantic_pred = add_aux_semantic_pred
|
373 |
+
if add_aux_semantic_pred:
|
374 |
+
self._auxiliary_semantic_predictor = SemanticPredictor(
|
375 |
+
in_channels=input_shape_from_backbone['res5'].channels,
|
376 |
+
os8_channels=input_shape_from_backbone['res3'].channels,
|
377 |
+
os4_channels=input_shape_from_backbone['res2'].channels,
|
378 |
+
# +1 for void.
|
379 |
+
num_classes=num_classes+1)
|
380 |
+
|
381 |
+
|
382 |
+
@classmethod
|
383 |
+
def from_config(cls, cfg, input_shape_from_backbone):
|
384 |
+
ret = {}
|
385 |
+
ret["dec_layers"] = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DEC_LAYERS
|
386 |
+
ret["in_channels"] = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.IN_CHANNELS
|
387 |
+
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
|
388 |
+
ret["num_queries"] = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NUM_OBJECT_QUERIES
|
389 |
+
ret["drop_path_prob"] = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DROP_PATH_PROB
|
390 |
+
ret["add_aux_semantic_pred"] = (cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_WEIGHT > 0)
|
391 |
+
ret["input_shape_from_backbone"] = input_shape_from_backbone
|
392 |
+
return ret
|
393 |
+
|
394 |
+
|
395 |
+
def forward(self, x, panoptic_features, semantic_features):
|
396 |
+
B = x[0].shape[0]
|
397 |
+
cluster_centers = self._cluster_centers.weight.unsqueeze(0).repeat(B, 1, 1) # B x C x L
|
398 |
+
|
399 |
+
current_transformer_idx = 0
|
400 |
+
|
401 |
+
predictions_class = []
|
402 |
+
predictions_mask = []
|
403 |
+
predictions_pixel_feature = []
|
404 |
+
|
405 |
+
for i, feat in enumerate(x):
|
406 |
+
for _ in range(self._num_blocks[i]):
|
407 |
+
cluster_centers, prediction_result = self._kmax_transformer_layers[current_transformer_idx](
|
408 |
+
pixel_feature=feat, query_feature=cluster_centers
|
409 |
+
)
|
410 |
+
predictions_class.append(prediction_result['class_logits'])
|
411 |
+
predictions_mask.append(prediction_result['mask_logits'])
|
412 |
+
predictions_pixel_feature.append(prediction_result['pixel_feature'])
|
413 |
+
current_transformer_idx += 1
|
414 |
+
|
415 |
+
class_embeddings = self._class_embedding_projection(cluster_centers)
|
416 |
+
mask_embeddings = self._mask_embedding_projection(cluster_centers)
|
417 |
+
|
418 |
+
# Final predictions.
|
419 |
+
prediction_result = self._predcitor(
|
420 |
+
class_embeddings=class_embeddings,
|
421 |
+
mask_embeddings=mask_embeddings,
|
422 |
+
pixel_feature=panoptic_features,
|
423 |
+
)
|
424 |
+
predictions_class.append(prediction_result['class_logits'])
|
425 |
+
predictions_mask.append(prediction_result['mask_logits'])
|
426 |
+
predictions_pixel_feature.append(prediction_result['pixel_feature'])
|
427 |
+
|
428 |
+
out = {
|
429 |
+
'pred_logits': predictions_class[-1],
|
430 |
+
'pred_masks': predictions_mask[-1],
|
431 |
+
'pixel_feature': predictions_pixel_feature[-1],
|
432 |
+
'aux_outputs': self._set_aux_loss(
|
433 |
+
predictions_class, predictions_mask, predictions_pixel_feature
|
434 |
+
),
|
435 |
+
}
|
436 |
+
|
437 |
+
if self._add_aux_semantic_pred and self.training:
|
438 |
+
semantic_features, low_features_os8, low_features_os4 = semantic_features
|
439 |
+
aux_semantic_prediction = self._auxiliary_semantic_predictor(
|
440 |
+
x=semantic_features, low_features_os8=low_features_os8, low_features_os4=low_features_os4)
|
441 |
+
out.update({'aux_semantic_pred': aux_semantic_prediction,})
|
442 |
+
return out
|
443 |
+
|
444 |
+
|
445 |
+
@torch.jit.unused
|
446 |
+
def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_pixel_feature):
|
447 |
+
target_size = outputs_seg_masks[-1].shape[-2:]
|
448 |
+
align_corners = (target_size[0] % 2 == 1)
|
449 |
+
return [
|
450 |
+
{"pred_logits": a, "pred_masks": F.interpolate(b, size=target_size, mode="bilinear", align_corners=align_corners),
|
451 |
+
"pixel_feature": F.interpolate(c, size=target_size, mode="bilinear", align_corners=align_corners),}
|
452 |
+
for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_pixel_feature[:-1])
|
453 |
+
]
|
pakages.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
libtinfo5
|
2 |
+
libsm6
|
3 |
+
libxext6
|
4 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pyyaml==5.1
|
2 |
+
torch==1.9.0
|
3 |
+
torchvision==0.10.0
|
4 |
+
|
5 |
+
docutils==0.16
|
6 |
+
# https://github.com/sphinx-doc/sphinx/commit/7acd3ada3f38076af7b2b5c9f3b60bb9c2587a3d
|
7 |
+
sphinx==3.2.0
|
8 |
+
recommonmark==0.6.0
|
9 |
+
sphinx_rtd_theme
|
10 |
+
# Dependencies here are only those required by import
|
11 |
+
termcolor
|
12 |
+
numpy
|
13 |
+
tqdm
|
14 |
+
matplotlib
|
15 |
+
termcolor
|
16 |
+
yacs
|
17 |
+
tabulate
|
18 |
+
cloudpickle
|
19 |
+
Pillow
|
20 |
+
future
|
21 |
+
fvcore
|
22 |
+
omegaconf>=2.1.0.dev24
|
23 |
+
hydra-core>=1.1.0.dev5
|
24 |
+
|
25 |
+
opencv-python-headless
|
26 |
+
|
27 |
+
|
28 |
+
cython
|
29 |
+
scipy
|
30 |
+
shapely
|
31 |
+
timm
|
32 |
+
h5py
|
33 |
+
submitit
|
34 |
+
scikit-image
|
train_net.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/train_net.py
|
2 |
+
# Modified by Qihang Yu
|
3 |
+
|
4 |
+
try:
|
5 |
+
# ignore ShapelyDeprecationWarning from fvcore
|
6 |
+
from shapely.errors import ShapelyDeprecationWarning
|
7 |
+
import warnings
|
8 |
+
warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
|
9 |
+
except:
|
10 |
+
pass
|
11 |
+
|
12 |
+
import copy
|
13 |
+
import itertools
|
14 |
+
import os
|
15 |
+
|
16 |
+
from typing import Any, Dict, List, Set
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
import detectron2.utils.comm as comm
|
21 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
22 |
+
from detectron2.config import get_cfg
|
23 |
+
from detectron2.data import MetadataCatalog, build_detection_train_loader, build_detection_test_loader
|
24 |
+
from detectron2.engine import (
|
25 |
+
DefaultTrainer,
|
26 |
+
default_argument_parser,
|
27 |
+
default_setup,
|
28 |
+
launch,
|
29 |
+
)
|
30 |
+
from detectron2.evaluation import (
|
31 |
+
COCOEvaluator,
|
32 |
+
DatasetEvaluators,
|
33 |
+
SemSegEvaluator,
|
34 |
+
verify_results,
|
35 |
+
)
|
36 |
+
from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
|
37 |
+
from detectron2.solver.build import maybe_add_gradient_clipping
|
38 |
+
from detectron2.utils.logger import setup_logger
|
39 |
+
|
40 |
+
# MaskFormer
|
41 |
+
from kmax_deeplab import (
|
42 |
+
COCOPanoptickMaXDeepLabDatasetMapper,
|
43 |
+
add_kmax_deeplab_config,
|
44 |
+
)
|
45 |
+
|
46 |
+
from detectron2.data import MetadataCatalog
|
47 |
+
|
48 |
+
import train_net_utils
|
49 |
+
|
50 |
+
|
51 |
+
class Trainer(DefaultTrainer):
|
52 |
+
"""
|
53 |
+
Extension of the Trainer class adapted to MaskFormer.
|
54 |
+
"""
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
58 |
+
"""
|
59 |
+
Create evaluator(s) for a given dataset.
|
60 |
+
This uses the special metadata "evaluator_type" associated with each
|
61 |
+
builtin dataset. For your own dataset, you can simply create an
|
62 |
+
evaluator manually in your script and do not have to worry about the
|
63 |
+
hacky if-else logic here.
|
64 |
+
"""
|
65 |
+
if output_folder is None:
|
66 |
+
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
67 |
+
evaluator_list = []
|
68 |
+
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
69 |
+
# panoptic segmentation
|
70 |
+
if evaluator_type in [
|
71 |
+
"coco_panoptic_seg",
|
72 |
+
]:
|
73 |
+
if cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON:
|
74 |
+
evaluator_list.append(train_net_utils.COCOPanopticEvaluatorwithVis(dataset_name, output_folder, save_vis_num=cfg.MODEL.KMAX_DEEPLAB.SAVE_VIS_NUM))
|
75 |
+
# COCO
|
76 |
+
if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON:
|
77 |
+
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
78 |
+
if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON:
|
79 |
+
evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
|
80 |
+
elif len(evaluator_list) == 1:
|
81 |
+
return evaluator_list[0]
|
82 |
+
return DatasetEvaluators(evaluator_list)
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def build_train_loader(cls, cfg):
|
86 |
+
# Semantic segmentation dataset mapper
|
87 |
+
if cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
|
88 |
+
mapper = COCOPanoptickMaXDeepLabDatasetMapper(cfg, True)
|
89 |
+
return build_detection_train_loader(cfg, mapper=mapper)
|
90 |
+
else:
|
91 |
+
mapper = None
|
92 |
+
return build_detection_train_loader(cfg, mapper=mapper)
|
93 |
+
|
94 |
+
|
95 |
+
@classmethod
|
96 |
+
def build_lr_scheduler(cls, cfg, optimizer):
|
97 |
+
"""
|
98 |
+
It now calls :func:`detectron2.solver.build_lr_scheduler`.
|
99 |
+
Overwrite it if you'd like a different scheduler.
|
100 |
+
"""
|
101 |
+
name = cfg.SOLVER.LR_SCHEDULER_NAME
|
102 |
+
if name == "TF2WarmupPolyLR":
|
103 |
+
return train_net_utils.TF2WarmupPolyLR(
|
104 |
+
optimizer,
|
105 |
+
cfg.SOLVER.MAX_ITER,
|
106 |
+
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
107 |
+
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
108 |
+
warmup_method=cfg.SOLVER.WARMUP_METHOD,
|
109 |
+
power=cfg.SOLVER.POLY_LR_POWER,
|
110 |
+
constant_ending=cfg.SOLVER.POLY_LR_CONSTANT_ENDING,
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
return build_lr_scheduler(cfg, optimizer)
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def build_optimizer(cls, cfg, model):
|
117 |
+
weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
|
118 |
+
weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
|
119 |
+
|
120 |
+
defaults = {}
|
121 |
+
defaults["lr"] = cfg.SOLVER.BASE_LR
|
122 |
+
defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
|
123 |
+
|
124 |
+
from kmax_deeplab.modeling.backbone.convnext import LayerNorm
|
125 |
+
|
126 |
+
norm_module_types = (
|
127 |
+
torch.nn.BatchNorm1d,
|
128 |
+
torch.nn.BatchNorm2d,
|
129 |
+
torch.nn.BatchNorm3d,
|
130 |
+
torch.nn.SyncBatchNorm,
|
131 |
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
132 |
+
torch.nn.GroupNorm,
|
133 |
+
torch.nn.InstanceNorm1d,
|
134 |
+
torch.nn.InstanceNorm2d,
|
135 |
+
torch.nn.InstanceNorm3d,
|
136 |
+
torch.nn.LayerNorm,
|
137 |
+
torch.nn.LocalResponseNorm,
|
138 |
+
LayerNorm
|
139 |
+
)
|
140 |
+
|
141 |
+
params: List[Dict[str, Any]] = []
|
142 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
143 |
+
for module_name, module in model.named_modules():
|
144 |
+
for module_param_name, value in module.named_parameters(recurse=False):
|
145 |
+
if not value.requires_grad:
|
146 |
+
continue
|
147 |
+
# Avoid duplicating parameters
|
148 |
+
if value in memo:
|
149 |
+
continue
|
150 |
+
memo.add(value)
|
151 |
+
|
152 |
+
hyperparams = copy.copy(defaults)
|
153 |
+
hyperparams["name"] = (module_name, module_param_name)
|
154 |
+
if "backbone" in module_name:
|
155 |
+
hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
|
156 |
+
if (
|
157 |
+
"relative_position_bias_table" in module_param_name
|
158 |
+
or "absolute_pos_embed" in module_param_name
|
159 |
+
):
|
160 |
+
print(module_param_name)
|
161 |
+
hyperparams["weight_decay"] = 0.0
|
162 |
+
if isinstance(module, norm_module_types):
|
163 |
+
hyperparams["weight_decay"] = weight_decay_norm
|
164 |
+
if isinstance(module, torch.nn.Embedding):
|
165 |
+
hyperparams["weight_decay"] = weight_decay_embed
|
166 |
+
# Rule for kMaX.
|
167 |
+
if "_rpe" in module_name:
|
168 |
+
# relative positional embedding in axial attention.
|
169 |
+
hyperparams["weight_decay"] = 0.0
|
170 |
+
if "_cluster_centers" in module_name:
|
171 |
+
# cluster center embeddings.
|
172 |
+
hyperparams["weight_decay"] = 0.0
|
173 |
+
if "bias" in module_param_name:
|
174 |
+
# any bias terms.
|
175 |
+
hyperparams["weight_decay"] = 0.0
|
176 |
+
if "gamma" in module_param_name:
|
177 |
+
# gamma term in convnext
|
178 |
+
hyperparams["weight_decay"] = 0.0
|
179 |
+
|
180 |
+
params.append({"params": [value], **hyperparams})
|
181 |
+
for param_ in params:
|
182 |
+
print(param_["name"], param_["lr"], param_["weight_decay"])
|
183 |
+
|
184 |
+
def maybe_add_full_model_gradient_clipping(optim):
|
185 |
+
# detectron2 doesn't have full model gradient clipping now
|
186 |
+
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
|
187 |
+
enable = (
|
188 |
+
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
|
189 |
+
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
|
190 |
+
and clip_norm_val > 0.0
|
191 |
+
)
|
192 |
+
|
193 |
+
class FullModelGradientClippingOptimizer(optim):
|
194 |
+
def step(self, closure=None):
|
195 |
+
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
|
196 |
+
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
|
197 |
+
super().step(closure=closure)
|
198 |
+
|
199 |
+
return FullModelGradientClippingOptimizer if enable else optim
|
200 |
+
|
201 |
+
optimizer_type = cfg.SOLVER.OPTIMIZER
|
202 |
+
if optimizer_type == "SGD":
|
203 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
|
204 |
+
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
|
205 |
+
)
|
206 |
+
elif optimizer_type == "ADAMW":
|
207 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
|
208 |
+
params, cfg.SOLVER.BASE_LR
|
209 |
+
)
|
210 |
+
elif optimizer_type == "ADAM":
|
211 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.Adam)(
|
212 |
+
params, cfg.SOLVER.BASE_LR
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
raise NotImplementedError(f"no optimizer type {optimizer_type}")
|
216 |
+
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
|
217 |
+
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
|
218 |
+
return optimizer
|
219 |
+
|
220 |
+
|
221 |
+
def setup(args):
|
222 |
+
"""
|
223 |
+
Create configs and perform basic setups.
|
224 |
+
"""
|
225 |
+
cfg = get_cfg()
|
226 |
+
# for poly lr schedule
|
227 |
+
add_deeplab_config(cfg)
|
228 |
+
add_kmax_deeplab_config(cfg)
|
229 |
+
cfg.merge_from_file(args.config_file)
|
230 |
+
cfg.merge_from_list(args.opts)
|
231 |
+
cfg.freeze()
|
232 |
+
default_setup(cfg, args)
|
233 |
+
setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="kmax_deeplab")
|
234 |
+
return cfg
|
235 |
+
|
236 |
+
|
237 |
+
def main(args):
|
238 |
+
cfg = setup(args)
|
239 |
+
|
240 |
+
torch.backends.cudnn.enabled = True
|
241 |
+
if args.eval_only:
|
242 |
+
model = Trainer.build_model(cfg)
|
243 |
+
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
244 |
+
cfg.MODEL.WEIGHTS, resume=args.resume
|
245 |
+
)
|
246 |
+
res = Trainer.test(cfg, model)
|
247 |
+
if comm.is_main_process():
|
248 |
+
verify_results(cfg, res)
|
249 |
+
return res
|
250 |
+
|
251 |
+
trainer = Trainer(cfg)
|
252 |
+
trainer.resume_or_load(resume=args.resume)
|
253 |
+
return trainer.train()
|
254 |
+
|
255 |
+
|
256 |
+
if __name__ == "__main__":
|
257 |
+
args = default_argument_parser().parse_args()
|
258 |
+
print("Command Line Args:", args)
|
259 |
+
launch(
|
260 |
+
main,
|
261 |
+
args.num_gpus,
|
262 |
+
num_machines=args.num_machines,
|
263 |
+
machine_rank=args.machine_rank,
|
264 |
+
dist_url=args.dist_url,
|
265 |
+
args=(args,),
|
266 |
+
)
|
train_net_utils.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import tempfile
|
9 |
+
from collections import OrderedDict
|
10 |
+
from PIL import Image
|
11 |
+
from tabulate import tabulate
|
12 |
+
import json
|
13 |
+
import contextlib
|
14 |
+
|
15 |
+
import detectron2.utils.comm as comm
|
16 |
+
from detectron2.utils.file_io import PathManager
|
17 |
+
from detectron2.data import MetadataCatalog
|
18 |
+
from detectron2.evaluation import COCOPanopticEvaluator
|
19 |
+
|
20 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
21 |
+
from detectron2.data import MetadataCatalog
|
22 |
+
import io
|
23 |
+
import math
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
from detectron2.solver.lr_scheduler import _get_warmup_factor_at_iter
|
27 |
+
|
28 |
+
|
29 |
+
import logging
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class TF2WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
|
35 |
+
"""
|
36 |
+
Poly learning rate schedule used in TF DeepLab2.
|
37 |
+
Reference: https://github.com/google-research/deeplab2/blob/main/trainer/trainer_utils.py#L23
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
optimizer: torch.optim.Optimizer,
|
43 |
+
max_iters: int,
|
44 |
+
warmup_factor: float = 0.001,
|
45 |
+
warmup_iters: int = 1000,
|
46 |
+
warmup_method: str = "linear",
|
47 |
+
last_epoch: int = -1,
|
48 |
+
power: float = 0.9,
|
49 |
+
constant_ending: float = 0.0,
|
50 |
+
):
|
51 |
+
self.max_iters = max_iters
|
52 |
+
self.warmup_factor = warmup_factor
|
53 |
+
self.warmup_iters = warmup_iters
|
54 |
+
self.warmup_method = warmup_method
|
55 |
+
self.power = power
|
56 |
+
self.constant_ending = constant_ending
|
57 |
+
super().__init__(optimizer, last_epoch)
|
58 |
+
|
59 |
+
def get_lr(self) -> List[float]:
|
60 |
+
warmup_factor = _get_warmup_factor_at_iter(
|
61 |
+
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
62 |
+
)
|
63 |
+
if self.constant_ending > 0 and warmup_factor == 1.0:
|
64 |
+
# Constant ending lr.
|
65 |
+
if (
|
66 |
+
math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
|
67 |
+
< self.constant_ending
|
68 |
+
):
|
69 |
+
return [base_lr * self.constant_ending for base_lr in self.base_lrs]
|
70 |
+
if self.last_epoch < self.warmup_iters:
|
71 |
+
return [
|
72 |
+
base_lr * warmup_factor
|
73 |
+
for base_lr in self.base_lrs
|
74 |
+
]
|
75 |
+
else:
|
76 |
+
return [
|
77 |
+
base_lr * math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
|
78 |
+
for base_lr in self.base_lrs
|
79 |
+
]
|
80 |
+
|
81 |
+
def _compute_values(self) -> List[float]:
|
82 |
+
# The new interface
|
83 |
+
return self.get_lr()
|
84 |
+
|
85 |
+
|
86 |
+
class COCOPanopticEvaluatorwithVis(COCOPanopticEvaluator):
|
87 |
+
"""
|
88 |
+
COCO Panoptic Evaluator that supports saving visualizations.
|
89 |
+
TODO(qihangyu): Note that original implementation will also write all predictions to a tmp folder
|
90 |
+
and then run official evaluation script, we may also check how to copy from the tmp folder for visualization.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(self, dataset_name: str, output_dir: Optional[str] = None, save_vis_num=0):
|
94 |
+
super().__init__(dataset_name=dataset_name, output_dir=output_dir)
|
95 |
+
self.metadata = MetadataCatalog.get("coco_2017_val_panoptic_with_sem_seg")
|
96 |
+
self.output_dir = output_dir
|
97 |
+
self.save_vis_num = save_vis_num
|
98 |
+
|
99 |
+
def process(self, inputs, outputs):
|
100 |
+
from panopticapi.utils import id2rgb
|
101 |
+
|
102 |
+
cur_save_num = 0
|
103 |
+
for input, output in zip(inputs, outputs):
|
104 |
+
panoptic_img, segments_info = output["panoptic_seg"]
|
105 |
+
panoptic_seg = panoptic_img.cpu()
|
106 |
+
panoptic_img = panoptic_seg.numpy()
|
107 |
+
|
108 |
+
file_name = os.path.basename(input["file_name"])
|
109 |
+
file_name_png = os.path.splitext(file_name)[0] + ".png"
|
110 |
+
if cur_save_num < self.save_vis_num:
|
111 |
+
image = output["original_image"]
|
112 |
+
image = image.permute(1, 2 ,0).cpu().numpy()#[:, :, ::-1]
|
113 |
+
visualizer = Visualizer(image, self.metadata, instance_mode=ColorMode.IMAGE)
|
114 |
+
vis_output = visualizer.draw_panoptic_seg_predictions(
|
115 |
+
panoptic_seg, segments_info
|
116 |
+
)
|
117 |
+
if not os.path.exists(os.path.join(self.output_dir, 'vis')):
|
118 |
+
os.makedirs(os.path.join(self.output_dir, 'vis'))
|
119 |
+
out_filename = os.path.join(self.output_dir, 'vis', file_name_png)
|
120 |
+
vis_output.save(out_filename)
|
121 |
+
cur_save_num += 1
|
122 |
+
|
123 |
+
if segments_info is None:
|
124 |
+
# If "segments_info" is None, we assume "panoptic_img" is a
|
125 |
+
# H*W int32 image storing the panoptic_id in the format of
|
126 |
+
# category_id * label_divisor + instance_id. We reserve -1 for
|
127 |
+
# VOID label, and add 1 to panoptic_img since the official
|
128 |
+
# evaluation script uses 0 for VOID label.
|
129 |
+
label_divisor = self._metadata.label_divisor
|
130 |
+
segments_info = []
|
131 |
+
for panoptic_label in np.unique(panoptic_img):
|
132 |
+
if panoptic_label == -1:
|
133 |
+
# VOID region.
|
134 |
+
continue
|
135 |
+
pred_class = panoptic_label // label_divisor
|
136 |
+
isthing = (
|
137 |
+
pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
|
138 |
+
)
|
139 |
+
segments_info.append(
|
140 |
+
{
|
141 |
+
"id": int(panoptic_label) + 1,
|
142 |
+
"category_id": int(pred_class),
|
143 |
+
"isthing": bool(isthing),
|
144 |
+
}
|
145 |
+
)
|
146 |
+
# Official evaluation script uses 0 for VOID label.
|
147 |
+
panoptic_img += 1
|
148 |
+
|
149 |
+
|
150 |
+
with io.BytesIO() as out:
|
151 |
+
Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
|
152 |
+
segments_info = [self._convert_category_id(x) for x in segments_info]
|
153 |
+
self._predictions.append(
|
154 |
+
{
|
155 |
+
"image_id": input["image_id"],
|
156 |
+
"file_name": file_name_png,
|
157 |
+
"png_string": out.getvalue(),
|
158 |
+
"segments_info": segments_info,
|
159 |
+
}
|
160 |
+
)
|
161 |
+
|
162 |
+
def evaluate(self):
|
163 |
+
comm.synchronize()
|
164 |
+
|
165 |
+
self._predictions = comm.gather(self._predictions)
|
166 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
167 |
+
if not comm.is_main_process():
|
168 |
+
return
|
169 |
+
|
170 |
+
# PanopticApi requires local files
|
171 |
+
gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
|
172 |
+
gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
|
173 |
+
|
174 |
+
with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
|
175 |
+
logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
|
176 |
+
for p in self._predictions:
|
177 |
+
with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
|
178 |
+
f.write(p.pop("png_string"))
|
179 |
+
|
180 |
+
with open(gt_json, "r") as f:
|
181 |
+
json_data = json.load(f)
|
182 |
+
json_data["annotations"] = self._predictions
|
183 |
+
|
184 |
+
output_dir = self._output_dir or pred_dir
|
185 |
+
predictions_json = os.path.join(output_dir, "predictions.json")
|
186 |
+
with PathManager.open(predictions_json, "w") as f:
|
187 |
+
f.write(json.dumps(json_data))
|
188 |
+
|
189 |
+
from kmax_deeplab.evaluation.panoptic_evaluation import pq_compute
|
190 |
+
|
191 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
192 |
+
pq_res = pq_compute(
|
193 |
+
gt_json,
|
194 |
+
PathManager.get_local_path(predictions_json),
|
195 |
+
gt_folder=gt_folder,
|
196 |
+
pred_folder=pred_dir,
|
197 |
+
)
|
198 |
+
|
199 |
+
res = {}
|
200 |
+
res["PQ"] = 100 * pq_res["All"]["pq"]
|
201 |
+
res["SQ"] = 100 * pq_res["All"]["sq"]
|
202 |
+
res["RQ"] = 100 * pq_res["All"]["rq"]
|
203 |
+
res["PQ_th"] = 100 * pq_res["Things"]["pq"]
|
204 |
+
res["SQ_th"] = 100 * pq_res["Things"]["sq"]
|
205 |
+
res["RQ_th"] = 100 * pq_res["Things"]["rq"]
|
206 |
+
res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
|
207 |
+
res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
|
208 |
+
res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
|
209 |
+
|
210 |
+
results = OrderedDict({"panoptic_seg": res})
|
211 |
+
_print_panoptic_results(pq_res)
|
212 |
+
|
213 |
+
return results
|
214 |
+
|
215 |
+
|
216 |
+
def _print_panoptic_results(pq_res):
|
217 |
+
headers = ["", "PQ", "SQ", "RQ", "#categories"]
|
218 |
+
data = []
|
219 |
+
for name in ["All", "Things", "Stuff"]:
|
220 |
+
row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
|
221 |
+
data.append(row)
|
222 |
+
table = tabulate(
|
223 |
+
data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
|
224 |
+
)
|
225 |
+
logger.info("Panoptic Evaluation Results:\n" + table)
|