atwang commited on
Commit
01664b3
1 Parent(s): 20c01c5

[NOT TESTED] initial implementation of app

Browse files
Files changed (50) hide show
  1. .gitignore +3 -0
  2. app.py +131 -4
  3. configs/coco/instance-segmentation/Base-COCO-InstanceSegmentation.yaml +47 -0
  4. configs/coco/instance-segmentation/maskformer2_R50_bs16_50ep.yaml +44 -0
  5. configs/coco/instance-segmentation/swin/opd_base.yaml +50 -0
  6. configs/coco/instance-segmentation/swin/opd_v1_real.yaml +7 -0
  7. dev-requirements.txt +3 -0
  8. examples/59-4860.png +0 -0
  9. examples/59-4860_d.png +0 -0
  10. inference.py +836 -0
  11. mask2former/__init__.py +11 -0
  12. mask2former/config.py +125 -0
  13. mask2former/maskformer_model.py +820 -0
  14. mask2former/modeling/__init__.py +6 -0
  15. mask2former/modeling/backbone/__init__.py +1 -0
  16. mask2former/modeling/backbone/swin.py +770 -0
  17. mask2former/modeling/criterion.py +547 -0
  18. mask2former/modeling/matcher.py +192 -0
  19. mask2former/modeling/meta_arch/__init__.py +1 -0
  20. mask2former/modeling/meta_arch/mask_former_head.py +133 -0
  21. mask2former/modeling/meta_arch/per_pixel_baseline.py +243 -0
  22. mask2former/modeling/pixel_decoder/__init__.py +1 -0
  23. mask2former/modeling/pixel_decoder/fpn.py +312 -0
  24. mask2former/modeling/pixel_decoder/msdeformattn.py +358 -0
  25. mask2former/modeling/pixel_decoder/ops/functions/__init__.py +13 -0
  26. mask2former/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py +72 -0
  27. mask2former/modeling/pixel_decoder/ops/make.sh +13 -0
  28. mask2former/modeling/pixel_decoder/ops/modules/__init__.py +12 -0
  29. mask2former/modeling/pixel_decoder/ops/modules/ms_deform_attn.py +125 -0
  30. mask2former/modeling/pixel_decoder/ops/setup.py +78 -0
  31. mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp +46 -0
  32. mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h +38 -0
  33. mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu +158 -0
  34. mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h +35 -0
  35. mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh +1332 -0
  36. mask2former/modeling/pixel_decoder/ops/src/ms_deform_attn.h +67 -0
  37. mask2former/modeling/pixel_decoder/ops/src/vision.cpp +21 -0
  38. mask2former/modeling/pixel_decoder/ops/test.py +92 -0
  39. mask2former/modeling/transformer_decoder/__init__.py +4 -0
  40. mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py +461 -0
  41. mask2former/modeling/transformer_decoder/maskformer_transformer_decoder.py +188 -0
  42. mask2former/modeling/transformer_decoder/opd_transformer_decoder.py +520 -0
  43. mask2former/modeling/transformer_decoder/position_encoding.py +64 -0
  44. mask2former/modeling/transformer_decoder/transformer.py +369 -0
  45. mask2former/utils/__init__.py +2 -0
  46. mask2former/utils/misc.py +111 -0
  47. mask2former/utils/motion_visualizer.py +676 -0
  48. mask2former/utils/tranform.py +169 -0
  49. pre-requirements.txt +6 -0
  50. requirements.txt +11 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ build/
2
+ venv/
3
+ __pycache__/
app.py CHANGED
@@ -1,7 +1,134 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ import os
2
+ import re
3
+ from types import SimpleNamespace
4
+ from typing import Any
5
+
6
  import gradio as gr
7
+ import numpy as np
8
+ from detectron2 import engine
9
+
10
+ from inference import main, setup_cfg
11
+
12
+ # internal settings
13
+ NUM_PROCESSES = 1
14
+ CROP = False
15
+ SCORE_THRESHOLD = 0.8
16
+ MAX_PARTS = 5
17
+ ARGS = SimpleNamespace(
18
+ config_file="configs/coco/instance-segmentation/swin/opd_v1_real.yaml",
19
+ model="...",
20
+ input_format="RGB",
21
+ output=".output",
22
+ cpu=True,
23
+ )
24
+
25
+
26
+ def predict(rgb_image: str, depth_image: str, intrinsics: np.ndarray, num_samples: int) -> list[Any]:
27
+ def find_gifs(path: str) -> list[str]:
28
+ """Scrape folders for all generated gif files."""
29
+ for file in os.listdir(path):
30
+ sub_path = os.path.join(path, file)
31
+ if os.path.isdir(sub_path):
32
+ for image_file in os.listdir(sub_path):
33
+ if re.match(r".*\.gif$", image_file):
34
+ yield os.path.join(sub_path, image_file)
35
+
36
+ cfg = setup_cfg(ARGS)
37
+
38
+ engine.launch(
39
+ main,
40
+ NUM_PROCESSES,
41
+ args=(
42
+ cfg,
43
+ rgb_image,
44
+ depth_image,
45
+ intrinsics,
46
+ num_samples,
47
+ CROP,
48
+ SCORE_THRESHOLD,
49
+ ),
50
+ )
51
+
52
+ # process output
53
+ # TODO: may want to select these in decreasing order of score
54
+ pre_outputs = list(find_gifs(ARGS.output))
55
+
56
+ outputs = []
57
+ for idx in range(MAX_PARTS): # hide unused components
58
+ if idx < len(pre_outputs):
59
+ outputs.append(gr.update(value=pre_outputs[idx], visible=True))
60
+ else:
61
+ outputs.append(gr.update(visible=False))
62
+ return outputs
63
+
64
+
65
+ def variable_outputs(idx):
66
+ idx = int(idx)
67
+
68
+
69
+ with gr.Blocks() as app:
70
+ gr.Markdown(
71
+ """
72
+ # OPDMulti Demo
73
+ Upload an image to see its range of motion.
74
+ """
75
+ )
76
+
77
+ # TODO: add gr.Examples
78
+
79
+ with gr.Row():
80
+ rgb_image = gr.Image(
81
+ image_mode="RGB", source="upload", type="filepath", label="RGB Image", show_label=True, interactive=True
82
+ )
83
+ depth_image = gr.Image(
84
+ image_mode="L", source="upload", type="filepath", label="Depth Image", show_label=True, interactive=True
85
+ )
86
+
87
+ intrinsics = gr.Dataframe(
88
+ value=[
89
+ [
90
+ 214.85935872395834,
91
+ 0.0,
92
+ 0.0,
93
+ ],
94
+ [
95
+ 0.0,
96
+ 214.85935872395834,
97
+ 0.0,
98
+ ],
99
+ [
100
+ 125.90160319010417,
101
+ 95.13726399739583,
102
+ 1.0,
103
+ ],
104
+ ],
105
+ row_count=(3, "fixed"),
106
+ col_count=(3, "fixed"),
107
+ datatype="number",
108
+ type="numpy",
109
+ label="Intrinsics matrix",
110
+ show_label=True,
111
+ interactive=True,
112
+ )
113
+ num_samples = gr.Number(
114
+ value=10,
115
+ label="Number of samples",
116
+ show_label=True,
117
+ interactive=True,
118
+ precision=0,
119
+ minimum=3,
120
+ maximum=20,
121
+ )
122
+
123
+ submit_btn = gr.Button("Run model")
124
+
125
+ # TODO: do we want to set a maximum limit on how many parts we render? We could also show the number of components
126
+ # identified.
127
+ outputs = [gr.Image(type="filepath", label=f"Part {idx + 1}", visible=False) for idx in range(MAX_PARTS)]
128
 
129
+ # TODO: maybe need to use a queue here so we don't overload the instance
130
+ submit_btn.click(
131
+ fn=predict, inputs=[rgb_image, depth_image, intrinsics, num_samples], outputs=outputs, api_name="run_model"
132
+ )
133
 
134
+ app.launch()
 
configs/coco/instance-segmentation/Base-COCO-InstanceSegmentation.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ BACKBONE:
3
+ FREEZE_AT: 0
4
+ NAME: "build_resnet_backbone"
5
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
6
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
7
+ PIXEL_STD: [58.395, 57.120, 57.375]
8
+ RESNETS:
9
+ DEPTH: 50
10
+ STEM_TYPE: "basic" # not used
11
+ STEM_OUT_CHANNELS: 64
12
+ STRIDE_IN_1X1: False
13
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
14
+ # NORM: "SyncBN"
15
+ RES5_MULTI_GRID: [1, 1, 1] # not used
16
+ DATASETS:
17
+ TRAIN: ("coco_2017_train",)
18
+ TEST: ("coco_2017_val",)
19
+ SOLVER:
20
+ IMS_PER_BATCH: 16
21
+ BASE_LR: 0.0001
22
+ STEPS: (327778, 355092)
23
+ MAX_ITER: 368750
24
+ WARMUP_FACTOR: 1.0
25
+ WARMUP_ITERS: 10
26
+ WEIGHT_DECAY: 0.05
27
+ OPTIMIZER: "ADAMW"
28
+ BACKBONE_MULTIPLIER: 0.1
29
+ CLIP_GRADIENTS:
30
+ ENABLED: True
31
+ CLIP_TYPE: "full_model"
32
+ CLIP_VALUE: 0.01
33
+ NORM_TYPE: 2.0
34
+ AMP:
35
+ ENABLED: True
36
+ INPUT:
37
+ IMAGE_SIZE: 1024
38
+ MIN_SCALE: 0.1
39
+ MAX_SCALE: 2.0
40
+ FORMAT: "RGB"
41
+ DATASET_MAPPER_NAME: "coco_instance_lsj"
42
+ TEST:
43
+ EVAL_PERIOD: 5000
44
+ DATALOADER:
45
+ FILTER_EMPTY_ANNOTATIONS: True
46
+ NUM_WORKERS: 4
47
+ VERSION: 2
configs/coco/instance-segmentation/maskformer2_R50_bs16_50ep.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: Base-COCO-InstanceSegmentation.yaml
2
+ MODEL:
3
+ META_ARCHITECTURE: "MaskFormer"
4
+ SEM_SEG_HEAD:
5
+ NAME: "MaskFormerHead"
6
+ IGNORE_VALUE: 255
7
+ NUM_CLASSES: 80
8
+ LOSS_WEIGHT: 1.0
9
+ CONVS_DIM: 256
10
+ MASK_DIM: 256
11
+ NORM: "GN"
12
+ # pixel decoder
13
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
14
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
15
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
16
+ COMMON_STRIDE: 4
17
+ TRANSFORMER_ENC_LAYERS: 6
18
+ MASK_FORMER:
19
+ TRANSFORMER_DECODER_NAME: "MultiScaleMaskedTransformerDecoder"
20
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
21
+ DEEP_SUPERVISION: True
22
+ NO_OBJECT_WEIGHT: 0.1
23
+ CLASS_WEIGHT: 2.0
24
+ MASK_WEIGHT: 5.0
25
+ DICE_WEIGHT: 5.0
26
+ HIDDEN_DIM: 256
27
+ NUM_OBJECT_QUERIES: 100
28
+ NHEADS: 8
29
+ DROPOUT: 0.0
30
+ DIM_FEEDFORWARD: 2048
31
+ ENC_LAYERS: 0
32
+ PRE_NORM: False
33
+ ENFORCE_INPUT_PROJ: False
34
+ SIZE_DIVISIBILITY: 32
35
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
36
+ TRAIN_NUM_POINTS: 12544
37
+ OVERSAMPLE_RATIO: 3.0
38
+ IMPORTANCE_SAMPLE_RATIO: 0.75
39
+ TEST:
40
+ SEMANTIC_ON: False
41
+ INSTANCE_ON: True
42
+ PANOPTIC_ON: False
43
+ OVERLAP_THRESHOLD: 0.8
44
+ OBJECT_MASK_THRESHOLD: 0.8
configs/coco/instance-segmentation/swin/opd_base.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: ../maskformer2_R50_bs16_50ep.yaml
2
+
3
+ INPUT:
4
+ FORMAT: "RGB"
5
+ IMAGE_SIZE: 256
6
+ MAX_SIZE_TEST: 256
7
+ MAX_SIZE_TRAIN: 256
8
+ MIN_SIZE_TEST: 256
9
+ MIN_SIZE_TRAIN:
10
+ - 256
11
+ # DATASET_MAPPER_NAME: "motion_instance"
12
+
13
+ DATALOADER:
14
+ NUM_WORKERS: 4
15
+
16
+ DATASETS:
17
+ TRAIN: ("MotionNet_train",)
18
+ TEST: ("MotionNet_valid",)
19
+
20
+ MODEL:
21
+ MOTIONNET:
22
+ TYPE: BMOC_V0
23
+ SEM_SEG_HEAD:
24
+ NUM_CLASSES: 3
25
+ MASK_ON: True # Useful for our MotionEvaluator, because it's from an older version detectron2
26
+ MASK_FORMER:
27
+ TRANSFORMER_DECODER_NAME: OPDMultiScaleMaskedTransformerDecoder
28
+ CLASS_WEIGHT: 2.0
29
+ MASK_WEIGHT: 5.0
30
+ DICE_WEIGHT: 5.0
31
+ MTYPE_WEIGHT: 2.0
32
+ MORIGIN_WEIGHT: 16.0
33
+ MAXIS_WEIGHT: 16.0
34
+ MSTATE_WEIGHT: 16.0
35
+ MSTATEMAX_WEIGHT: 16.0
36
+ EXTRINSIC_WEIGHT: 30.0
37
+
38
+ SOLVER:
39
+ IMS_PER_BATCH: 16
40
+ BASE_LR: 0.0001
41
+ STEPS: (36000, 48000)
42
+ MAX_ITER: 60000
43
+ CHECKPOINT_PERIOD: 10000
44
+
45
+ TEST:
46
+ AUG:
47
+ ENABLED: false
48
+ FLIP: false
49
+ EVAL_PERIOD: 10000
50
+ SEED: 42
configs/coco/instance-segmentation/swin/opd_v1_real.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _BASE_: ./opd_base.yaml
2
+
3
+ MODEL:
4
+ MOTIONNET:
5
+ TYPE: BMOC_V1
6
+ PIXEL_MEAN: [142.60756197911175, 128.59507321750323, 110.82755928042158, 1267.231689453125] # RGB mean from MotionDataset_real train
7
+ PIXEL_STD: [24.008765143841437, 24.132018526763215, 27.228518892160068, 599.8106079101562] # RGB stddev from MotionDataset_real train
dev-requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ black==23.9.1
2
+ gradio==3.44.3
3
+ huggingface-hub==0.17.2
examples/59-4860.png ADDED
examples/59-4860_d.png ADDED
inference.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+ ------------
4
+ Provides functionality to run the OPDMulti model on an input image, independent of dataset and ground truth, and
5
+ visualize the output. Large portions of the code originate from get_prediction.py, rgbd_to_pcd_vis.py,
6
+ evaluate_on_log.py, and other related files. The primary goal was to create a more standalone script which could be
7
+ converted more easily into a public demo, thus the goal was to sever most dependencies on existing ground truth or
8
+ datasets.
9
+
10
+ Example usage:
11
+ python inference.py \
12
+ --rgb path/to/59-4860.png \
13
+ --depth path/to/59-4860_d.png \
14
+ --model path/to/model.pth \
15
+ --output path/to/output_dir
16
+ """
17
+
18
+ import argparse
19
+ import logging
20
+ import os
21
+ import time
22
+ from copy import deepcopy
23
+ from typing import Any
24
+
25
+ import imageio
26
+ import open3d as o3d
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ from detectron2 import engine, evaluation
31
+ from detectron2.modeling import build_model
32
+ from detectron2.config import get_cfg, CfgNode
33
+ from detectron2.projects.deeplab import add_deeplab_config
34
+ from detectron2.structures import instances
35
+ from detectron2.utils import comm
36
+ from detectron2.utils.logger import setup_logger
37
+ from PIL import Image, ImageChops
38
+
39
+ from mask2former import (
40
+ add_maskformer2_config,
41
+ add_motionnet_config,
42
+ )
43
+
44
+ # import based on torch version. Required for model loading. Code is taken from fvcore.common.checkpoint.py, in order to
45
+ # replicate model loading without the overhead of setting up an OPDTrainer
46
+
47
+ TORCH_VERSION: tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
48
+ if TORCH_VERSION >= (1, 11):
49
+ from torch.ao import quantization
50
+ from torch.ao.quantization import FakeQuantizeBase, ObserverBase
51
+ elif (
52
+ TORCH_VERSION >= (1, 8)
53
+ and hasattr(torch.quantization, "FakeQuantizeBase")
54
+ and hasattr(torch.quantization, "ObserverBase")
55
+ ):
56
+ from torch import quantization
57
+ from torch.quantization import FakeQuantizeBase, ObserverBase
58
+
59
+ # TODO: find a global place for this instead of in many places in code
60
+ TYPE_CLASSIFICATION = {
61
+ 0: "rotation",
62
+ 1: "translation",
63
+ }
64
+
65
+ POINT_COLOR = [1, 0, 0] # red for demonstration
66
+ IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")
67
+
68
+
69
+ def get_parser() -> argparse.ArgumentParser:
70
+ """
71
+ Specfy command-line arguments.
72
+
73
+ The primary inputs to the script should be the image paths (RGBD) and camera intrinsics. Other arguments are
74
+ provided to facilitate script testing and model changes. Run file with -h/--help to see all arguments.
75
+
76
+ :return: parser for extracting command-line arguments
77
+ """
78
+ parser = argparse.ArgumentParser(description="Inference for OPDMulti")
79
+ # The main arguments which should be specified by the user
80
+ parser.add_argument(
81
+ "--rgb",
82
+ dest="rgb_image",
83
+ metavar="FILE",
84
+ help="path to RGB image file on which to run model",
85
+ )
86
+ parser.add_argument(
87
+ "--depth",
88
+ dest="depth_image",
89
+ metavar="FILE",
90
+ help="path to depth image file on which to run model",
91
+ )
92
+ parser.add_argument( # FIXME: might make more sense to make this a path
93
+ "-i",
94
+ "--intrinsics",
95
+ nargs=9,
96
+ default=[
97
+ 214.85935872395834,
98
+ 0.0,
99
+ 0.0,
100
+ 0.0,
101
+ 214.85935872395834,
102
+ 0.0,
103
+ 125.90160319010417,
104
+ 95.13726399739583,
105
+ 1.0,
106
+ ],
107
+ dest="intrinsics",
108
+ help="camera intrinsics matrix, as a list of values",
109
+ )
110
+
111
+ # optional parameters for user to specify
112
+ parser.add_argument(
113
+ "-n",
114
+ "--num-samples",
115
+ default=10,
116
+ dest="num_samples",
117
+ metavar="NUM",
118
+ help="number of sample states to generate in visualization",
119
+ )
120
+ parser.add_argument(
121
+ "--crop",
122
+ action="store_true",
123
+ dest="crop",
124
+ help="crop whitespace out of images for visualization",
125
+ )
126
+
127
+ # local script development arguments
128
+ parser.add_argument(
129
+ "-m",
130
+ "--model",
131
+ default="path/to/model/file", # FIXME: set a good default path
132
+ dest="model",
133
+ metavar="FILE",
134
+ help="path to model file to run",
135
+ )
136
+ parser.add_argument(
137
+ "-c",
138
+ "--config",
139
+ default="configs/coco/instance-segmentation/swin/opd_v1_real.yaml",
140
+ metavar="FILE",
141
+ dest="config_file",
142
+ help="path to config file",
143
+ )
144
+ parser.add_argument(
145
+ "-o",
146
+ "--output",
147
+ default="output", # FIXME: set a good default path
148
+ dest="output",
149
+ help="path to output directory in which to save results",
150
+ )
151
+ parser.add_argument(
152
+ "--num-processes",
153
+ default=1,
154
+ dest="num_processes",
155
+ help="number of processes per machine. When using GPUs, this should be the number of GPUs.",
156
+ )
157
+ parser.add_argument(
158
+ "-s",
159
+ "--score-threshold",
160
+ default=0.8,
161
+ type=float,
162
+ dest="score_threshold",
163
+ help="threshold between 0.0 and 1.0 by which to filter out bad predictions",
164
+ )
165
+ parser.add_argument(
166
+ "--input-format",
167
+ default="RGB",
168
+ dest="input_format",
169
+ help="input format of image. Must be one of RGB, RGBD, or depth",
170
+ )
171
+ parser.add_argument(
172
+ "--cpu",
173
+ action="store_true",
174
+ help="flag to require code to use CPU only",
175
+ )
176
+
177
+ return parser
178
+
179
+
180
+ def setup_cfg(args: argparse.Namespace) -> CfgNode:
181
+ """
182
+ Create configs and perform basic setups.
183
+ """
184
+ cfg = get_cfg()
185
+ # add model configurations
186
+ add_deeplab_config(cfg)
187
+ add_maskformer2_config(cfg)
188
+ add_motionnet_config(cfg)
189
+ cfg.merge_from_file(args.config_file)
190
+
191
+ # set additional config parameters
192
+ cfg.MODEL.WEIGHTS = args.model
193
+ cfg.OBJ_DETECT = False # TODO: figure out if this is needed, and parameterize it
194
+ cfg.MODEL.MOTIONNET.VOTING = "none"
195
+ # Output directory
196
+ cfg.OUTPUT_DIR = args.output
197
+ cfg.MODEL.DEVICE = "cpu" if args.cpu else "cuda"
198
+
199
+ cfg.MODEL.MODELATTRPATH = None
200
+
201
+ # Input format
202
+ cfg.INPUT.FORMAT = args.input_format
203
+ if args.input_format == "RGB":
204
+ cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN[0:3]
205
+ cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD[0:3]
206
+ elif args.input_format == "depth":
207
+ cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN[3:4]
208
+ cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD[3:4]
209
+ elif args.input_format == "RGBD":
210
+ pass
211
+ else:
212
+ raise ValueError("Invalid input format")
213
+
214
+ cfg.freeze()
215
+ engine.default_setup(cfg, args)
216
+
217
+ # Setup logger for "mask_former" module
218
+ setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="opdformer")
219
+ return cfg
220
+
221
+
222
+ def format_input(rgb_path: str) -> list[dict[str, Any]]:
223
+ """
224
+ Read and format input image into detectron2 form so that it can be passed to the model.
225
+
226
+ :param rgb_path: path to RGB image file
227
+ :return: list of dictionaries per image, where each dictionary is of the form
228
+ {
229
+ "file_name": path to RGB image,
230
+ "image": torch.Tensor of dimensions [channel, height, width] representing the image
231
+ }
232
+ """
233
+ image = imageio.imread(rgb_path).astype(np.float32)
234
+ image_tensor = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) # dim: [channel, height, width]
235
+ return [{"file_name": rgb_path, "image": image_tensor}]
236
+
237
+
238
+ def load_model(model: nn.Module, checkpoint: Any) -> None:
239
+ """
240
+ Load weights from a checkpoint.
241
+
242
+ The majority of the function definition is taken from the DetectionCheckpointer implementation provided in
243
+ detectron2. While not all of this code is necessarily needed for model loading, it was ported with the intention
244
+ of keeping the implementation and output as close to the original as possible, and reusing the checkpoint class here
245
+ in isolation was determined to be infeasible.
246
+
247
+ :param model: model for which to load weights
248
+ :param checkpoint: checkpoint contains the weights.
249
+ """
250
+
251
+ def _strip_prefix_if_present(state_dict: dict[str, Any], prefix: str) -> None:
252
+ """If prefix is found on all keys in state dict, remove prefix."""
253
+ keys = sorted(state_dict.keys())
254
+ if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
255
+ return
256
+
257
+ for key in keys:
258
+ newkey = key[len(prefix) :]
259
+ state_dict[newkey] = state_dict.pop(key)
260
+
261
+ checkpoint_state_dict = checkpoint.pop("model")
262
+
263
+ # convert from numpy to tensor
264
+ for k, v in checkpoint_state_dict.items():
265
+ if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
266
+ raise ValueError("Unsupported type found in checkpoint! {}: {}".format(k, type(v)))
267
+ if not isinstance(v, torch.Tensor):
268
+ checkpoint_state_dict[k] = torch.from_numpy(v)
269
+
270
+ # if the state_dict comes from a model that was wrapped in a
271
+ # DataParallel or DistributedDataParallel during serialization,
272
+ # remove the "module" prefix before performing the matching.
273
+ _strip_prefix_if_present(checkpoint_state_dict, "module.")
274
+
275
+ # workaround https://github.com/pytorch/pytorch/issues/24139
276
+ model_state_dict = model.state_dict()
277
+ incorrect_shapes = []
278
+ for k in list(checkpoint_state_dict.keys()): # state dict is modified in loop, so list op is necessary
279
+ if k in model_state_dict:
280
+ model_param = model_state_dict[k]
281
+ # Allow mismatch for uninitialized parameters
282
+ if TORCH_VERSION >= (1, 8) and isinstance(model_param, nn.parameter.UninitializedParameter):
283
+ continue
284
+ shape_model = tuple(model_param.shape)
285
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
286
+ if shape_model != shape_checkpoint:
287
+ has_observer_base_classes = (
288
+ TORCH_VERSION >= (1, 8)
289
+ and hasattr(quantization, "ObserverBase")
290
+ and hasattr(quantization, "FakeQuantizeBase")
291
+ )
292
+ if has_observer_base_classes:
293
+ # Handle the special case of quantization per channel observers,
294
+ # where buffer shape mismatches are expected.
295
+ def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
296
+ # foo.bar.param_or_buffer_name -> [foo, bar]
297
+ key_parts = key.split(".")[:-1]
298
+ cur_module = model
299
+ for key_part in key_parts:
300
+ cur_module = getattr(cur_module, key_part)
301
+ return cur_module
302
+
303
+ cls_to_skip = (
304
+ ObserverBase,
305
+ FakeQuantizeBase,
306
+ )
307
+ target_module = _get_module_for_key(model, k)
308
+ if isinstance(target_module, cls_to_skip):
309
+ # Do not remove modules with expected shape mismatches
310
+ # them from the state_dict loading. They have special logic
311
+ # in _load_from_state_dict to handle the mismatches.
312
+ continue
313
+
314
+ incorrect_shapes.append((k, shape_checkpoint, shape_model))
315
+ checkpoint_state_dict.pop(k)
316
+
317
+ model.load_state_dict(checkpoint_state_dict, strict=False)
318
+
319
+
320
+ def predict(model: nn.Module, inp: list[dict[str, Any]]) -> list[dict[str, instances.Instances]]:
321
+ """
322
+ Compute model predictions.
323
+
324
+ :param model: model to run on input
325
+ :param inp: input, in the form
326
+ {
327
+ "image_file": path to image,
328
+ "image": float32 torch.tensor of dimensions [channel, height, width] as RGB/RGBD/depth image
329
+ }
330
+ :return: list of detected instances and predicted openable parameters
331
+ """
332
+ with torch.no_grad(), evaluation.inference_context(model):
333
+ out = model(inp)
334
+ return out
335
+
336
+
337
+ def generate_rotation_visualization(
338
+ pcd: o3d.geometry.PointCloud,
339
+ axis_arrow: o3d.geometry.TriangleMesh,
340
+ mask: np.ndarray,
341
+ axis_vector: np.ndarray,
342
+ origin: np.ndarray,
343
+ range_min: float,
344
+ range_max: float,
345
+ num_samples: int,
346
+ output_dir: str,
347
+ ) -> None:
348
+ """
349
+ Generate visualization files for a rotation motion of a part.
350
+
351
+ :param pcd: point cloud object representing 2D image input (RGBD) as a point cloud
352
+ :param axis_arrow: mesh object representing axis arrow of rotation to be rendered in visualization
353
+ :param mask: mask np.array of dimensions (height, width) representing the part to be rotated in the image
354
+ :param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of rotation
355
+ :param origin: np.array of dimensions (3, ) representing the origin point of the axis of rotation
356
+ :param range_min: float representing the minimum range of motion in radians
357
+ :param range_max: float representing the maximum range of motion in radians
358
+ :param num_samples: number of sample states to visualize in between range_min and range_max of motion
359
+ :param output_dir: string path to directory in which to save visualization output
360
+ """
361
+ angle_in_radians = np.linspace(range_min, range_max, num_samples)
362
+ angles_in_degrees = angle_in_radians * 180 / np.pi
363
+
364
+ for idx, angle_in_degrees in enumerate(angles_in_degrees):
365
+ # Make a copy of your original point cloud and arrow for each rotation
366
+ rotated_pcd = deepcopy(pcd)
367
+ rotated_arrow = deepcopy(axis_arrow)
368
+
369
+ angle_rad = np.radians(angle_in_degrees)
370
+ rotated_pcd = rotate_part(rotated_pcd, mask, axis_vector, origin, angle_rad)
371
+
372
+ # Create a Visualizer object for each rotation
373
+ vis = o3d.visualization.Visualizer()
374
+ vis.create_window()
375
+
376
+ # Add the rotated geometries
377
+ vis.add_geometry(rotated_pcd)
378
+ vis.add_geometry(rotated_arrow)
379
+
380
+ # Apply the additional rotation around x-axis if desired
381
+ angle_x = np.pi * 5.5 / 5 # 198 degrees
382
+ rotation_matrix = o3d.geometry.get_rotation_matrix_from_axis_angle(np.asarray([1, 0, 0]) * angle_x)
383
+ rotated_pcd.rotate(rotation_matrix, center=rotated_pcd.get_center())
384
+ rotated_arrow.rotate(rotation_matrix, center=rotated_pcd.get_center())
385
+
386
+ # Capture and save the image
387
+ output_filename = f"{output_dir}/{idx}.png"
388
+ vis.capture_screen_image(output_filename, do_render=True)
389
+ vis.destroy_window()
390
+
391
+
392
+ def generate_translation_visualization(
393
+ pcd: o3d.geometry.PointCloud,
394
+ axis_arrow: o3d.geometry.TriangleMesh,
395
+ mask: np.ndarray,
396
+ end: np.ndarray,
397
+ range_min: float,
398
+ range_max: float,
399
+ num_samples: int,
400
+ output_dir: str,
401
+ ) -> None:
402
+ """
403
+ Generate visualization files for a translation motion of a part.
404
+
405
+ :param pcd: point cloud object representing 2D image input (RGBD) as a point cloud
406
+ :param axis_arrow: mesh object representing axis arrow of translation to be rendered in visualization
407
+ :param mask: mask np.array of dimensions (height, width) representing the part to be translated in the image
408
+ :param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of translation
409
+ :param origin: np.array of dimensions (3, ) representing the origin point of the axis of translation
410
+ :param range_min: float representing the minimum range of motion
411
+ :param range_max: float representing the maximum range of motion
412
+ :param num_samples: number of sample states to visualize in between range_min and range_max of motion
413
+ :param output_dir: string path to directory in which to save visualization output
414
+ """
415
+ translate_distances = np.linspace(range_min, range_max, num_samples)
416
+ for idx, translate_distance in enumerate(translate_distances):
417
+ translated_pcd = deepcopy(pcd)
418
+ translated_arrow = deepcopy(axis_arrow)
419
+
420
+ translated_pcd = translate_part(translated_pcd, mask, end, translate_distance.item())
421
+
422
+ # Create a Visualizer object for each rotation
423
+ vis = o3d.visualization.Visualizer()
424
+ vis.create_window()
425
+
426
+ # Add the translated geometries
427
+ vis.add_geometry(translated_pcd)
428
+ vis.add_geometry(translated_arrow)
429
+
430
+ # Apply the additional rotation around x-axis if desired
431
+ # TODO: not sure why we need this rotation for the translation, and when it would be desired
432
+ angle_x = np.pi * 5.5 / 5 # 198 degrees
433
+ R = o3d.geometry.get_rotation_matrix_from_axis_angle(np.asarray([1, 0, 0]) * angle_x)
434
+ translated_pcd.rotate(R, center=translated_pcd.get_center())
435
+ translated_arrow.rotate(R, center=translated_pcd.get_center())
436
+
437
+ # Capture and save the image
438
+ output_filename = f"{output_dir}/{idx}.png"
439
+ vis.capture_screen_image(output_filename, do_render=True)
440
+ vis.destroy_window()
441
+
442
+
443
+ def get_rotation_matrix_from_vectors(vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray:
444
+ """
445
+ Find the rotation matrix that aligns vec1 to vec2
446
+
447
+ :param vec1: A 3d "source" vector
448
+ :param vec2: A 3d "destination" vector
449
+ :return: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
450
+ """
451
+ a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
452
+ v = np.cross(a, b)
453
+ c = np.dot(a, b)
454
+ s = np.linalg.norm(v)
455
+ kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
456
+ rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2))
457
+ return rotation_matrix
458
+
459
+
460
+ def draw_line(start_point: np.ndarray, end_point: np.ndarray) -> o3d.geometry.TriangleMesh:
461
+ """
462
+ Generate 3D mesh representing axis from start_point to end_point.
463
+
464
+ :param start_point: np.ndarray of dimensions (3, ) representing the start point of the axis
465
+ :param end_point: np.ndarray of dimensions (3, ) representing the end point of the axis
466
+ :return: mesh object representing axis from start to end
467
+ """
468
+ # Compute direction vector and normalize it
469
+ direction_vector = end_point - start_point
470
+ normalized_vector = direction_vector / np.linalg.norm(direction_vector)
471
+
472
+ # Compute the rotation matrix to align the Z-axis with the desired direction
473
+ target_vector = np.array([0, 0, 1])
474
+ rot_mat = get_rotation_matrix_from_vectors(target_vector, normalized_vector)
475
+
476
+ # Create the cylinder (shaft of the arrow)
477
+ cylinder_length = 0.9 # 90% of the total arrow length, you can adjust as needed
478
+ cylinder_radius = 0.01 # Adjust the thickness of the arrow shaft
479
+ cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=cylinder_radius, height=cylinder_length)
480
+
481
+ # Move base of cylinder to origin, rotate, then translate to start_point
482
+ cylinder.translate([0, 0, 0])
483
+ cylinder.rotate(rot_mat, center=[0, 0, 0])
484
+ cylinder.translate(start_point)
485
+
486
+ # Create the cone (head of the arrow)
487
+ cone_height = 0.1 # 10% of the total arrow length, adjust as needed
488
+ cone_radius = 0.03 # Adjust the size of the arrowhead
489
+ cone = o3d.geometry.TriangleMesh.create_cone(radius=cone_radius, height=cone_height)
490
+
491
+ # Move base of cone to origin, rotate, then translate to end of cylinder
492
+ cone.translate([-0, 0, 0])
493
+ cone.rotate(rot_mat, center=[0, 0, 0])
494
+ cone.translate(start_point + normalized_vector * 0.4)
495
+
496
+ arrow = cylinder + cone
497
+ return arrow
498
+
499
+
500
+ def rotate_part(
501
+ pcd: o3d.geometry.PointCloud, mask: np.ndarray, axis_vector: np.ndarray, origin: np.ndarray, angle_rad: float
502
+ ) -> o3d.geometry.PointCloud:
503
+ """
504
+ Generate rotated point cloud of mask based on provided angle around axis.
505
+
506
+ :param pcd: point cloud object representing points of image
507
+ :param mask: mask np.array of dimensions (height, width) representing the part to be rotated in the image
508
+ :param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of rotation
509
+ :param origin: np.array of dimensions (3, ) representing the origin point of the axis of rotation
510
+ :param angle_rad: angle in radians to rotate mask part
511
+ :return: point cloud object after rotation of masked part
512
+ """
513
+ # Get the coordinates of the point cloud as a numpy array
514
+ points_np = np.asarray(pcd.points)
515
+
516
+ # Convert point cloud colors to numpy array for easier manipulation
517
+ colors_np = np.asarray(pcd.colors)
518
+
519
+ # Create skew-symmetric matrix from end
520
+ K = np.array(
521
+ [
522
+ [0, -axis_vector[2], axis_vector[1]],
523
+ [axis_vector[2], 0, -axis_vector[0]],
524
+ [-axis_vector[1], axis_vector[0], 0],
525
+ ]
526
+ )
527
+
528
+ # Compute rotation matrix using Rodrigues' formula
529
+ R = np.eye(3) + np.sin(angle_rad) * K + (1 - np.cos(angle_rad)) * np.dot(K, K)
530
+
531
+ # Iterate over the mask and rotate the points corresponding to the object pixels
532
+ for i in range(mask.shape[0]):
533
+ for j in range(mask.shape[1]):
534
+ if mask[i, j] > 0: # This condition checks if the pixel belongs to the object
535
+ point_index = i * mask.shape[1] + j
536
+
537
+ # Translate the point such that the rotation origin is at the world origin
538
+ translated_point = points_np[point_index] - origin
539
+
540
+ # Rotate the translated point
541
+ rotated_point = np.dot(R, translated_point)
542
+
543
+ # Translate the point back
544
+ points_np[point_index] = rotated_point + origin
545
+
546
+ colors_np[point_index] = POINT_COLOR
547
+
548
+ # Update the point cloud's coordinates
549
+ pcd.points = o3d.utility.Vector3dVector(points_np)
550
+
551
+ # Update point cloud colors
552
+ pcd.colors = o3d.utility.Vector3dVector(colors_np)
553
+
554
+ return pcd
555
+
556
+
557
+ def translate_part(pcd, mask, axis_vector, distance):
558
+ """
559
+ Generate translated point cloud of mask based on provided angle around axis.
560
+
561
+ :param pcd: point cloud object representing points of image
562
+ :param mask: mask np.array of dimensions (height, width) representing the part to be translated in the image
563
+ :param axis_vector: np.array of dimensions (3, ) representing the vector of the axis of translation
564
+ :param distance: distance within coordinate system to translate mask part
565
+ :return: point cloud object after translation of masked part
566
+ """
567
+ normalized_vector = axis_vector / np.linalg.norm(axis_vector)
568
+ translation_vector = normalized_vector * distance
569
+
570
+ # Convert point cloud colors to numpy array for easier manipulation
571
+ colors_np = np.asarray(pcd.colors)
572
+
573
+ # Get the coordinates of the point cloud as a numpy array
574
+ points_np = np.asarray(pcd.points)
575
+
576
+ # Iterate over the mask and assign the color to the points corresponding to the object pixels
577
+ for i in range(mask.shape[0]):
578
+ for j in range(mask.shape[1]):
579
+ if mask[i, j] > 0: # This condition checks if the pixel belongs to the object
580
+ point_index = i * mask.shape[1] + j
581
+ colors_np[point_index] = POINT_COLOR
582
+ points_np[point_index] += translation_vector
583
+
584
+ # Update point cloud colors
585
+ pcd.colors = o3d.utility.Vector3dVector(colors_np)
586
+
587
+ # Update the point cloud's coordinates
588
+ pcd.points = o3d.utility.Vector3dVector(points_np)
589
+
590
+ return pcd
591
+
592
+
593
+ def batch_trim(images_path: str, save_path: str, identical: bool = False) -> None:
594
+ """
595
+ Trim white spaces from all images in the given path and save new images to folder.
596
+
597
+ :param images_path: local path to folder containing all images. Images must have the extension ".png", ".jpg", or
598
+ ".jpeg".
599
+ :param save_path: local path to folder in which to save trimmed images
600
+ :param identical: if True, will apply same crop to all images, else each image will have its whitespace trimmed
601
+ independently. Note that in the latter case, each image may have a slightly different size.
602
+ """
603
+
604
+ def get_trim(im):
605
+ """Trim whitespace from an image and return the cropped image."""
606
+ bg = Image.new(im.mode, im.size, im.getpixel((0, 0)))
607
+ diff = ImageChops.difference(im, bg)
608
+ diff = ImageChops.add(diff, diff, 2.0, -100)
609
+ bbox = diff.getbbox()
610
+ return bbox
611
+
612
+ if identical: #
613
+ images = []
614
+ optimal_box = None
615
+
616
+ # load all images
617
+ for image_file in os.listdir(images_path):
618
+ if image_file.endswith(IMAGE_EXTENSIONS):
619
+ image_path = os.path.join(images_path, image_file)
620
+ images.append(Image.open(image_path))
621
+
622
+ # find optimal box size
623
+ for im in images:
624
+ bbox = get_trim(im)
625
+ if bbox is None:
626
+ bbox = (0, 0, im.size[0], im.size[1]) # bound entire image
627
+
628
+ if optimal_box is None:
629
+ optimal_box = bbox
630
+ else:
631
+ optimal_box = (
632
+ min(optimal_box[0], bbox[0]),
633
+ min(optimal_box[1], bbox[1]),
634
+ max(optimal_box[2], bbox[2]),
635
+ max(optimal_box[3], bbox[3]),
636
+ )
637
+
638
+ # apply cropping, if optimal box was found
639
+ if optimal_box:
640
+ for im in images:
641
+ im.crop(optimal_box)
642
+ im.close()
643
+
644
+ else: # trim each image separately
645
+ for image_file in os.listdir(images_path):
646
+ if image_file.endswith(IMAGE_EXTENSIONS):
647
+ image_path = os.path.join(images_path, image_file)
648
+ with Image.open(image_path) as im:
649
+ bbox = get_trim(im)
650
+ trimmed = im.crop(bbox) if bbox else im
651
+ trimmed.save(os.path.join(save_path, image_file))
652
+
653
+
654
+ def create_gif(image_folder_path: str, num_samples: int, gif_filename: str = "output.gif") -> None:
655
+ """
656
+ Create gif out of folder of images and save to file.
657
+
658
+ :param image_folder_path: path to folder containing images (non-recursive). Assumes images are named as {i}.png for
659
+ each of i from 0 to num_samples.
660
+ :param num_samples: number of sampled images to compile into gif.
661
+ :param gif_filename: filename for gif, defaults to "output.gif"
662
+ """
663
+ # Generate a list of image filenames (assuming the images are saved as 0.png, 1.png, etc.)
664
+ image_files = [f"{image_folder_path}/{i}.png" for i in range(num_samples)]
665
+
666
+ # Read the images using imageio
667
+ images = [imageio.imread(image_file) for image_file in image_files]
668
+
669
+ # Save images as a gif
670
+ gif_output_path = f"{image_folder_path}/{gif_filename}"
671
+ imageio.mimsave(gif_output_path, images, duration=0.1)
672
+
673
+ return
674
+
675
+
676
+ def main(
677
+ cfg: CfgNode,
678
+ rgb_image: str,
679
+ depth_image: str,
680
+ intrinsics: list[float],
681
+ num_samples: int,
682
+ crop: bool,
683
+ score_threshold: float,
684
+ ) -> None:
685
+ """
686
+ Main inference method.
687
+
688
+ :param cfg: configuration object
689
+ :param rgb_image: local path to RGB image
690
+ :param depth_image: local path to depth image
691
+ :param intrinsics: camera intrinsics matrix as a list of 9 values
692
+ :param num_samples: number of sample visualization states to generate
693
+ :param crop: if True, images will be cropped to remove whitespace before visualization
694
+ :param score_threshold: float between 0 and 1 representing threshold at which to filter instances based on score
695
+ """
696
+ logger = logging.getLogger("detectron2")
697
+
698
+ # setup data
699
+ logger.info("Loading image.")
700
+ inp = format_input(rgb_image)
701
+
702
+ # setup model
703
+ logger.info("Loading model.")
704
+ model = build_model(cfg)
705
+ weights = torch.load(cfg.MODEL.WEIGHTS, map_location=torch.device("cpu"))
706
+ if "model" not in weights:
707
+ weights = {"model": weights}
708
+ load_model(model, weights)
709
+
710
+ # run model on data
711
+ logger.info("Running model.")
712
+ prediction = predict(model, inp)[0] # index 0 since there is only one image
713
+
714
+ # select best prediction to visualize
715
+ pred_instances = prediction["instances"]
716
+ score_ranking = np.argsort([-1 * pred_instances[i].scores.item() for i in range(len(pred_instances))])
717
+ score_ranking = [idx for idx in score_ranking if pred_instances[int(idx)].scores.item() > score_threshold]
718
+ if len(score_ranking) == 0:
719
+ logging.warning("The model did not predict any moving parts above the score threshold.")
720
+ return
721
+
722
+ for idx in score_ranking: # iterate through all best predictions, by score threshold
723
+ pred = pred_instances[int(idx)] # take highest predicted one
724
+ logger.info("Rendering prediction for instance %d", int(idx))
725
+ output_dir = os.path.join(cfg.OUTPUT_DIR, str(idx))
726
+ os.makedirs(output_dir, exist_ok=True)
727
+
728
+ # extract predicted values for visualization
729
+ mask = np.squeeze(pred.pred_masks.cpu().numpy()) # dim: [height, width]
730
+ origin = pred.morigin.cpu().numpy().flatten() # dim: [3, ]
731
+ axis_vector = pred.maxis.cpu().numpy().flatten() # dim: [3, ]
732
+ pred_type = TYPE_CLASSIFICATION.get(pred.mtype.item())
733
+ range_min = 0 - pred.mstate.cpu().numpy()
734
+ range_max = pred.mstatemax.cpu().numpy() - pred.mstate.cpu().numpy()
735
+
736
+ # process visualization
737
+ color = o3d.io.read_image(rgb_image)
738
+ depth = o3d.io.read_image(depth_image)
739
+ rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(color, depth, convert_rgb_to_intensity=False)
740
+ color_np = np.asarray(color)
741
+ height, width = color_np.shape[:2]
742
+
743
+ # generate intrinsics
744
+ intrinsic_matrix = np.reshape(intrinsics, (3, 3), order="F")
745
+ intrinsic_obj = o3d.camera.PinholeCameraIntrinsic(
746
+ width,
747
+ height,
748
+ intrinsic_matrix[0, 0],
749
+ intrinsic_matrix[1, 1],
750
+ intrinsic_matrix[0, 2],
751
+ intrinsic_matrix[1, 2],
752
+ )
753
+
754
+ # Convert the RGBD image to a point cloud
755
+ pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic_obj)
756
+
757
+ # Create a LineSet to visualize the direction vector
758
+ axis_arrow = draw_line(origin, axis_vector + origin)
759
+ axis_arrow.paint_uniform_color([0, 1, 0])
760
+
761
+ # if USE_GT:
762
+ # anno_path = f"/localhome/atw7/projects/opdmulti/data/data_demo_dev/59-4860.json"
763
+ # part_id = 32
764
+
765
+ # # get annotation for the frame
766
+ # import json
767
+
768
+ # with open(anno_path, "r") as f:
769
+ # anno = json.load(f)
770
+
771
+ # articulations = anno["articulation"]
772
+ # for articulation in articulations:
773
+ # if articulation["partId"] == part_id:
774
+ # range_min = articulation["rangeMin"] - articulation["state"]
775
+ # range_max = articulation["rangeMax"] - articulation["state"]
776
+ # break
777
+
778
+ if pred_type == "rotation":
779
+ generate_rotation_visualization(
780
+ pcd,
781
+ axis_arrow,
782
+ mask,
783
+ axis_vector,
784
+ origin,
785
+ range_min,
786
+ range_max,
787
+ num_samples,
788
+ output_dir,
789
+ )
790
+ elif pred_type == "translation":
791
+ generate_translation_visualization(
792
+ pcd,
793
+ axis_arrow,
794
+ mask,
795
+ axis_vector,
796
+ range_min,
797
+ range_max,
798
+ num_samples,
799
+ output_dir,
800
+ )
801
+ else:
802
+ raise ValueError(f"Invalid motion prediction type: {pred_type}")
803
+
804
+ if pred_type:
805
+ if crop: # crop images to remove shared extraneous whitespace
806
+ output_dir_cropped = f"{output_dir}_cropped"
807
+ if not os.path.isdir(output_dir_cropped):
808
+ os.makedirs(output_dir_cropped)
809
+ batch_trim(output_dir, output_dir_cropped, identical=True)
810
+ create_gif(output_dir_cropped, num_samples)
811
+ else: # leave original dimensions of image as-is
812
+ create_gif(output_dir, num_samples)
813
+
814
+
815
+ if __name__ == "__main__":
816
+ # parse arguments
817
+ start_time = time.time()
818
+ args = get_parser().parse_args()
819
+ cfg = setup_cfg(args)
820
+
821
+ # run main code
822
+ engine.launch(
823
+ main,
824
+ args.num_processes,
825
+ args=(
826
+ cfg,
827
+ args.rgb_image,
828
+ args.depth_image,
829
+ args.intrinsics,
830
+ args.num_samples,
831
+ args.crop,
832
+ args.score_threshold,
833
+ ),
834
+ )
835
+ end_time = time.time()
836
+ print(f"Inference time: {end_time - start_time:.2f} seconds")
mask2former/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates
2
+ from . import modeling
3
+
4
+ # config
5
+ from .config import add_maskformer2_config, add_motionnet_config
6
+
7
+ __all__ = [
8
+ "modeling",
9
+ "add_maskformer2_config",
10
+ "add_motionnet_config",
11
+ ]
mask2former/config.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from detectron2.config import CfgNode as CN
4
+
5
+
6
+ def add_motionnet_config(cfg: CN):
7
+ _C = cfg
8
+ _C.MODEL.MOTIONNET = CN()
9
+ _C.MODEL.MOTIONNET.TYPE = "BMOC_V0"
10
+ cfg.MODEL.MASK_FORMER.MTYPE_WEIGHT = 2.0
11
+ cfg.MODEL.MASK_FORMER.MORIGIN_WEIGHT = 16.0
12
+ cfg.MODEL.MASK_FORMER.MAXIS_WEIGHT = 16.0
13
+ cfg.MODEL.MASK_FORMER.MSTATE_WEIGHT = 16.0
14
+ cfg.MODEL.MASK_FORMER.MSTATEMAX_WEIGHT = 16.0
15
+ cfg.MODEL.MASK_FORMER.EXTRINSIC_WEIGHT = 30.0
16
+
17
+ def add_maskformer2_config(cfg):
18
+ """
19
+ Add config for MASK_FORMER.
20
+ """
21
+ # NOTE: configs from original maskformer
22
+ # data config
23
+ # select the dataset mapper
24
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
25
+ # Color augmentation
26
+ cfg.INPUT.COLOR_AUG_SSD = False
27
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
28
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
29
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
30
+ # Pad image and segmentation GT in dataset mapper.
31
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
32
+
33
+ # solver config
34
+ # weight decay on embedding
35
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
36
+ # optimizer
37
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
38
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
39
+
40
+ # mask_former model config
41
+ cfg.MODEL.MASK_FORMER = CN()
42
+
43
+ # loss
44
+ cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
45
+ cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
46
+ cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = 1.0
47
+ cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
48
+ cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
49
+
50
+ # transformer config
51
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
52
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
53
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
54
+ cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
55
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
56
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
57
+
58
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
59
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
60
+
61
+ cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
62
+ cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
63
+
64
+ # mask_former inference config
65
+ cfg.MODEL.MASK_FORMER.TEST = CN()
66
+ cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
67
+ cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
68
+ cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
69
+ cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
70
+ cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
71
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
72
+
73
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
74
+ # you can use this config to override
75
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
76
+
77
+ # pixel decoder config
78
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
79
+ # adding transformer in pixel decoder
80
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
81
+ # pixel decoder
82
+ cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
83
+
84
+ # swin transformer backbone
85
+ cfg.MODEL.SWIN = CN()
86
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
87
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
88
+ cfg.MODEL.SWIN.EMBED_DIM = 96
89
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
90
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
91
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
92
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
93
+ cfg.MODEL.SWIN.QKV_BIAS = True
94
+ cfg.MODEL.SWIN.QK_SCALE = None
95
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
96
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
97
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
98
+ cfg.MODEL.SWIN.APE = False
99
+ cfg.MODEL.SWIN.PATCH_NORM = True
100
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
101
+ cfg.MODEL.SWIN.USE_CHECKPOINT = False
102
+
103
+ # NOTE: maskformer2 extra configs
104
+ # transformer module
105
+ cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME = "MultiScaleMaskedTransformerDecoder"
106
+
107
+ # LSJ aug
108
+ cfg.INPUT.IMAGE_SIZE = 1024
109
+ cfg.INPUT.MIN_SCALE = 0.1
110
+ cfg.INPUT.MAX_SCALE = 2.0
111
+
112
+ # MSDeformAttn encoder configs
113
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
114
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
115
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
116
+
117
+ # point loss configs
118
+ # Number of points sampled during training for a mask point head.
119
+ cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = 112 * 112
120
+ # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
121
+ # original paper.
122
+ cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
123
+ # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
124
+ # the original paper.
125
+ cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
mask2former/maskformer_model.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import pdb
3
+ from typing import Tuple
4
+ from copy import deepcopy
5
+
6
+ import torch
7
+ from torch import device, nn
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
13
+ from detectron2.modeling.backbone import Backbone
14
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
15
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks
16
+ from detectron2.utils.memory import retry_if_cuda_oom
17
+
18
+ from .modeling.criterion import SetCriterion
19
+ from .modeling.matcher import HungarianMatcher
20
+ from .utils.tranform import matrix_to_quaternion, quaternion_to_matrix, rotation_6d_to_matrix, matrix_to_rotation_6d, geometric_median
21
+ from .modeling.criterion import convert_to_filled_tensor
22
+
23
+ import numpy as np
24
+
25
+ @META_ARCH_REGISTRY.register()
26
+ class MaskFormer(nn.Module):
27
+ """
28
+ Main class for mask classification semantic segmentation architectures.
29
+ """
30
+
31
+ @configurable
32
+ def __init__(
33
+ self,
34
+ *,
35
+ backbone: Backbone,
36
+ sem_seg_head: nn.Module,
37
+ criterion: nn.Module,
38
+ mask2former_backbone: nn.Module,
39
+ mask2former_sem_seg_head: nn.Module,
40
+ num_queries: int,
41
+ object_mask_threshold: float,
42
+ overlap_threshold: float,
43
+ metadata,
44
+ size_divisibility: int,
45
+ sem_seg_postprocess_before_inference: bool,
46
+ pixel_mean: Tuple[float],
47
+ pixel_std: Tuple[float],
48
+ # inference
49
+ semantic_on: bool,
50
+ panoptic_on: bool,
51
+ instance_on: bool,
52
+ test_topk_per_image: int,
53
+ # OPD
54
+ motionnet_type,
55
+ voting,
56
+ gtdet,
57
+ inference_matcher,
58
+ gtextrinsic,
59
+ only_DET,
60
+ obj_method
61
+ ):
62
+ """
63
+ Args:
64
+ backbone: a backbone module, must follow detectron2's backbone interface
65
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
66
+ criterion: a module that defines the loss
67
+ num_queries: int, number of queries
68
+ object_mask_threshold: float, threshold to filter query based on classification score
69
+ for panoptic segmentation inference
70
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
71
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
72
+ segmentation inference
73
+ size_divisibility: Some backbones require the input height and width to be divisible by a
74
+ specific integer. We can use this to override such requirement.
75
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
76
+ to original input size before semantic segmentation inference or after.
77
+ For high-resolution dataset like Mapillary, resizing predictions before
78
+ inference will cause OOM error.
79
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
80
+ the per-channel mean and std to be used to normalize the input image
81
+ semantic_on: bool, whether to output semantic segmentation prediction
82
+ instance_on: bool, whether to output instance segmentation prediction
83
+ panoptic_on: bool, whether to output panoptic segmentation prediction
84
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
85
+ """
86
+ super().__init__()
87
+ self.backbone = backbone
88
+ self.sem_seg_head = sem_seg_head
89
+ self.mask2former_backbone = mask2former_backbone
90
+ self.mask2former_sem_seg_head = mask2former_sem_seg_head
91
+
92
+ self.criterion = criterion
93
+ self.num_queries = num_queries
94
+ self.overlap_threshold = overlap_threshold
95
+ self.object_mask_threshold = object_mask_threshold
96
+ self.metadata = metadata
97
+ if size_divisibility < 0:
98
+ # use backbone size_divisibility if not set
99
+ size_divisibility = self.backbone.size_divisibility
100
+ self.size_divisibility = size_divisibility
101
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
102
+ self.register_buffer("pixel_mean", torch.Tensor(
103
+ pixel_mean).view(-1, 1, 1), False)
104
+ self.register_buffer("pixel_std", torch.Tensor(
105
+ pixel_std).view(-1, 1, 1), False)
106
+
107
+ # additional args
108
+ self.semantic_on = semantic_on
109
+ self.instance_on = instance_on
110
+ self.panoptic_on = panoptic_on
111
+ self.test_topk_per_image = test_topk_per_image
112
+
113
+ if not self.semantic_on:
114
+ assert self.sem_seg_postprocess_before_inference
115
+
116
+ # OPD
117
+ self.motionnet_type = motionnet_type
118
+ self.voting = voting
119
+ self.gtdet = gtdet
120
+ self.inference_matcher = inference_matcher
121
+ self.gtextrinsic = gtextrinsic
122
+ self.only_DET = only_DET
123
+ self.obj_method = obj_method
124
+
125
+ @classmethod
126
+ def from_config(cls, cfg):
127
+ backbone = build_backbone(cfg)
128
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
129
+
130
+ # TODO: add mask2former backbone and semseghead to get object mask
131
+ if cfg.OBJ_DETECT:
132
+ mask2former_backbone = build_backbone(cfg.MASK2FORMER)
133
+ mask2former_sem_seg_head = build_sem_seg_head(
134
+ cfg.MASK2FORMER, backbone.output_shape())
135
+ else:
136
+ mask2former_backbone = None
137
+ mask2former_sem_seg_head = None
138
+
139
+ # Loss parameters:
140
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
141
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
142
+
143
+ # loss weights
144
+ class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
145
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
146
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
147
+ # OPD
148
+ mtype_weight = cfg.MODEL.MASK_FORMER.MTYPE_WEIGHT
149
+ morigin_weight = cfg.MODEL.MASK_FORMER.MORIGIN_WEIGHT
150
+ maxis_weight = cfg.MODEL.MASK_FORMER.MAXIS_WEIGHT
151
+ extrinsic_weight = cfg.MODEL.MASK_FORMER.EXTRINSIC_WEIGHT
152
+ mstate_weight = cfg.MODEL.MASK_FORMER.MSTATE_WEIGHT
153
+ mstatemax_weight = cfg.MODEL.MASK_FORMER.MSTATEMAX_WEIGHT
154
+
155
+ motionnet_type = cfg.MODEL.MOTIONNET.TYPE
156
+
157
+ # building criterion
158
+ matcher = HungarianMatcher(
159
+ cost_class=class_weight,
160
+ cost_mask=mask_weight,
161
+ cost_dice=dice_weight,
162
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
163
+ )
164
+
165
+ if "GTDET" in cfg.MODEL:
166
+ gtdet = cfg.MODEL.GTDET
167
+ else:
168
+ gtdet = False
169
+
170
+ if "GTEXTRINSIC" in cfg.MODEL:
171
+ gtextrinsic = cfg.MODEL.GTEXTRINSIC
172
+ else:
173
+ gtextrinsic = None
174
+
175
+ if gtdet or gtextrinsic:
176
+ # This inference matcher is used for GT ablation when inferencing
177
+ inference_matcher = matcher
178
+ else:
179
+ inference_matcher = None
180
+
181
+ if "ONLY_DET" in cfg.MODEL:
182
+ only_DET = cfg.MODEL.ONLY_DET
183
+ else:
184
+ only_DET = False
185
+
186
+ # OPD
187
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight, "loss_mtype": mtype_weight,
188
+ "loss_morigin": morigin_weight, "loss_maxis": maxis_weight, "loss_mstate": mstate_weight, "loss_mstatemax": mstatemax_weight}
189
+ if motionnet_type == "BMOC_V1" or motionnet_type == "BMOC_V2" or motionnet_type == "BMOC_V3" or motionnet_type == "BMOC_V4" or motionnet_type == "BMOC_V5" or motionnet_type == "BMOC_V6":
190
+ weight_dict["loss_extrinsic"] = extrinsic_weight
191
+
192
+ if deep_supervision:
193
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
194
+ aux_weight_dict = {}
195
+ for i in range(dec_layers - 1):
196
+ aux_weight_dict.update(
197
+ {k + f"_{i}": v for k, v in weight_dict.items()})
198
+ weight_dict.update(aux_weight_dict)
199
+
200
+ # OPD
201
+ if motionnet_type == "BMOC_V0":
202
+ weight_dict["loss_extrinsic"] = extrinsic_weight
203
+
204
+ # OPD
205
+ losses = ["labels", "masks", "mtypes", "morigins",
206
+ "maxises", "extrinsics", "mstates", "mstatemaxs"]
207
+
208
+ criterion = SetCriterion(
209
+ sem_seg_head.num_classes,
210
+ matcher=matcher,
211
+ weight_dict=weight_dict,
212
+ eos_coef=no_object_weight,
213
+ losses=losses,
214
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
215
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
216
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
217
+ motionnet_type=motionnet_type,
218
+ only_DET=only_DET,
219
+ )
220
+
221
+ # OPD
222
+ if "VOTING" in cfg.MODEL.MOTIONNET:
223
+ voting = cfg.MODEL.MOTIONNET.VOTING
224
+ else:
225
+ voting = None
226
+
227
+ return {
228
+ "backbone": backbone,
229
+ "sem_seg_head": sem_seg_head,
230
+ "mask2former_backbone": mask2former_backbone,
231
+ "mask2former_sem_seg_head": mask2former_sem_seg_head,
232
+ "criterion": criterion,
233
+ "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
234
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
235
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
236
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
237
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
238
+ "sem_seg_postprocess_before_inference": (
239
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
240
+ or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
241
+ or cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
242
+ ),
243
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
244
+ "pixel_std": cfg.MODEL.PIXEL_STD,
245
+ # inference
246
+ "semantic_on": cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON,
247
+ "instance_on": cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON,
248
+ "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
249
+ "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
250
+ # OPD
251
+ "motionnet_type": motionnet_type,
252
+ "voting": voting,
253
+ "gtdet": gtdet,
254
+ "inference_matcher": inference_matcher,
255
+ "gtextrinsic": gtextrinsic,
256
+ "only_DET": only_DET,
257
+ "obj_method": cfg.OBJ_DETECT
258
+ }
259
+
260
+ @property
261
+ def device(self):
262
+ return self.pixel_mean.device
263
+
264
+ def forward(self, batched_inputs):
265
+ """
266
+ Args:
267
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
268
+ Each item in the list contains the inputs for one image.
269
+ For now, each item in the list is a dict that contains:
270
+ * "image": Tensor, image in (C, H, W) format.
271
+ * "instances": per-region ground truth
272
+ * Other information that's included in the original dicts, such as:
273
+ "height", "width" (int): the output resolution of the model (may be different
274
+ from input resolution), used in inference.
275
+ Returns:
276
+ list[dict]:
277
+ each dict has the results for one image. The dict contains the following keys:
278
+
279
+ * "sem_seg":
280
+ A Tensor that represents the
281
+ per-pixel segmentation prediced by the head.
282
+ The prediction has shape KxHxW that represents the logits of
283
+ each class for each pixel.
284
+ * "panoptic_seg":
285
+ A tuple that represent panoptic output
286
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
287
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
288
+ Each dict contains keys "id", "category_id", "isthing".
289
+ """
290
+ images = [x["image"].to(self.device) for x in batched_inputs]
291
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
292
+ images = ImageList.from_tensors(images, self.size_divisibility)
293
+
294
+ # Load the targets if it's training or it's in the groundtruth ablation study
295
+ if self.training or self.gtdet or self.gtextrinsic:
296
+ # get the grpundtruth
297
+ if "instances" in batched_inputs[0]:
298
+ gt_instances = [x["instances"].to(
299
+ self.device) for x in batched_inputs]
300
+ targets = self.prepare_targets(gt_instances, images)
301
+ else:
302
+ targets = None
303
+
304
+ if not self.obj_method:
305
+ features = self.backbone(images.tensor)
306
+ outputs = self.sem_seg_head(features)
307
+ else:
308
+ # TODO: add freezed model to extract object mask.
309
+ for para in self.mask2former_backbone.parameters():
310
+ para.requires_grad = False
311
+ for para in self.mask2former_sem_seg_head.parameters():
312
+ para.requires_grad = False
313
+
314
+ obj_feature = self.mask2former_backbone(images.tensor)
315
+ obj_output = self.mask2former_sem_seg_head(obj_feature)
316
+
317
+ pred_obj_masks = obj_output["pred_masks"]
318
+ # prob_masks = torch.sigmoid(pred_obj_masks)
319
+ pred_cls_results = obj_output["pred_logits"]
320
+
321
+ # TODO: use object prediction to help object pose prediction, find a way to calculate the IoU of part and object mask
322
+ for indice, pred_obj_mask in enumerate(pred_obj_masks):
323
+ # get binary mask
324
+ for idx, mask in enumerate(pred_obj_mask):
325
+ max_score = torch.max(mask)
326
+ pred_obj_mask[idx] = (mask > (max_score*0.5)).float()
327
+
328
+ # replace the pred masks with binary masks
329
+ pred_obj_masks[indice] = pred_obj_mask
330
+
331
+ # import pdb
332
+ # pdb.set_trace()
333
+
334
+ features = self.backbone(images.tensor)
335
+ outputs = self.sem_seg_head(features, pred_obj_masks)
336
+
337
+ # import pdb
338
+ # pdb.set_trace()
339
+
340
+ if self.training:
341
+ # bipartite matching-based loss
342
+ losses = self.criterion(outputs, targets)
343
+
344
+ for k in list(losses.keys()):
345
+ if k in self.criterion.weight_dict:
346
+ losses[k] *= self.criterion.weight_dict[k]
347
+ else:
348
+ # remove this loss if not specified in `weight_dict`
349
+ print(f"Warning: {k} is not in loss")
350
+ losses.pop(k)
351
+ return losses
352
+ else:
353
+ mask_cls_results = outputs["pred_logits"]
354
+ mask_pred_results = outputs["pred_masks"]
355
+ # OPD
356
+ mask_mtype_results = outputs["pred_mtypes"]
357
+ mask_morigin_results = outputs["pred_morigins"]
358
+ mask_maxis_results = outputs["pred_maxises"]
359
+ mask_mstate_results = outputs["pred_mstates"]
360
+ mask_mstatemax_results = outputs["pred_mstatemaxs"]
361
+ if "BMOC" in self.motionnet_type:
362
+ mask_extrinsic_results = outputs["pred_extrinsics"]
363
+
364
+ # upsample masks
365
+ mask_pred_results = F.interpolate(
366
+ mask_pred_results,
367
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
368
+ mode="bilinear",
369
+ align_corners=False,
370
+ )
371
+
372
+ if self.gtdet or self.gtextrinsic:
373
+ if self.gtdet:
374
+ # Make other predictions be bad, so that they will not consider when evaluating
375
+ mask_pred_results[:, :, :, :] = -30
376
+ mask_cls_results[:, :, :3] = 0
377
+ mask_cls_results[:, :, 3] = 15 # weight for softmax
378
+ # Initialize the predicted class and predicted mask to the default value
379
+ if targets[0]["masks"].shape[0] != 0:
380
+ outputs_without_aux = {
381
+ k: v for k, v in outputs.items() if k != "aux_outputs"}
382
+ # Retrieve the matching between the outputs of the last layer and the targets
383
+ indices = self.inference_matcher(
384
+ outputs_without_aux, targets)
385
+
386
+ def _get_src_permutation_idx(indices):
387
+ # permute predictions following indices
388
+ batch_idx = torch.cat(
389
+ [torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
390
+ src_idx = torch.cat([src for (src, _) in indices])
391
+ return batch_idx, src_idx
392
+
393
+ def _get_tgt_permutation_idx(indices):
394
+ # permute targets following indices
395
+ batch_idx = torch.cat(
396
+ [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
397
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
398
+ return batch_idx, tgt_idx
399
+
400
+ src_idx = _get_src_permutation_idx(indices)
401
+ tgt_idx = _get_tgt_permutation_idx(indices)
402
+ if self.gtdet:
403
+ mask_pred_results[src_idx] = targets[0]["masks"].unsqueeze(0)[
404
+ tgt_idx].float() * 30
405
+ mask_pred_results[mask_pred_results == 0] = -30
406
+ mask_cls_results[src_idx] = F.one_hot(
407
+ targets[0]["labels"][tgt_idx[1]], num_classes=self.sem_seg_head.num_classes+1).float() * 15
408
+ if self.gtextrinsic:
409
+ if self.motionnet_type == "BMOC_V6":
410
+ gt_extrinsic_raw = targets[0]["gt_extrinsic"][0]
411
+ gt_extrinsic = torch.cat(
412
+ [
413
+ gt_extrinsic_raw[0:3],
414
+ gt_extrinsic_raw[4:7],
415
+ gt_extrinsic_raw[8:11],
416
+ gt_extrinsic_raw[12:15],
417
+ ],
418
+ 0,
419
+ )
420
+ mask_extrinsic_results[0] = gt_extrinsic
421
+ else:
422
+ raise ValueError("Not Implemented")
423
+
424
+ del outputs
425
+
426
+ if "BMOC" in self.motionnet_type:
427
+ processed_results = []
428
+ for mask_cls_result, mask_pred_result, input_per_image, image_size, mask_mtype_result, mask_morigin_result, mask_maxis_result, mask_mstate_result, mask_mstatemax_result, mask_extrinsic_result in zip(
429
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes, mask_mtype_results, mask_morigin_results, mask_maxis_results, mask_mstate_results, mask_mstatemax_results, mask_extrinsic_results
430
+ ):
431
+ height = input_per_image.get("height", image_size[0])
432
+ width = input_per_image.get("width", image_size[1])
433
+ processed_results.append({})
434
+
435
+ if self.sem_seg_postprocess_before_inference:
436
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
437
+ mask_pred_result, image_size, height, width
438
+ )
439
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
440
+ # OPD
441
+ mask_mtype_result = mask_mtype_result.to(
442
+ mask_pred_result)
443
+ mask_morigin_result = mask_morigin_result.to(
444
+ mask_pred_result)
445
+ mask_maxis_result = mask_maxis_result.to(
446
+ mask_pred_result)
447
+ mask_mstate_result = mask_mstate_result.to(
448
+ mask_pred_result)
449
+ mask_mstatemax_result = mask_mstatemax_result.to(
450
+ mask_pred_result)
451
+ mask_extrinsic_result = mask_extrinsic_result.to(
452
+ mask_pred_result)
453
+
454
+ # semantic segmentation inference
455
+ if self.semantic_on:
456
+ r = retry_if_cuda_oom(self.semantic_inference)(
457
+ mask_cls_result, mask_pred_result)
458
+ if not self.sem_seg_postprocess_before_inference:
459
+ r = retry_if_cuda_oom(sem_seg_postprocess)(
460
+ r, image_size, height, width)
461
+ processed_results[-1]["sem_seg"] = r
462
+
463
+ # panoptic segmentation inference
464
+ if self.panoptic_on:
465
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(
466
+ mask_cls_result, mask_pred_result)
467
+ processed_results[-1]["panoptic_seg"] = panoptic_r
468
+
469
+ # instance segmentation inference
470
+ if self.instance_on:
471
+ instance_r = retry_if_cuda_oom(self.instance_inference)(
472
+ mask_cls_result, mask_pred_result, mask_mtype_result, mask_morigin_result, mask_maxis_result, mask_mstate_result, mask_mstatemax_result, mask_extrinsic_result)
473
+ processed_results[-1]["instances"] = instance_r
474
+ else:
475
+ processed_results = []
476
+ for mask_cls_result, mask_pred_result, input_per_image, image_size, mask_mtype_result, mask_morigin_result, mask_maxis_result, mask_mstate_result, mask_mstatemax_result in zip(
477
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes, mask_mtype_results, mask_morigin_results, mask_maxis_results, mask_mstate_results, mask_mstatemax_results
478
+ ):
479
+ height = input_per_image.get("height", image_size[0])
480
+ width = input_per_image.get("width", image_size[1])
481
+ processed_results.append({})
482
+
483
+ if self.sem_seg_postprocess_before_inference:
484
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
485
+ mask_pred_result, image_size, height, width
486
+ )
487
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
488
+ # OPD
489
+ mask_mtype_result = mask_mtype_result.to(
490
+ mask_pred_result)
491
+ mask_morigin_result = mask_morigin_result.to(
492
+ mask_pred_result)
493
+ mask_maxis_result = mask_maxis_result.to(
494
+ mask_pred_result)
495
+ mask_mstate_result = mask_mstate_result.to(
496
+ mask_pred_result)
497
+ mask_mstatemax_result = mask_mstatemax_result.to(
498
+ mask_pred_result)
499
+
500
+ # semantic segmentation inference
501
+ if self.semantic_on:
502
+ r = retry_if_cuda_oom(self.semantic_inference)(
503
+ mask_cls_result, mask_pred_result)
504
+ if not self.sem_seg_postprocess_before_inference:
505
+ r = retry_if_cuda_oom(sem_seg_postprocess)(
506
+ r, image_size, height, width)
507
+ processed_results[-1]["sem_seg"] = r
508
+
509
+ # panoptic segmentation inference
510
+ if self.panoptic_on:
511
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(
512
+ mask_cls_result, mask_pred_result)
513
+ processed_results[-1]["panoptic_seg"] = panoptic_r
514
+
515
+ # instance segmentation inference
516
+ if self.instance_on:
517
+ instance_r = retry_if_cuda_oom(self.instance_inference)(
518
+ mask_cls_result, mask_pred_result, mask_mtype_result, mask_morigin_result, mask_maxis_result, mask_mstate_result, mask_mstatemax_result, None)
519
+ processed_results[-1]["instances"] = instance_r
520
+
521
+ return processed_results
522
+
523
+ def prepare_targets(self, targets, images):
524
+ h_pad, w_pad = images.tensor.shape[-2:]
525
+ new_targets = []
526
+ for targets_per_image in targets:
527
+ if hasattr(targets_per_image, "gt_masks"):
528
+ # pad gt
529
+ gt_masks = targets_per_image.gt_masks
530
+ padded_masks = torch.zeros(
531
+ (gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
532
+ padded_masks[:, : gt_masks.shape[1],
533
+ : gt_masks.shape[2]] = gt_masks
534
+ else:
535
+ padded_masks = torch.tensor([])
536
+ if "BMOC" in self.motionnet_type:
537
+ new_targets.append(
538
+ {
539
+ "labels": targets_per_image.gt_classes,
540
+ "masks": padded_masks,
541
+ # OPD
542
+ "gt_motion_valids": targets_per_image.gt_motion_valids,
543
+ "gt_types": targets_per_image.gt_types,
544
+ "gt_origins": targets_per_image.gt_origins,
545
+ "gt_axises": targets_per_image.gt_axises,
546
+ "gt_states": targets_per_image.gt_states,
547
+ "gt_statemaxs": targets_per_image.gt_statemaxs,
548
+ "gt_extrinsic": targets_per_image.gt_extrinsic,
549
+ "gt_extrinsic_quaternion": targets_per_image.gt_extrinsic_quaternion,
550
+ "gt_extrinsic_6d": targets_per_image.gt_extrinsic_6d,
551
+ }
552
+ )
553
+ else:
554
+ new_targets.append(
555
+ {
556
+ "labels": targets_per_image.gt_classes,
557
+ "masks": padded_masks,
558
+ # OPD
559
+ "gt_motion_valids": targets_per_image.gt_motion_valids,
560
+ "gt_types": targets_per_image.gt_types,
561
+ "gt_origins": targets_per_image.gt_origins,
562
+ "gt_axises": targets_per_image.gt_axises,
563
+ "gt_states": targets_per_image.gt_states,
564
+ "gt_statemaxs": targets_per_image.gt_statemaxs,
565
+ }
566
+ )
567
+ return new_targets
568
+
569
+ def semantic_inference(self, mask_cls, mask_pred):
570
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
571
+ mask_pred = mask_pred.sigmoid()
572
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
573
+ return semseg
574
+
575
+ def panoptic_inference(self, mask_cls, mask_pred):
576
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
577
+ mask_pred = mask_pred.sigmoid()
578
+
579
+ keep = labels.ne(self.sem_seg_head.num_classes) & (
580
+ scores > self.object_mask_threshold)
581
+ cur_scores = scores[keep]
582
+ cur_classes = labels[keep]
583
+ cur_masks = mask_pred[keep]
584
+ cur_mask_cls = mask_cls[keep]
585
+ cur_mask_cls = cur_mask_cls[:, :-1]
586
+
587
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
588
+
589
+ h, w = cur_masks.shape[-2:]
590
+ panoptic_seg = torch.zeros(
591
+ (h, w), dtype=torch.int32, device=cur_masks.device)
592
+ segments_info = []
593
+
594
+ current_segment_id = 0
595
+
596
+ if cur_masks.shape[0] == 0:
597
+ # We didn't detect any mask :(
598
+ return panoptic_seg, segments_info
599
+ else:
600
+ # take argmax
601
+ cur_mask_ids = cur_prob_masks.argmax(0)
602
+ stuff_memory_list = {}
603
+ for k in range(cur_classes.shape[0]):
604
+ pred_class = cur_classes[k].item()
605
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
606
+ mask_area = (cur_mask_ids == k).sum().item()
607
+ original_area = (cur_masks[k] >= 0.5).sum().item()
608
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
609
+
610
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
611
+ if mask_area / original_area < self.overlap_threshold:
612
+ continue
613
+
614
+ # merge stuff regions
615
+ if not isthing:
616
+ if int(pred_class) in stuff_memory_list.keys():
617
+ panoptic_seg[mask] = stuff_memory_list[int(
618
+ pred_class)]
619
+ continue
620
+ else:
621
+ stuff_memory_list[int(
622
+ pred_class)] = current_segment_id + 1
623
+
624
+ current_segment_id += 1
625
+ panoptic_seg[mask] = current_segment_id
626
+
627
+ segments_info.append(
628
+ {
629
+ "id": current_segment_id,
630
+ "isthing": bool(isthing),
631
+ "category_id": int(pred_class),
632
+ }
633
+ )
634
+
635
+ return panoptic_seg, segments_info
636
+
637
+ # Voting algorithms for inference
638
+ def votingProcess(self, x, voting):
639
+ device = x.device
640
+ if voting == "median":
641
+ final = torch.median(x, axis=0)[0]
642
+ elif voting == "mean":
643
+ final = torch.mean(x, axis=0)
644
+ elif voting == "geo-median":
645
+ x = x.detach().cpu().numpy()
646
+ final = geometric_median(x)
647
+ final = torch.from_numpy(final).to(device)
648
+ return final
649
+
650
+ def convert_to_valid_extrinsic(self, mask_extrinsic, dim=0):
651
+ if dim == 0:
652
+ translation = mask_extrinsic[9:12]
653
+ rotation_mat = quaternion_to_matrix(matrix_to_quaternion(
654
+ torch.transpose(mask_extrinsic[:9].reshape(3, 3), 0, 1)))
655
+ rotation_vector = torch.flatten(rotation_mat.transpose(0, 1))
656
+ final_mask_extrinsic = torch.cat((rotation_vector, translation))
657
+ elif dim == 1:
658
+ translation = mask_extrinsic[:, 9:12]
659
+ rotation_mat = quaternion_to_matrix(matrix_to_quaternion(
660
+ torch.transpose(mask_extrinsic[:, :9].reshape(-1, 3, 3), 1, 2)))
661
+ rotation_vector = torch.flatten(
662
+ rotation_mat.transpose(1, 2), start_dim=1)
663
+ final_mask_extrinsic = torch.cat(
664
+ (rotation_vector, translation), dim=1)
665
+ return final_mask_extrinsic
666
+
667
+ def instance_inference(self, mask_cls, mask_pred, mask_mtype, mask_morigin, mask_maxis, mask_mstate, mask_mstatemax, mask_extrinsic):
668
+ # mask_pred is already processed to have the same shape as original input
669
+ image_size = mask_pred.shape[-2:]
670
+
671
+ # [Q, K]
672
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
673
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(
674
+ 0).repeat(self.num_queries, 1).flatten(0, 1)
675
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
676
+ scores_per_image, topk_indices = scores.flatten(
677
+ 0, 1).topk(self.test_topk_per_image, sorted=False)
678
+ labels_per_image = labels[topk_indices]
679
+
680
+ topk_indices = topk_indices // self.sem_seg_head.num_classes
681
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
682
+ mask_pred = mask_pred[topk_indices]
683
+
684
+ # OPD
685
+ mask_mtype = mask_mtype[topk_indices]
686
+ pred_probs = F.softmax(mask_mtype, dim=1)
687
+ mask_mtype = torch.argmax(pred_probs, 1).float()
688
+
689
+ mask_morigin = mask_morigin[topk_indices]
690
+ mask_maxis = mask_maxis[topk_indices]
691
+ mask_mstate = mask_mstate[topk_indices]
692
+ mask_mstatemax = mask_mstatemax[topk_indices]
693
+
694
+ if self.motionnet_type == "BMOC_V1":
695
+ mask_extrinsic = mask_extrinsic[topk_indices]
696
+ mask_extrinsic = self.convert_to_valid_extrinsic(
697
+ mask_extrinsic, dim=1)
698
+ if self.voting != "none":
699
+ final_translation = torch.median(
700
+ mask_extrinsic[:, 9:12], axis=0)[0]
701
+ quaternions = matrix_to_quaternion(torch.transpose(
702
+ mask_extrinsic[:, :9].reshape(-1, 3, 3), 1, 2))
703
+ final_quaternion = self.votingProcess(quaternions, self.voting)
704
+ final_rotation = quaternion_to_matrix(final_quaternion)
705
+ final_rotation_vector = torch.flatten(
706
+ final_rotation.transpose(0, 1))
707
+ mask_extrinsic = torch.cat(
708
+ (final_rotation_vector, final_translation))
709
+ elif self.motionnet_type == "BMOC_V2":
710
+ mask_extrinsic = mask_extrinsic[topk_indices]
711
+ if self.voting != "none":
712
+ final_translation = torch.median(
713
+ mask_extrinsic[:, 4:7], axis=0)[0]
714
+ final_quaternion = self.votingProcess(
715
+ mask_extrinsic[:, :4], self.voting)
716
+ final_rotation = quaternion_to_matrix(final_quaternion)
717
+ final_rotation_vector = torch.flatten(
718
+ final_rotation.transpose(0, 1))
719
+ mask_extrinsic = torch.cat(
720
+ (final_rotation_vector, final_translation))
721
+ elif self.voting == "none":
722
+ translations = mask_extrinsic[:, 4:7]
723
+ quaternions = mask_extrinsic[:, :4]
724
+ rotation_vector = torch.flatten(
725
+ quaternion_to_matrix(quaternions).transpose(1, 2), 1)
726
+ mask_extrinsic = torch.cat((rotation_vector, translations), 1)
727
+ elif self.motionnet_type == "BMOC_V3":
728
+ mask_extrinsic = mask_extrinsic[topk_indices]
729
+ if self.voting != "none":
730
+ final_translation = torch.median(
731
+ mask_extrinsic[:, 6:9], axis=0)[0]
732
+ final_6d = self.votingProcess(
733
+ mask_extrinsic[:, :6], self.voting)
734
+ final_rotation = rotation_6d_to_matrix(final_6d)
735
+ final_rotation_vector = torch.flatten(
736
+ final_rotation.transpose(0, 1))
737
+ mask_extrinsic = torch.cat(
738
+ (final_rotation_vector, final_translation))
739
+ elif self.voting == "none":
740
+ translations = mask_extrinsic[:, 6:9]
741
+ rotation_6ds = mask_extrinsic[:, :6]
742
+ rotation_vector = torch.flatten(
743
+ rotation_6d_to_matrix(rotation_6ds).transpose(1, 2), 1)
744
+ mask_extrinsic = torch.cat((rotation_vector, translations), 1)
745
+ elif self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5":
746
+ translation = mask_extrinsic[4:7]
747
+ quaternion = mask_extrinsic[:4]
748
+ rotation_vector = torch.flatten(
749
+ quaternion_to_matrix(quaternion).transpose(0, 1))
750
+ mask_extrinsic = torch.cat((rotation_vector, translation))
751
+ elif self.motionnet_type == "BMOC_V0" or self.motionnet_type == "BMOC_V6":
752
+ mask_extrinsic = self.convert_to_valid_extrinsic(
753
+ mask_extrinsic, dim=0)
754
+
755
+ if "BMOC" in self.motionnet_type:
756
+ # Use the predicted extrinsic matrix to convert the predicted morigin and maxis back to camera coordinate
757
+ maxis_end = mask_morigin + mask_maxis
758
+ mextrinsic_c2w = torch.eye(4, device=mask_morigin.device).repeat(
759
+ mask_morigin.shape[0], 1, 1
760
+ )
761
+
762
+ if self.motionnet_type == "BMOC_V0" or self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6" or (self.motionnet_type == "BMOC_V1" and self.voting != "none") or (self.motionnet_type == "BMOC_V2" and self.voting != "none") or (self.motionnet_type == "BMOC_V3" and self.voting != "none"):
763
+ mextrinsic_c2w[:, 0:3, 0:4] = torch.transpose(
764
+ mask_extrinsic.reshape(4, 3).repeat(
765
+ mask_morigin.shape[0], 1, 1), 1, 2
766
+ )
767
+ elif self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3":
768
+ mextrinsic_c2w[:, 0:3, 0:4] = torch.transpose(
769
+ mask_extrinsic.reshape(-1, 4, 3), 1, 2
770
+ )
771
+ mextrinsic_w2c = torch.inverse(mextrinsic_c2w)
772
+ mask_morigin = (
773
+ torch.matmul(
774
+ mextrinsic_w2c[:, :3,
775
+ :3], mask_morigin.unsqueeze(2)
776
+ ).squeeze(2)
777
+ + mextrinsic_w2c[:, :3, 3]
778
+ )
779
+ end_in_cam = (
780
+ torch.matmul(
781
+ mextrinsic_w2c[:, :3, :3], maxis_end.unsqueeze(2)
782
+ ).squeeze(2)
783
+ + mextrinsic_w2c[:, :3, 3]
784
+ )
785
+ mask_maxis = end_in_cam - mask_morigin
786
+
787
+ # if this is panoptic segmentation, we only keep the "thing" classes
788
+ if self.panoptic_on:
789
+ keep = torch.zeros_like(scores_per_image).bool()
790
+ for i, lab in enumerate(labels_per_image):
791
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
792
+
793
+ scores_per_image = scores_per_image[keep]
794
+ labels_per_image = labels_per_image[keep]
795
+ mask_pred = mask_pred[keep]
796
+
797
+ result = Instances(image_size)
798
+ # mask (before sigmoid)
799
+ result.pred_masks = (mask_pred > 0).float()
800
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
801
+ # Uncomment the following to get boxes from masks (this is slow)
802
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
803
+
804
+ # calculate average mask prob
805
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(
806
+ 1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
807
+ result.scores = scores_per_image * mask_scores_per_image
808
+ result.pred_classes = labels_per_image
809
+
810
+ # OPD
811
+ result.mtype = mask_mtype
812
+ result.morigin = mask_morigin
813
+ result.maxis = mask_maxis
814
+ result.mstate = mask_mstate
815
+ result.mstatemax = mask_mstatemax
816
+ if self.motionnet_type == "BMOC_V0" or self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6" or (self.motionnet_type == "BMOC_V1" and self.voting != "none") or (self.motionnet_type == "BMOC_V2" and self.voting != "none") or (self.motionnet_type == "BMOC_V3" and self.voting != "none"):
817
+ result.mextrinsic = mask_extrinsic.repeat(mask_morigin.shape[0], 1)
818
+ elif self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3":
819
+ result.mextrinsic = mask_extrinsic
820
+ return result
mask2former/modeling/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .backbone.swin import D2SwinTransformer
3
+ from .pixel_decoder.fpn import BasePixelDecoder
4
+ from .pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder
5
+ from .meta_arch.mask_former_head import MaskFormerHead
6
+ from .meta_arch.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead
mask2former/modeling/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
mask2former/modeling/backbone/swin.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
17
+
18
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ """Multilayer perceptron."""
23
+
24
+ def __init__(
25
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
26
+ ):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ def window_partition(x, window_size):
45
+ """
46
+ Args:
47
+ x: (B, H, W, C)
48
+ window_size (int): window size
49
+ Returns:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ """
52
+ B, H, W, C = x.shape
53
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
54
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
55
+ return windows
56
+
57
+
58
+ def window_reverse(windows, window_size, H, W):
59
+ """
60
+ Args:
61
+ windows: (num_windows*B, window_size, window_size, C)
62
+ window_size (int): Window size
63
+ H (int): Height of image
64
+ W (int): Width of image
65
+ Returns:
66
+ x: (B, H, W, C)
67
+ """
68
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
69
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
70
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
71
+ return x
72
+
73
+
74
+ class WindowAttention(nn.Module):
75
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
76
+ It supports both of shifted and non-shifted window.
77
+ Args:
78
+ dim (int): Number of input channels.
79
+ window_size (tuple[int]): The height and width of the window.
80
+ num_heads (int): Number of attention heads.
81
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
82
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
83
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
84
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ window_size,
91
+ num_heads,
92
+ qkv_bias=True,
93
+ qk_scale=None,
94
+ attn_drop=0.0,
95
+ proj_drop=0.0,
96
+ ):
97
+
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.window_size = window_size # Wh, Ww
101
+ self.num_heads = num_heads
102
+ head_dim = dim // num_heads
103
+ self.scale = qk_scale or head_dim ** -0.5
104
+
105
+ # define a parameter table of relative position bias
106
+ self.relative_position_bias_table = nn.Parameter(
107
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
108
+ ) # 2*Wh-1 * 2*Ww-1, nH
109
+
110
+ # get pair-wise relative position index for each token inside the window
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+
123
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ self.proj = nn.Linear(dim, dim)
126
+ self.proj_drop = nn.Dropout(proj_drop)
127
+
128
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
129
+ self.softmax = nn.Softmax(dim=-1)
130
+
131
+ def forward(self, x, mask=None):
132
+ """Forward function.
133
+ Args:
134
+ x: input features with shape of (num_windows*B, N, C)
135
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
136
+ """
137
+ B_, N, C = x.shape
138
+ qkv = (
139
+ self.qkv(x)
140
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
141
+ .permute(2, 0, 3, 1, 4)
142
+ )
143
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
144
+
145
+ q = q * self.scale
146
+ attn = q @ k.transpose(-2, -1)
147
+
148
+ relative_position_bias = self.relative_position_bias_table[
149
+ self.relative_position_index.view(-1)
150
+ ].view(
151
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
152
+ ) # Wh*Ww,Wh*Ww,nH
153
+ relative_position_bias = relative_position_bias.permute(
154
+ 2, 0, 1
155
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
156
+ attn = attn + relative_position_bias.unsqueeze(0)
157
+
158
+ if mask is not None:
159
+ nW = mask.shape[0]
160
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
161
+ attn = attn.view(-1, self.num_heads, N, N)
162
+ attn = self.softmax(attn)
163
+ else:
164
+ attn = self.softmax(attn)
165
+
166
+ attn = self.attn_drop(attn)
167
+
168
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
169
+ x = self.proj(x)
170
+ x = self.proj_drop(x)
171
+ return x
172
+
173
+
174
+ class SwinTransformerBlock(nn.Module):
175
+ """Swin Transformer Block.
176
+ Args:
177
+ dim (int): Number of input channels.
178
+ num_heads (int): Number of attention heads.
179
+ window_size (int): Window size.
180
+ shift_size (int): Shift size for SW-MSA.
181
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
182
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
183
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
184
+ drop (float, optional): Dropout rate. Default: 0.0
185
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
186
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
187
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
188
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ num_heads,
195
+ window_size=7,
196
+ shift_size=0,
197
+ mlp_ratio=4.0,
198
+ qkv_bias=True,
199
+ qk_scale=None,
200
+ drop=0.0,
201
+ attn_drop=0.0,
202
+ drop_path=0.0,
203
+ act_layer=nn.GELU,
204
+ norm_layer=nn.LayerNorm,
205
+ ):
206
+ super().__init__()
207
+ self.dim = dim
208
+ self.num_heads = num_heads
209
+ self.window_size = window_size
210
+ self.shift_size = shift_size
211
+ self.mlp_ratio = mlp_ratio
212
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
213
+
214
+ self.norm1 = norm_layer(dim)
215
+ self.attn = WindowAttention(
216
+ dim,
217
+ window_size=to_2tuple(self.window_size),
218
+ num_heads=num_heads,
219
+ qkv_bias=qkv_bias,
220
+ qk_scale=qk_scale,
221
+ attn_drop=attn_drop,
222
+ proj_drop=drop,
223
+ )
224
+
225
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
226
+ self.norm2 = norm_layer(dim)
227
+ mlp_hidden_dim = int(dim * mlp_ratio)
228
+ self.mlp = Mlp(
229
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
230
+ )
231
+
232
+ self.H = None
233
+ self.W = None
234
+
235
+ def forward(self, x, mask_matrix):
236
+ """Forward function.
237
+ Args:
238
+ x: Input feature, tensor size (B, H*W, C).
239
+ H, W: Spatial resolution of the input feature.
240
+ mask_matrix: Attention mask for cyclic shift.
241
+ """
242
+ B, L, C = x.shape
243
+ H, W = self.H, self.W
244
+ assert L == H * W, "input feature has wrong size"
245
+
246
+ shortcut = x
247
+ x = self.norm1(x)
248
+ x = x.view(B, H, W, C)
249
+
250
+ # pad feature maps to multiples of window size
251
+ pad_l = pad_t = 0
252
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
253
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
254
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
255
+ _, Hp, Wp, _ = x.shape
256
+
257
+ # cyclic shift
258
+ if self.shift_size > 0:
259
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
260
+ attn_mask = mask_matrix
261
+ else:
262
+ shifted_x = x
263
+ attn_mask = None
264
+
265
+ # partition windows
266
+ x_windows = window_partition(
267
+ shifted_x, self.window_size
268
+ ) # nW*B, window_size, window_size, C
269
+ x_windows = x_windows.view(
270
+ -1, self.window_size * self.window_size, C
271
+ ) # nW*B, window_size*window_size, C
272
+
273
+ # W-MSA/SW-MSA
274
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
275
+
276
+ # merge windows
277
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
278
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
279
+
280
+ # reverse cyclic shift
281
+ if self.shift_size > 0:
282
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
283
+ else:
284
+ x = shifted_x
285
+
286
+ if pad_r > 0 or pad_b > 0:
287
+ x = x[:, :H, :W, :].contiguous()
288
+
289
+ x = x.view(B, H * W, C)
290
+
291
+ # FFN
292
+ x = shortcut + self.drop_path(x)
293
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
294
+
295
+ return x
296
+
297
+
298
+ class PatchMerging(nn.Module):
299
+ """Patch Merging Layer
300
+ Args:
301
+ dim (int): Number of input channels.
302
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
303
+ """
304
+
305
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
306
+ super().__init__()
307
+ self.dim = dim
308
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
309
+ self.norm = norm_layer(4 * dim)
310
+
311
+ def forward(self, x, H, W):
312
+ """Forward function.
313
+ Args:
314
+ x: Input feature, tensor size (B, H*W, C).
315
+ H, W: Spatial resolution of the input feature.
316
+ """
317
+ B, L, C = x.shape
318
+ assert L == H * W, "input feature has wrong size"
319
+
320
+ x = x.view(B, H, W, C)
321
+
322
+ # padding
323
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
324
+ if pad_input:
325
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+
340
+ class BasicLayer(nn.Module):
341
+ """A basic Swin Transformer layer for one stage.
342
+ Args:
343
+ dim (int): Number of feature channels
344
+ depth (int): Depths of this stage.
345
+ num_heads (int): Number of attention head.
346
+ window_size (int): Local window size. Default: 7.
347
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
348
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
349
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
350
+ drop (float, optional): Dropout rate. Default: 0.0
351
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
352
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
353
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
354
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
355
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ dim,
361
+ depth,
362
+ num_heads,
363
+ window_size=7,
364
+ mlp_ratio=4.0,
365
+ qkv_bias=True,
366
+ qk_scale=None,
367
+ drop=0.0,
368
+ attn_drop=0.0,
369
+ drop_path=0.0,
370
+ norm_layer=nn.LayerNorm,
371
+ downsample=None,
372
+ use_checkpoint=False,
373
+ ):
374
+ super().__init__()
375
+ self.window_size = window_size
376
+ self.shift_size = window_size // 2
377
+ self.depth = depth
378
+ self.use_checkpoint = use_checkpoint
379
+
380
+ # build blocks
381
+ self.blocks = nn.ModuleList(
382
+ [
383
+ SwinTransformerBlock(
384
+ dim=dim,
385
+ num_heads=num_heads,
386
+ window_size=window_size,
387
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
388
+ mlp_ratio=mlp_ratio,
389
+ qkv_bias=qkv_bias,
390
+ qk_scale=qk_scale,
391
+ drop=drop,
392
+ attn_drop=attn_drop,
393
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
394
+ norm_layer=norm_layer,
395
+ )
396
+ for i in range(depth)
397
+ ]
398
+ )
399
+
400
+ # patch merging layer
401
+ if downsample is not None:
402
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
403
+ else:
404
+ self.downsample = None
405
+
406
+ def forward(self, x, H, W):
407
+ """Forward function.
408
+ Args:
409
+ x: Input feature, tensor size (B, H*W, C).
410
+ H, W: Spatial resolution of the input feature.
411
+ """
412
+
413
+ # calculate attention mask for SW-MSA
414
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
415
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
416
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
417
+ h_slices = (
418
+ slice(0, -self.window_size),
419
+ slice(-self.window_size, -self.shift_size),
420
+ slice(-self.shift_size, None),
421
+ )
422
+ w_slices = (
423
+ slice(0, -self.window_size),
424
+ slice(-self.window_size, -self.shift_size),
425
+ slice(-self.shift_size, None),
426
+ )
427
+ cnt = 0
428
+ for h in h_slices:
429
+ for w in w_slices:
430
+ img_mask[:, h, w, :] = cnt
431
+ cnt += 1
432
+
433
+ mask_windows = window_partition(
434
+ img_mask, self.window_size
435
+ ) # nW, window_size, window_size, 1
436
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
437
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
438
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
439
+ attn_mask == 0, float(0.0)
440
+ )
441
+
442
+ for blk in self.blocks:
443
+ blk.H, blk.W = H, W
444
+ if self.use_checkpoint:
445
+ x = checkpoint.checkpoint(blk, x, attn_mask)
446
+ else:
447
+ x = blk(x, attn_mask)
448
+ if self.downsample is not None:
449
+ x_down = self.downsample(x, H, W)
450
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
451
+ return x, H, W, x_down, Wh, Ww
452
+ else:
453
+ return x, H, W, x, H, W
454
+
455
+
456
+ class PatchEmbed(nn.Module):
457
+ """Image to Patch Embedding
458
+ Args:
459
+ patch_size (int): Patch token size. Default: 4.
460
+ in_chans (int): Number of input image channels. Default: 3.
461
+ embed_dim (int): Number of linear projection output channels. Default: 96.
462
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
463
+ """
464
+
465
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
466
+ super().__init__()
467
+ patch_size = to_2tuple(patch_size)
468
+ self.patch_size = patch_size
469
+
470
+ self.in_chans = in_chans
471
+ self.embed_dim = embed_dim
472
+
473
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
474
+ if norm_layer is not None:
475
+ self.norm = norm_layer(embed_dim)
476
+ else:
477
+ self.norm = None
478
+
479
+ def forward(self, x):
480
+ """Forward function."""
481
+ # padding
482
+ _, _, H, W = x.size()
483
+ if W % self.patch_size[1] != 0:
484
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
485
+ if H % self.patch_size[0] != 0:
486
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
487
+
488
+ x = self.proj(x) # B C Wh Ww
489
+ if self.norm is not None:
490
+ Wh, Ww = x.size(2), x.size(3)
491
+ x = x.flatten(2).transpose(1, 2)
492
+ x = self.norm(x)
493
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
494
+
495
+ return x
496
+
497
+
498
+ class SwinTransformer(nn.Module):
499
+ """Swin Transformer backbone.
500
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
501
+ https://arxiv.org/pdf/2103.14030
502
+ Args:
503
+ pretrain_img_size (int): Input image size for training the pretrained model,
504
+ used in absolute postion embedding. Default 224.
505
+ patch_size (int | tuple(int)): Patch size. Default: 4.
506
+ in_chans (int): Number of input image channels. Default: 3.
507
+ embed_dim (int): Number of linear projection output channels. Default: 96.
508
+ depths (tuple[int]): Depths of each Swin Transformer stage.
509
+ num_heads (tuple[int]): Number of attention head of each stage.
510
+ window_size (int): Window size. Default: 7.
511
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
512
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
513
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
514
+ drop_rate (float): Dropout rate.
515
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
516
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
517
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
518
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
519
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
520
+ out_indices (Sequence[int]): Output from which stages.
521
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
522
+ -1 means not freezing any parameters.
523
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ pretrain_img_size=224,
529
+ patch_size=4,
530
+ in_chans=3,
531
+ embed_dim=96,
532
+ depths=[2, 2, 6, 2],
533
+ num_heads=[3, 6, 12, 24],
534
+ window_size=7,
535
+ mlp_ratio=4.0,
536
+ qkv_bias=True,
537
+ qk_scale=None,
538
+ drop_rate=0.0,
539
+ attn_drop_rate=0.0,
540
+ drop_path_rate=0.2,
541
+ norm_layer=nn.LayerNorm,
542
+ ape=False,
543
+ patch_norm=True,
544
+ out_indices=(0, 1, 2, 3),
545
+ frozen_stages=-1,
546
+ use_checkpoint=False,
547
+ ):
548
+ super().__init__()
549
+
550
+ self.pretrain_img_size = pretrain_img_size
551
+ self.num_layers = len(depths)
552
+ self.embed_dim = embed_dim
553
+ self.ape = ape
554
+ self.patch_norm = patch_norm
555
+ self.out_indices = out_indices
556
+ self.frozen_stages = frozen_stages
557
+
558
+ # split image into non-overlapping patches
559
+ self.patch_embed = PatchEmbed(
560
+ patch_size=patch_size,
561
+ in_chans=in_chans,
562
+ embed_dim=embed_dim,
563
+ norm_layer=norm_layer if self.patch_norm else None,
564
+ )
565
+
566
+ # absolute position embedding
567
+ if self.ape:
568
+ pretrain_img_size = to_2tuple(pretrain_img_size)
569
+ patch_size = to_2tuple(patch_size)
570
+ patches_resolution = [
571
+ pretrain_img_size[0] // patch_size[0],
572
+ pretrain_img_size[1] // patch_size[1],
573
+ ]
574
+
575
+ self.absolute_pos_embed = nn.Parameter(
576
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
577
+ )
578
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
579
+
580
+ self.pos_drop = nn.Dropout(p=drop_rate)
581
+
582
+ # stochastic depth
583
+ dpr = [
584
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
585
+ ] # stochastic depth decay rule
586
+
587
+ # build layers
588
+ self.layers = nn.ModuleList()
589
+ for i_layer in range(self.num_layers):
590
+ layer = BasicLayer(
591
+ dim=int(embed_dim * 2 ** i_layer),
592
+ depth=depths[i_layer],
593
+ num_heads=num_heads[i_layer],
594
+ window_size=window_size,
595
+ mlp_ratio=mlp_ratio,
596
+ qkv_bias=qkv_bias,
597
+ qk_scale=qk_scale,
598
+ drop=drop_rate,
599
+ attn_drop=attn_drop_rate,
600
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
601
+ norm_layer=norm_layer,
602
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
603
+ use_checkpoint=use_checkpoint,
604
+ )
605
+ self.layers.append(layer)
606
+
607
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
608
+ self.num_features = num_features
609
+
610
+ # add a norm layer for each output
611
+ for i_layer in out_indices:
612
+ layer = norm_layer(num_features[i_layer])
613
+ layer_name = f"norm{i_layer}"
614
+ self.add_module(layer_name, layer)
615
+
616
+ self._freeze_stages()
617
+
618
+ def _freeze_stages(self):
619
+ if self.frozen_stages >= 0:
620
+ self.patch_embed.eval()
621
+ for param in self.patch_embed.parameters():
622
+ param.requires_grad = False
623
+
624
+ if self.frozen_stages >= 1 and self.ape:
625
+ self.absolute_pos_embed.requires_grad = False
626
+
627
+ if self.frozen_stages >= 2:
628
+ self.pos_drop.eval()
629
+ for i in range(0, self.frozen_stages - 1):
630
+ m = self.layers[i]
631
+ m.eval()
632
+ for param in m.parameters():
633
+ param.requires_grad = False
634
+
635
+ def init_weights(self, pretrained=None):
636
+ """Initialize the weights in backbone.
637
+ Args:
638
+ pretrained (str, optional): Path to pre-trained weights.
639
+ Defaults to None.
640
+ """
641
+
642
+ def _init_weights(m):
643
+ if isinstance(m, nn.Linear):
644
+ trunc_normal_(m.weight, std=0.02)
645
+ if isinstance(m, nn.Linear) and m.bias is not None:
646
+ nn.init.constant_(m.bias, 0)
647
+ elif isinstance(m, nn.LayerNorm):
648
+ nn.init.constant_(m.bias, 0)
649
+ nn.init.constant_(m.weight, 1.0)
650
+
651
+ def forward(self, x):
652
+ """Forward function."""
653
+ x = self.patch_embed(x)
654
+
655
+ Wh, Ww = x.size(2), x.size(3)
656
+ if self.ape:
657
+ # interpolate the position embedding to the corresponding size
658
+ absolute_pos_embed = F.interpolate(
659
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
660
+ )
661
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
662
+ else:
663
+ x = x.flatten(2).transpose(1, 2)
664
+ x = self.pos_drop(x)
665
+
666
+ outs = {}
667
+ for i in range(self.num_layers):
668
+ layer = self.layers[i]
669
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
670
+
671
+ if i in self.out_indices:
672
+ norm_layer = getattr(self, f"norm{i}")
673
+ x_out = norm_layer(x_out)
674
+
675
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
676
+ outs["res{}".format(i + 2)] = out
677
+
678
+ return outs
679
+
680
+ def train(self, mode=True):
681
+ """Convert the model into training mode while keep layers freezed."""
682
+ super(SwinTransformer, self).train(mode)
683
+ self._freeze_stages()
684
+
685
+
686
+ @BACKBONE_REGISTRY.register()
687
+ class D2SwinTransformer(SwinTransformer, Backbone):
688
+ def __init__(self, cfg, input_shape):
689
+
690
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
691
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
692
+ in_chans = 3
693
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
694
+ depths = cfg.MODEL.SWIN.DEPTHS
695
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
696
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
697
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
698
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
699
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
700
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
701
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
702
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
703
+ norm_layer = nn.LayerNorm
704
+ ape = cfg.MODEL.SWIN.APE
705
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
706
+ use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT
707
+
708
+ super().__init__(
709
+ pretrain_img_size,
710
+ patch_size,
711
+ in_chans,
712
+ embed_dim,
713
+ depths,
714
+ num_heads,
715
+ window_size,
716
+ mlp_ratio,
717
+ qkv_bias,
718
+ qk_scale,
719
+ drop_rate,
720
+ attn_drop_rate,
721
+ drop_path_rate,
722
+ norm_layer,
723
+ ape,
724
+ patch_norm,
725
+ use_checkpoint=use_checkpoint,
726
+ )
727
+
728
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
729
+
730
+ self._out_feature_strides = {
731
+ "res2": 4,
732
+ "res3": 8,
733
+ "res4": 16,
734
+ "res5": 32,
735
+ }
736
+ self._out_feature_channels = {
737
+ "res2": self.num_features[0],
738
+ "res3": self.num_features[1],
739
+ "res4": self.num_features[2],
740
+ "res5": self.num_features[3],
741
+ }
742
+
743
+ def forward(self, x):
744
+ """
745
+ Args:
746
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
747
+ Returns:
748
+ dict[str->Tensor]: names and the corresponding features
749
+ """
750
+ assert (
751
+ x.dim() == 4
752
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
753
+ outputs = {}
754
+ y = super().forward(x)
755
+ for k in y.keys():
756
+ if k in self._out_features:
757
+ outputs[k] = y[k]
758
+ return outputs
759
+
760
+ def output_shape(self):
761
+ return {
762
+ name: ShapeSpec(
763
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
764
+ )
765
+ for name in self._out_features
766
+ }
767
+
768
+ @property
769
+ def size_divisibility(self):
770
+ return 32
mask2former/modeling/criterion.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ """
4
+ MaskFormer criterion.
5
+ """
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from detectron2.utils.comm import get_world_size
13
+ from detectron2.projects.point_rend.point_features import (
14
+ get_uncertain_point_coords_with_randomness,
15
+ point_sample,
16
+ )
17
+
18
+ from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list, _max_by_axis
19
+ from ..utils.tranform import matrix_to_quaternion, quaternion_to_matrix
20
+
21
+ def dice_loss(
22
+ inputs: torch.Tensor,
23
+ targets: torch.Tensor,
24
+ num_masks: float,
25
+ ):
26
+ """
27
+ Compute the DICE loss, similar to generalized IOU for masks
28
+ Args:
29
+ inputs: A float tensor of arbitrary shape.
30
+ The predictions for each example.
31
+ targets: A float tensor with the same shape as inputs. Stores the binary
32
+ classification label for each element in inputs
33
+ (0 for the negative class and 1 for the positive class).
34
+ """
35
+ inputs = inputs.sigmoid()
36
+ inputs = inputs.flatten(1)
37
+ numerator = 2 * (inputs * targets).sum(-1)
38
+ denominator = inputs.sum(-1) + targets.sum(-1)
39
+ loss = 1 - (numerator + 1) / (denominator + 1)
40
+ return loss.sum() / num_masks
41
+
42
+
43
+ dice_loss_jit = torch.jit.script(
44
+ dice_loss
45
+ ) # type: torch.jit.ScriptModule
46
+
47
+
48
+ def sigmoid_ce_loss(
49
+ inputs: torch.Tensor,
50
+ targets: torch.Tensor,
51
+ num_masks: float,
52
+ ):
53
+ """
54
+ Args:
55
+ inputs: A float tensor of arbitrary shape.
56
+ The predictions for each example.
57
+ targets: A float tensor with the same shape as inputs. Stores the binary
58
+ classification label for each element in inputs
59
+ (0 for the negative class and 1 for the positive class).
60
+ Returns:
61
+ Loss tensor
62
+ """
63
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
64
+
65
+ return loss.mean(1).sum() / num_masks
66
+
67
+
68
+ sigmoid_ce_loss_jit = torch.jit.script(
69
+ sigmoid_ce_loss
70
+ ) # type: torch.jit.ScriptModule
71
+
72
+
73
+ def calculate_uncertainty(logits):
74
+ """
75
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
76
+ foreground class in `classes`.
77
+ Args:
78
+ logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
79
+ class-agnostic, where R is the total number of predicted masks in all images and C is
80
+ the number of foreground classes. The values are logits.
81
+ Returns:
82
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
83
+ the most uncertain locations having the highest uncertainty score.
84
+ """
85
+ assert logits.shape[1] == 1
86
+ gt_class_logits = logits.clone()
87
+ return -(torch.abs(gt_class_logits))
88
+
89
+ def convert_to_filled_tensor(tensor_list):
90
+ max_size = _max_by_axis([list(tensor.shape) for tensor in tensor_list])
91
+ batch_shape = [len(tensor_list)] + max_size
92
+ dtype = tensor_list[0].dtype
93
+ device = tensor_list[0].device
94
+ filled_tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
95
+ for old, new in zip(tensor_list, filled_tensor):
96
+ new[:old.shape[0]] = old
97
+ return filled_tensor
98
+
99
+ def smooth_l1_loss(
100
+ input: torch.Tensor, target: torch.Tensor, beta: float, reduction: str = "none"
101
+ ) -> torch.Tensor:
102
+ """
103
+ Smooth L1 loss defined in the Fast R-CNN paper as:
104
+ ::
105
+ | 0.5 * x ** 2 / beta if abs(x) < beta
106
+ smoothl1(x) = |
107
+ | abs(x) - 0.5 * beta otherwise,
108
+
109
+ where x = input - target.
110
+
111
+ Smooth L1 loss is related to Huber loss, which is defined as:
112
+ ::
113
+ | 0.5 * x ** 2 if abs(x) < beta
114
+ huber(x) = |
115
+ | beta * (abs(x) - 0.5 * beta) otherwise
116
+
117
+ Smooth L1 loss is equal to huber(x) / beta. This leads to the following
118
+ differences:
119
+
120
+ - As beta -> 0, Smooth L1 loss converges to L1 loss, while Huber loss
121
+ converges to a constant 0 loss.
122
+ - As beta -> +inf, Smooth L1 converges to a constant 0 loss, while Huber loss
123
+ converges to L2 loss.
124
+ - For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant
125
+ slope of 1. For Huber loss, the slope of the L1 segment is beta.
126
+
127
+ Smooth L1 loss can be seen as exactly L1 loss, but with the abs(x) < beta
128
+ portion replaced with a quadratic function such that at abs(x) = beta, its
129
+ slope is 1. The quadratic segment smooths the L1 loss near x = 0.
130
+
131
+ Args:
132
+ input (Tensor): input tensor of any shape
133
+ target (Tensor): target value tensor with the same shape as input
134
+ beta (float): L1 to L2 change point.
135
+ For beta values < 1e-5, L1 loss is computed.
136
+ reduction: 'none' | 'mean' | 'sum'
137
+ 'none': No reduction will be applied to the output.
138
+ 'mean': The output will be averaged.
139
+ 'sum': The output will be summed.
140
+
141
+ Returns:
142
+ The loss with the reduction option applied.
143
+
144
+ Note:
145
+ PyTorch's builtin "Smooth L1 loss" implementation does not actually
146
+ implement Smooth L1 loss, nor does it implement Huber loss. It implements
147
+ the special case of both in which they are equal (beta=1).
148
+ See: https://pytorch.org/docs/stable/nn.html#torch.nn.SmoothL1Loss.
149
+ """
150
+ if beta < 1e-5:
151
+ # if beta == 0, then torch.where will result in nan gradients when
152
+ # the chain rule is applied due to pytorch implementation details
153
+ # (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of
154
+ # zeros, rather than "no gradient"). To avoid this issue, we define
155
+ # small values of beta to be exactly l1 loss.
156
+ loss = torch.abs(input - target)
157
+ else:
158
+ n = torch.abs(input - target)
159
+ cond = n < beta
160
+ loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
161
+
162
+ if reduction == "mean":
163
+ loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
164
+ elif reduction == "sum":
165
+ loss = loss.sum()
166
+ return loss
167
+
168
+ class SetCriterion(nn.Module):
169
+ """This class computes the loss for DETR.
170
+ The process happens in two steps:
171
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
172
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
173
+ """
174
+
175
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
176
+ num_points, oversample_ratio, importance_sample_ratio, motionnet_type, only_DET):
177
+ """Create the criterion.
178
+ Parameters:
179
+ num_classes: number of object categories, omitting the special no-object category
180
+ matcher: module able to compute a matching between targets and proposals
181
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
182
+ eos_coef: relative classification weight applied to the no-object category
183
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
184
+ """
185
+ super().__init__()
186
+ self.num_classes = num_classes
187
+ self.matcher = matcher
188
+ self.weight_dict = weight_dict
189
+ self.eos_coef = eos_coef
190
+ self.losses = losses
191
+ empty_weight = torch.ones(self.num_classes + 1)
192
+ empty_weight[-1] = self.eos_coef
193
+ self.register_buffer("empty_weight", empty_weight)
194
+
195
+ # pointwise mask loss parameters
196
+ self.num_points = num_points
197
+ self.oversample_ratio = oversample_ratio
198
+ self.importance_sample_ratio = importance_sample_ratio
199
+
200
+ # OPD
201
+ self.motionnet_type = motionnet_type
202
+ self.only_DET = only_DET
203
+
204
+ def loss_labels(self, outputs, targets, indices, num_masks):
205
+ """Classification loss (NLL)
206
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
207
+ """
208
+ assert "pred_logits" in outputs
209
+ src_logits = outputs["pred_logits"].float()
210
+
211
+ idx = self._get_src_permutation_idx(indices)
212
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
213
+ target_classes = torch.full(
214
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
215
+ )
216
+ target_classes[idx] = target_classes_o
217
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
218
+ losses = {"loss_ce": loss_ce}
219
+ return losses
220
+
221
+ # OPD
222
+ def loss_mtypes(self, outputs, targets, indices, num_masks):
223
+ assert "pred_mtypes" in outputs
224
+
225
+ src_idx = self._get_src_permutation_idx(indices)
226
+ tgt_idx = self._get_tgt_permutation_idx(indices)
227
+
228
+ target_motion_valid = convert_to_filled_tensor([t["gt_motion_valids"] for t in targets])[tgt_idx]
229
+ src_mtypes = outputs["pred_mtypes"][src_idx][target_motion_valid]
230
+ target_mtypes = convert_to_filled_tensor([t["gt_types"] for t in targets])[tgt_idx][target_motion_valid]
231
+
232
+ if src_mtypes.shape[0] == 0:
233
+ return {"loss_mtype": 0.0 * src_mtypes.sum()}
234
+
235
+ loss_mtype = F.cross_entropy(src_mtypes, target_mtypes.long(), reduction="sum") / num_masks
236
+ losses = {"loss_mtype": loss_mtype}
237
+ return losses
238
+
239
+ def loss_morigins(self, outputs, targets, indices, num_masks):
240
+ assert "pred_morigins" in outputs
241
+
242
+ src_idx = self._get_src_permutation_idx(indices)
243
+ tgt_idx = self._get_tgt_permutation_idx(indices)
244
+
245
+ target_motion_valid = convert_to_filled_tensor([t["gt_motion_valids"] for t in targets])[tgt_idx]
246
+ # Only calculate origin loss for the rotation axis
247
+ target_mtypes = convert_to_filled_tensor([t["gt_types"] for t in targets])[tgt_idx][target_motion_valid]
248
+ rot_inds = (
249
+ (target_mtypes == 0).nonzero().unbind(1)[0]
250
+ )
251
+ src_morigins = outputs["pred_morigins"][src_idx][target_motion_valid][rot_inds]
252
+ target_morigins = convert_to_filled_tensor([t["gt_origins"] for t in targets])[tgt_idx][target_motion_valid][rot_inds]
253
+
254
+ if src_morigins.shape[0] == 0:
255
+ return {"loss_morigin": 0.0 * src_morigins.sum()}
256
+
257
+ loss_morigin = smooth_l1_loss(src_morigins, target_morigins, 1.0, reduction="sum") / num_masks
258
+ losses = {"loss_morigin": loss_morigin}
259
+ return losses
260
+
261
+ def loss_maxises(self, outputs, targets, indices, num_masks):
262
+ assert "pred_maxises" in outputs
263
+
264
+ src_idx = self._get_src_permutation_idx(indices)
265
+ tgt_idx = self._get_tgt_permutation_idx(indices)
266
+
267
+ target_motion_valid = convert_to_filled_tensor([t["gt_motion_valids"] for t in targets])[tgt_idx]
268
+ src_maxises = outputs["pred_maxises"][src_idx][target_motion_valid]
269
+ target_maxises = convert_to_filled_tensor([t["gt_axises"] for t in targets])[tgt_idx][target_motion_valid]
270
+
271
+ if src_maxises.shape[0] == 0:
272
+ return {"loss_maxis": 0.0 * src_maxises.sum()}
273
+
274
+ loss_maxis = smooth_l1_loss(src_maxises, target_maxises, 1.0, reduction="sum") / num_masks
275
+ losses = {"loss_maxis": loss_maxis}
276
+ return losses
277
+
278
+ #TODO: add loss for motion state
279
+ def loss_mstates(self, outputs, targets, indices, num_masks):
280
+ assert "pred_mstates" in outputs
281
+
282
+ src_idx = self._get_src_permutation_idx(indices)
283
+ tgt_idx = self._get_tgt_permutation_idx(indices)
284
+
285
+ target_motion_valid = convert_to_filled_tensor([t["gt_motion_valids"] for t in targets])[tgt_idx]
286
+ src_mstate = outputs["pred_mstates"][src_idx][target_motion_valid]
287
+ target_mstate = convert_to_filled_tensor([t["gt_states"] for t in targets])[tgt_idx][target_motion_valid]
288
+
289
+ if src_mstate.shape[0] == 0:
290
+ return {"loss_mstate": 0.0 * src_mstate.sum()}
291
+
292
+ loss_mstate = smooth_l1_loss(src_mstate, target_mstate, 1.0, reduction="sum") / num_masks
293
+ losses = {"loss_mstate": loss_mstate}
294
+ return losses
295
+
296
+ def loss_mstatemaxs(self, outputs, targets, indices, num_masks):
297
+ assert "pred_mstatemaxs" in outputs
298
+
299
+ src_idx = self._get_src_permutation_idx(indices)
300
+ tgt_idx = self._get_tgt_permutation_idx(indices)
301
+
302
+ target_motion_valid = convert_to_filled_tensor([t["gt_motion_valids"] for t in targets])[tgt_idx]
303
+ src_mstatemax = outputs["pred_mstatemaxs"][src_idx][target_motion_valid]
304
+ target_mstatemax = convert_to_filled_tensor([t["gt_statemaxs"] for t in targets])[tgt_idx][target_motion_valid]
305
+
306
+ if src_mstatemax.shape[0] == 0:
307
+ return {"loss_mstatemax": 0.0 * src_mstatemax.sum()}
308
+
309
+ loss_mstatemax = smooth_l1_loss(src_mstatemax, target_mstatemax, 1.0, reduction="sum") / num_masks
310
+ losses = {"loss_mstatemax": loss_mstatemax}
311
+ return losses
312
+
313
+ def loss_extrinsics(self, outputs, targets, indices, num_masks):
314
+ assert "pred_extrinsics" in outputs
315
+ if self.motionnet_type == "BMOC_V0" or self.motionnet_type == "BMOC_V6":
316
+ target_motion_valid = torch.tensor([t["gt_motion_valids"][0] for t in targets], device=outputs["pred_extrinsics"].device)
317
+ src_extrinsics = outputs["pred_extrinsics"][target_motion_valid]
318
+ target_extrinsics_full = [t["gt_extrinsic"][0] for t in targets]
319
+ target_extrinsics = convert_to_filled_tensor([torch.cat(
320
+ [
321
+ extrinsic[0:3],
322
+ extrinsic[4:7],
323
+ extrinsic[8:11],
324
+ extrinsic[12:15],
325
+ ],
326
+ 0,
327
+ ) for extrinsic in target_extrinsics_full])[target_motion_valid]
328
+ if src_extrinsics.shape[0] == 0:
329
+ return {"loss_extrinsic": 0.0 * src_extrinsics.sum()}
330
+
331
+ # Much proper to make sure each valid image gives the same contribution to the loss
332
+ # Therefore, here use the number of images to average
333
+ loss_extrinsic = smooth_l1_loss(src_extrinsics, target_extrinsics, 1.0, reduction="sum") / outputs["pred_extrinsics"].shape[0]
334
+ elif self.motionnet_type == "BMOC_V1":
335
+ src_idx = self._get_src_permutation_idx(indices)
336
+ tgt_idx = self._get_tgt_permutation_idx(indices)
337
+
338
+ target_motion_valid = convert_to_filled_tensor([t["gt_motion_valids"] for t in targets])[tgt_idx]
339
+ src_extrinsics = outputs["pred_extrinsics"][src_idx][target_motion_valid]
340
+ target_extrinsics_full = []
341
+ for t in targets:
342
+ extrinsics = t["gt_extrinsic"]
343
+ target_extrinsics_full.append(torch.cat(
344
+ [
345
+ extrinsics[:, 0:3],
346
+ extrinsics[:, 4:7],
347
+ extrinsics[:, 8:11],
348
+ extrinsics[:, 12:15],
349
+ ],
350
+ 1,
351
+ ))
352
+
353
+ target_extrinsics = convert_to_filled_tensor(target_extrinsics_full)[tgt_idx][target_motion_valid]
354
+ if src_extrinsics.shape[0] == 0:
355
+ return {"loss_extrinsic": 0.0 * src_extrinsics.sum()}
356
+
357
+ # Much proper to make sure each valid image gives the same contribution to the loss
358
+ # Therefore, here use the number of images to average
359
+ loss_extrinsic = smooth_l1_loss(src_extrinsics, target_extrinsics, 1.0, reduction="sum") / num_masks
360
+ elif self.motionnet_type == "BMOC_V2":
361
+ src_idx = self._get_src_permutation_idx(indices)
362
+ tgt_idx = self._get_tgt_permutation_idx(indices)
363
+
364
+ target_motion_valid = convert_to_filled_tensor([t["gt_motion_valids"] for t in targets])[tgt_idx]
365
+ src_extrinsics = outputs["pred_extrinsics"][src_idx][target_motion_valid]
366
+ target_extrinsics = convert_to_filled_tensor([t["gt_extrinsic_quaternion"] for t in targets])[tgt_idx][target_motion_valid]
367
+
368
+ if src_extrinsics.shape[0] == 0:
369
+ return {"loss_extrinsic": 0.0 * src_extrinsics.sum()}
370
+
371
+ # Much proper to make sure each valid image gives the same contribution to the loss
372
+ # Therefore, here use the number of images to average
373
+ loss_extrinsic = smooth_l1_loss(src_extrinsics, target_extrinsics, 1.0, reduction="sum") / num_masks
374
+ elif self.motionnet_type == "BMOC_V3":
375
+ src_idx = self._get_src_permutation_idx(indices)
376
+ tgt_idx = self._get_tgt_permutation_idx(indices)
377
+
378
+ target_motion_valid = convert_to_filled_tensor([t["gt_motion_valids"] for t in targets])[tgt_idx]
379
+ src_extrinsics = outputs["pred_extrinsics"][src_idx][target_motion_valid]
380
+ target_extrinsics = convert_to_filled_tensor([t["gt_extrinsic_6d"] for t in targets])[tgt_idx][target_motion_valid]
381
+
382
+ if src_extrinsics.shape[0] == 0:
383
+ return {"loss_extrinsic": 0.0 * src_extrinsics.sum()}
384
+
385
+ # Much proper to make sure each valid image gives the same contribution to the loss
386
+ # Therefore, here use the number of images to average
387
+ loss_extrinsic = smooth_l1_loss(src_extrinsics, target_extrinsics, 1.0, reduction="sum") / num_masks
388
+ elif self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5":
389
+ target_motion_valid = torch.tensor([t["gt_motion_valids"][0] for t in targets], device=outputs["pred_extrinsics"].device)
390
+ src_extrinsics = outputs["pred_extrinsics"][target_motion_valid]
391
+ target_extrinsics = convert_to_filled_tensor([t["gt_extrinsic_quaternion"][0] for t in targets])[target_motion_valid]
392
+
393
+ if src_extrinsics.shape[0] == 0:
394
+ return {"loss_extrinsic": 0.0 * src_extrinsics.sum()}
395
+
396
+ # Much proper to make sure each valid image gives the same contribution to the loss
397
+ # Therefore, here use the number of images to average
398
+ loss_extrinsic = smooth_l1_loss(src_extrinsics, target_extrinsics, 1.0, reduction="sum") / outputs["pred_extrinsics"].shape[0]
399
+
400
+ return {"loss_extrinsic": loss_extrinsic}
401
+
402
+
403
+
404
+ def loss_masks(self, outputs, targets, indices, num_masks):
405
+ """Compute the losses related to the masks: the focal loss and the dice loss.
406
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
407
+ """
408
+ assert "pred_masks" in outputs
409
+
410
+ src_idx = self._get_src_permutation_idx(indices)
411
+ tgt_idx = self._get_tgt_permutation_idx(indices)
412
+ src_masks = outputs["pred_masks"]
413
+ src_masks = src_masks[src_idx]
414
+ masks = [t["masks"] for t in targets]
415
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
416
+ target_masks = target_masks.to(src_masks)
417
+ target_masks = target_masks[tgt_idx]
418
+
419
+ # No need to upsample predictions as we are using normalized coordinates :)
420
+ # N x 1 x H x W
421
+ src_masks = src_masks[:, None]
422
+ target_masks = target_masks[:, None]
423
+
424
+ with torch.no_grad():
425
+ # sample point_coords
426
+ point_coords = get_uncertain_point_coords_with_randomness(
427
+ src_masks,
428
+ lambda logits: calculate_uncertainty(logits),
429
+ self.num_points,
430
+ self.oversample_ratio,
431
+ self.importance_sample_ratio,
432
+ )
433
+ # get gt labels
434
+ point_labels = point_sample(
435
+ target_masks,
436
+ point_coords,
437
+ align_corners=False,
438
+ ).squeeze(1)
439
+
440
+ point_logits = point_sample(
441
+ src_masks,
442
+ point_coords,
443
+ align_corners=False,
444
+ ).squeeze(1)
445
+
446
+ losses = {
447
+ "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
448
+ "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks),
449
+ }
450
+
451
+ del src_masks
452
+ del target_masks
453
+ return losses
454
+
455
+ def _get_src_permutation_idx(self, indices):
456
+ # permute predictions following indices
457
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
458
+ src_idx = torch.cat([src for (src, _) in indices])
459
+ return batch_idx, src_idx
460
+
461
+ def _get_tgt_permutation_idx(self, indices):
462
+ # permute targets following indices
463
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
464
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
465
+ return batch_idx, tgt_idx
466
+
467
+ def get_loss(self, loss, outputs, targets, indices, num_masks):
468
+ tmp_device = outputs["pred_logits"].device
469
+ tmp_list = ["mtypes", "morigins", "maxises"]
470
+ loss_map = {
471
+ 'labels': self.loss_labels,
472
+ 'masks': self.loss_masks,
473
+ # OPD
474
+ "mtypes": self.loss_mtypes,
475
+ "morigins": self.loss_morigins,
476
+ "maxises": self.loss_maxises,
477
+ "extrinsics": self.loss_extrinsics,
478
+ "mstates": self.loss_mstates,
479
+ "mstatemaxs": self.loss_mstatemaxs,
480
+ }
481
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
482
+ tmp_loss = loss_map[loss](outputs, targets, indices, num_masks)
483
+ if self.only_DET and loss in tmp_list:
484
+ tmp_key = list(tmp_loss.keys())[0]
485
+ tmp_loss[tmp_key] = torch.tensor(0.0, device=tmp_device)
486
+ return tmp_loss
487
+ else:
488
+ return tmp_loss
489
+ # return loss_map[loss](outputs, targets, indices, num_masks)
490
+
491
+ def forward(self, outputs, targets):
492
+ """This performs the loss computation.
493
+ Parameters:
494
+ outputs: dict of tensors, see the output specification of the model for the format
495
+ targets: list of dicts, such that len(targets) == batch_size.
496
+ The expected keys in each dict depends on the losses applied, see each loss' doc
497
+ """
498
+ tmp_device = outputs["pred_logits"].device
499
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
500
+
501
+ # Retrieve the matching between the outputs of the last layer and the targets
502
+ indices = self.matcher(outputs_without_aux, targets)
503
+
504
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
505
+ num_masks = sum(len(t["labels"]) for t in targets)
506
+ num_masks = torch.as_tensor(
507
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
508
+ )
509
+ if is_dist_avail_and_initialized():
510
+ torch.distributed.all_reduce(num_masks)
511
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
512
+
513
+ # Compute all the requested losses
514
+ losses = {}
515
+ for loss in self.losses:
516
+ if loss == "extrinsics" and self.motionnet_type == "BMCC":
517
+ continue
518
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
519
+
520
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
521
+ if "aux_outputs" in outputs:
522
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
523
+ indices = self.matcher(aux_outputs, targets)
524
+ for loss in self.losses:
525
+ if loss == "extrinsics" and (self.motionnet_type == "BMOC_V0" or self.motionnet_type == "BMCC"):
526
+ continue
527
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
528
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
529
+ losses.update(l_dict)
530
+
531
+ return losses
532
+
533
+ def __repr__(self):
534
+ head = "Criterion " + self.__class__.__name__
535
+ body = [
536
+ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
537
+ "losses: {}".format(self.losses),
538
+ "weight_dict: {}".format(self.weight_dict),
539
+ "num_classes: {}".format(self.num_classes),
540
+ "eos_coef: {}".format(self.eos_coef),
541
+ "num_points: {}".format(self.num_points),
542
+ "oversample_ratio: {}".format(self.oversample_ratio),
543
+ "importance_sample_ratio: {}".format(self.importance_sample_ratio),
544
+ ]
545
+ _repr_indent = 4
546
+ lines = [head] + [" " * _repr_indent + line for line in body]
547
+ return "\n".join(lines)
mask2former/modeling/matcher.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
3
+ """
4
+ Modules to compute the matching cost and solve the corresponding LSAP.
5
+ """
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from scipy.optimize import linear_sum_assignment
9
+ from torch import nn
10
+ from torch.cuda.amp import autocast
11
+
12
+ from detectron2.projects.point_rend.point_features import point_sample
13
+
14
+
15
+ def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
16
+ """
17
+ Compute the DICE loss, similar to generalized IOU for masks
18
+ Args:
19
+ inputs: A float tensor of arbitrary shape.
20
+ The predictions for each example.
21
+ targets: A float tensor with the same shape as inputs. Stores the binary
22
+ classification label for each element in inputs
23
+ (0 for the negative class and 1 for the positive class).
24
+ """
25
+ inputs = inputs.sigmoid()
26
+ inputs = inputs.flatten(1)
27
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
28
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
29
+ loss = 1 - (numerator + 1) / (denominator + 1)
30
+ return loss
31
+
32
+
33
+ batch_dice_loss_jit = torch.jit.script(
34
+ batch_dice_loss
35
+ ) # type: torch.jit.ScriptModule
36
+
37
+
38
+ def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
39
+ """
40
+ Args:
41
+ inputs: A float tensor of arbitrary shape.
42
+ The predictions for each example.
43
+ targets: A float tensor with the same shape as inputs. Stores the binary
44
+ classification label for each element in inputs
45
+ (0 for the negative class and 1 for the positive class).
46
+ Returns:
47
+ Loss tensor
48
+ """
49
+ hw = inputs.shape[1]
50
+
51
+ pos = F.binary_cross_entropy_with_logits(
52
+ inputs, torch.ones_like(inputs), reduction="none"
53
+ )
54
+ neg = F.binary_cross_entropy_with_logits(
55
+ inputs, torch.zeros_like(inputs), reduction="none"
56
+ )
57
+
58
+ loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
59
+ "nc,mc->nm", neg, (1 - targets)
60
+ )
61
+
62
+ return loss / hw
63
+
64
+
65
+ batch_sigmoid_ce_loss_jit = torch.jit.script(
66
+ batch_sigmoid_ce_loss
67
+ ) # type: torch.jit.ScriptModule
68
+
69
+
70
+ class HungarianMatcher(nn.Module):
71
+ """This class computes an assignment between the targets and the predictions of the network
72
+
73
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
74
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
75
+ while the others are un-matched (and thus treated as non-objects).
76
+ """
77
+
78
+ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0):
79
+ """Creates the matcher
80
+
81
+ Params:
82
+ cost_class: This is the relative weight of the classification error in the matching cost
83
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
84
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
85
+ """
86
+ super().__init__()
87
+ self.cost_class = cost_class
88
+ self.cost_mask = cost_mask
89
+ self.cost_dice = cost_dice
90
+
91
+ assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
92
+
93
+ self.num_points = num_points
94
+
95
+ @torch.no_grad()
96
+ def memory_efficient_forward(self, outputs, targets):
97
+ """More memory-friendly matching"""
98
+ bs, num_queries = outputs["pred_logits"].shape[:2]
99
+
100
+ indices = []
101
+
102
+ # Iterate through batch size
103
+ for b in range(bs):
104
+
105
+ out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
106
+ tgt_ids = targets[b]["labels"]
107
+
108
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
109
+ # but approximate it in 1 - proba[target class].
110
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
111
+ cost_class = -out_prob[:, tgt_ids]
112
+
113
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
114
+ # gt masks are already padded when preparing target
115
+ tgt_mask = targets[b]["masks"].to(out_mask)
116
+
117
+ out_mask = out_mask[:, None]
118
+ tgt_mask = tgt_mask[:, None]
119
+ # all masks share the same set of points for efficient matching!
120
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device)
121
+ # get gt labels
122
+ tgt_mask = point_sample(
123
+ tgt_mask,
124
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
125
+ align_corners=False,
126
+ ).squeeze(1)
127
+
128
+ out_mask = point_sample(
129
+ out_mask,
130
+ point_coords.repeat(out_mask.shape[0], 1, 1),
131
+ align_corners=False,
132
+ ).squeeze(1)
133
+
134
+ with autocast(enabled=False):
135
+ out_mask = out_mask.float()
136
+ tgt_mask = tgt_mask.float()
137
+ # Compute the focal loss between masks
138
+ if out_mask.shape[0] == 0 or tgt_mask.shape[0] == 0:
139
+ cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask)
140
+ # Compute the dice loss betwen masks
141
+ cost_dice = batch_dice_loss(out_mask, tgt_mask)
142
+ else:
143
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
144
+ # Compute the dice loss betwen masks
145
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
146
+ # Final cost matrix
147
+ C = (
148
+ self.cost_mask * cost_mask
149
+ + self.cost_class * cost_class
150
+ + self.cost_dice * cost_dice
151
+ )
152
+ C = C.reshape(num_queries, -1).cpu()
153
+
154
+ indices.append(linear_sum_assignment(C))
155
+
156
+ return [
157
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
158
+ for i, j in indices
159
+ ]
160
+
161
+ @torch.no_grad()
162
+ def forward(self, outputs, targets):
163
+ """Performs the matching
164
+
165
+ Params:
166
+ outputs: This is a dict that contains at least these entries:
167
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
168
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
169
+
170
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
171
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
172
+ objects in the target) containing the class labels
173
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
174
+
175
+ Returns:
176
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
177
+ - index_i is the indices of the selected predictions (in order)
178
+ - index_j is the indices of the corresponding selected targets (in order)
179
+ For each batch element, it holds:
180
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
181
+ """
182
+ return self.memory_efficient_forward(outputs, targets)
183
+
184
+ def __repr__(self, _repr_indent=4):
185
+ head = "Matcher " + self.__class__.__name__
186
+ body = [
187
+ "cost_class: {}".format(self.cost_class),
188
+ "cost_mask: {}".format(self.cost_mask),
189
+ "cost_dice: {}".format(self.cost_dice),
190
+ ]
191
+ lines = [head] + [" " * _repr_indent + line for line in body]
192
+ return "\n".join(lines)
mask2former/modeling/meta_arch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
mask2former/modeling/meta_arch/mask_former_head.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
12
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
13
+
14
+ from ..transformer_decoder.maskformer_transformer_decoder import build_transformer_decoder
15
+ from ..pixel_decoder.fpn import build_pixel_decoder
16
+
17
+
18
+ @SEM_SEG_HEADS_REGISTRY.register()
19
+ class MaskFormerHead(nn.Module):
20
+
21
+ _version = 2
22
+
23
+ def _load_from_state_dict(
24
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
25
+ ):
26
+ version = local_metadata.get("version", None)
27
+ if version is None or version < 2:
28
+ # Do not warn if train from scratch
29
+ scratch = True
30
+ logger = logging.getLogger(__name__)
31
+ for k in list(state_dict.keys()):
32
+ newk = k
33
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
34
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
35
+ # logger.debug(f"{k} ==> {newk}")
36
+ if newk != k:
37
+ state_dict[newk] = state_dict[k]
38
+ del state_dict[k]
39
+ scratch = False
40
+
41
+ if not scratch:
42
+ logger.warning(
43
+ f"Weight format of {self.__class__.__name__} have changed! "
44
+ "Please upgrade your models. Applying automatic conversion now ..."
45
+ )
46
+
47
+ @configurable
48
+ def __init__(
49
+ self,
50
+ input_shape: Dict[str, ShapeSpec],
51
+ *,
52
+ num_classes: int,
53
+ pixel_decoder: nn.Module,
54
+ loss_weight: float = 1.0,
55
+ ignore_value: int = -1,
56
+ # extra parameters
57
+ transformer_predictor: nn.Module,
58
+ transformer_in_feature: str,
59
+ ):
60
+ """
61
+ NOTE: this interface is experimental.
62
+ Args:
63
+ input_shape: shapes (channels and stride) of the input features
64
+ num_classes: number of classes to predict
65
+ pixel_decoder: the pixel decoder module
66
+ loss_weight: loss weight
67
+ ignore_value: category id to be ignored during training.
68
+ transformer_predictor: the transformer decoder that makes prediction
69
+ transformer_in_feature: input feature name to the transformer_predictor
70
+ """
71
+ super().__init__()
72
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
73
+ self.in_features = [k for k, v in input_shape]
74
+ feature_strides = [v.stride for k, v in input_shape]
75
+ feature_channels = [v.channels for k, v in input_shape]
76
+
77
+ self.ignore_value = ignore_value
78
+ self.common_stride = 4
79
+ self.loss_weight = loss_weight
80
+
81
+ self.pixel_decoder = pixel_decoder
82
+ self.predictor = transformer_predictor
83
+ self.transformer_in_feature = transformer_in_feature
84
+
85
+ self.num_classes = num_classes
86
+
87
+ @classmethod
88
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
89
+ # figure out in_channels to transformer predictor
90
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
91
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
92
+ elif cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "pixel_embedding":
93
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
94
+ elif cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "multi_scale_pixel_decoder": # for maskformer2
95
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
96
+ else:
97
+ transformer_predictor_in_channels = input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels
98
+
99
+ return {
100
+ "input_shape": {
101
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
102
+ },
103
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
104
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
105
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
106
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
107
+ "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
108
+ "transformer_predictor": build_transformer_decoder(
109
+ cfg,
110
+ transformer_predictor_in_channels,
111
+ mask_classification=True,
112
+ ),
113
+ }
114
+
115
+ def forward(self, features, mask=None):
116
+ return self.layers(features, mask)
117
+
118
+ def layers(self, features, mask=None):
119
+ mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
120
+ if self.transformer_in_feature == "multi_scale_pixel_decoder":
121
+ # TODO: pass object mask prediction to this function
122
+ predictions = self.predictor(multi_scale_features, mask_features, mask)
123
+ else:
124
+ if self.transformer_in_feature == "transformer_encoder":
125
+ assert (
126
+ transformer_encoder_features is not None
127
+ ), "Please use the TransformerEncoderPixelDecoder."
128
+ predictions = self.predictor(transformer_encoder_features, mask_features, mask)
129
+ elif self.transformer_in_feature == "pixel_embedding":
130
+ predictions = self.predictor(mask_features, mask_features, mask)
131
+ else:
132
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
133
+ return predictions
mask2former/modeling/meta_arch/per_pixel_baseline.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from typing import Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
11
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
12
+
13
+ from ..transformer_decoder.maskformer_transformer_decoder import StandardTransformerDecoder
14
+ from ..pixel_decoder.fpn import build_pixel_decoder
15
+
16
+
17
+ @SEM_SEG_HEADS_REGISTRY.register()
18
+ class PerPixelBaselineHead(nn.Module):
19
+
20
+ _version = 2
21
+
22
+ def _load_from_state_dict(
23
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
24
+ ):
25
+ version = local_metadata.get("version", None)
26
+ if version is None or version < 2:
27
+ logger = logging.getLogger(__name__)
28
+ # Do not warn if train from scratch
29
+ scratch = True
30
+ logger = logging.getLogger(__name__)
31
+ for k in list(state_dict.keys()):
32
+ newk = k
33
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
34
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
35
+ # logger.warning(f"{k} ==> {newk}")
36
+ if newk != k:
37
+ state_dict[newk] = state_dict[k]
38
+ del state_dict[k]
39
+ scratch = False
40
+
41
+ if not scratch:
42
+ logger.warning(
43
+ f"Weight format of {self.__class__.__name__} have changed! "
44
+ "Please upgrade your models. Applying automatic conversion now ..."
45
+ )
46
+
47
+ @configurable
48
+ def __init__(
49
+ self,
50
+ input_shape: Dict[str, ShapeSpec],
51
+ *,
52
+ num_classes: int,
53
+ pixel_decoder: nn.Module,
54
+ loss_weight: float = 1.0,
55
+ ignore_value: int = -1,
56
+ ):
57
+ """
58
+ NOTE: this interface is experimental.
59
+ Args:
60
+ input_shape: shapes (channels and stride) of the input features
61
+ num_classes: number of classes to predict
62
+ pixel_decoder: the pixel decoder module
63
+ loss_weight: loss weight
64
+ ignore_value: category id to be ignored during training.
65
+ """
66
+ super().__init__()
67
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
68
+ self.in_features = [k for k, v in input_shape]
69
+ feature_strides = [v.stride for k, v in input_shape]
70
+ feature_channels = [v.channels for k, v in input_shape]
71
+
72
+ self.ignore_value = ignore_value
73
+ self.common_stride = 4
74
+ self.loss_weight = loss_weight
75
+
76
+ self.pixel_decoder = pixel_decoder
77
+ self.predictor = Conv2d(
78
+ self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0
79
+ )
80
+ weight_init.c2_msra_fill(self.predictor)
81
+
82
+ @classmethod
83
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
84
+ return {
85
+ "input_shape": {
86
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
87
+ },
88
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
89
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
90
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
91
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
92
+ }
93
+
94
+ def forward(self, features, targets=None):
95
+ """
96
+ Returns:
97
+ In training, returns (None, dict of losses)
98
+ In inference, returns (CxHxW logits, {})
99
+ """
100
+ x = self.layers(features)
101
+ if self.training:
102
+ return None, self.losses(x, targets)
103
+ else:
104
+ x = F.interpolate(
105
+ x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
106
+ )
107
+ return x, {}
108
+
109
+ def layers(self, features):
110
+ x, _, _ = self.pixel_decoder.forward_features(features)
111
+ x = self.predictor(x)
112
+ return x
113
+
114
+ def losses(self, predictions, targets):
115
+ predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
116
+ predictions = F.interpolate(
117
+ predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
118
+ )
119
+ loss = F.cross_entropy(
120
+ predictions, targets, reduction="mean", ignore_index=self.ignore_value
121
+ )
122
+ losses = {"loss_sem_seg": loss * self.loss_weight}
123
+ return losses
124
+
125
+
126
+ @SEM_SEG_HEADS_REGISTRY.register()
127
+ class PerPixelBaselinePlusHead(PerPixelBaselineHead):
128
+ def _load_from_state_dict(
129
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
130
+ ):
131
+ version = local_metadata.get("version", None)
132
+ if version is None or version < 2:
133
+ # Do not warn if train from scratch
134
+ scratch = True
135
+ logger = logging.getLogger(__name__)
136
+ for k in list(state_dict.keys()):
137
+ newk = k
138
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
139
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
140
+ logger.debug(f"{k} ==> {newk}")
141
+ if newk != k:
142
+ state_dict[newk] = state_dict[k]
143
+ del state_dict[k]
144
+ scratch = False
145
+
146
+ if not scratch:
147
+ logger.warning(
148
+ f"Weight format of {self.__class__.__name__} have changed! "
149
+ "Please upgrade your models. Applying automatic conversion now ..."
150
+ )
151
+
152
+ @configurable
153
+ def __init__(
154
+ self,
155
+ input_shape: Dict[str, ShapeSpec],
156
+ *,
157
+ # extra parameters
158
+ transformer_predictor: nn.Module,
159
+ transformer_in_feature: str,
160
+ deep_supervision: bool,
161
+ # inherit parameters
162
+ num_classes: int,
163
+ pixel_decoder: nn.Module,
164
+ loss_weight: float = 1.0,
165
+ ignore_value: int = -1,
166
+ ):
167
+ """
168
+ NOTE: this interface is experimental.
169
+ Args:
170
+ input_shape: shapes (channels and stride) of the input features
171
+ transformer_predictor: the transformer decoder that makes prediction
172
+ transformer_in_feature: input feature name to the transformer_predictor
173
+ deep_supervision: whether or not to add supervision to the output of
174
+ every transformer decoder layer
175
+ num_classes: number of classes to predict
176
+ pixel_decoder: the pixel decoder module
177
+ loss_weight: loss weight
178
+ ignore_value: category id to be ignored during training.
179
+ """
180
+ super().__init__(
181
+ input_shape,
182
+ num_classes=num_classes,
183
+ pixel_decoder=pixel_decoder,
184
+ loss_weight=loss_weight,
185
+ ignore_value=ignore_value,
186
+ )
187
+
188
+ del self.predictor
189
+
190
+ self.predictor = transformer_predictor
191
+ self.transformer_in_feature = transformer_in_feature
192
+ self.deep_supervision = deep_supervision
193
+
194
+ @classmethod
195
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
196
+ ret = super().from_config(cfg, input_shape)
197
+ ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE
198
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
199
+ in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
200
+ else:
201
+ in_channels = input_shape[ret["transformer_in_feature"]].channels
202
+ ret["transformer_predictor"] = StandardTransformerDecoder(
203
+ cfg, in_channels, mask_classification=False
204
+ )
205
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
206
+ return ret
207
+
208
+ def forward(self, features, targets=None):
209
+ """
210
+ Returns:
211
+ In training, returns (None, dict of losses)
212
+ In inference, returns (CxHxW logits, {})
213
+ """
214
+ x, aux_outputs = self.layers(features)
215
+ if self.training:
216
+ if self.deep_supervision:
217
+ losses = self.losses(x, targets)
218
+ for i, aux_output in enumerate(aux_outputs):
219
+ losses["loss_sem_seg" + f"_{i}"] = self.losses(
220
+ aux_output["pred_masks"], targets
221
+ )["loss_sem_seg"]
222
+ return None, losses
223
+ else:
224
+ return None, self.losses(x, targets)
225
+ else:
226
+ x = F.interpolate(
227
+ x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
228
+ )
229
+ return x, {}
230
+
231
+ def layers(self, features):
232
+ mask_features, transformer_encoder_features, _ = self.pixel_decoder.forward_features(features)
233
+ if self.transformer_in_feature == "transformer_encoder":
234
+ assert (
235
+ transformer_encoder_features is not None
236
+ ), "Please use the TransformerEncoderPixelDecoder."
237
+ predictions = self.predictor(transformer_encoder_features, mask_features)
238
+ else:
239
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features)
240
+ if self.deep_supervision:
241
+ return predictions["pred_masks"], predictions["aux_outputs"]
242
+ else:
243
+ return predictions["pred_masks"], None
mask2former/modeling/pixel_decoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
mask2former/modeling/pixel_decoder/fpn.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11
+ from torch.cuda.amp import autocast
12
+
13
+ from detectron2.config import configurable
14
+ from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm
15
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
16
+
17
+ from ..transformer_decoder.position_encoding import PositionEmbeddingSine
18
+ from ..transformer_decoder.transformer import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn
19
+
20
+
21
+ def build_pixel_decoder(cfg, input_shape):
22
+ """
23
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
24
+ """
25
+ name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
26
+ model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
27
+ forward_features = getattr(model, "forward_features", None)
28
+ if not callable(forward_features):
29
+ raise ValueError(
30
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
31
+ f"Please implement forward_features for {name} to only return mask features."
32
+ )
33
+ return model
34
+
35
+
36
+ # This is a modified FPN decoder.
37
+ @SEM_SEG_HEADS_REGISTRY.register()
38
+ class BasePixelDecoder(nn.Module):
39
+ @configurable
40
+ def __init__(
41
+ self,
42
+ input_shape: Dict[str, ShapeSpec],
43
+ *,
44
+ conv_dim: int,
45
+ mask_dim: int,
46
+ norm: Optional[Union[str, Callable]] = None,
47
+ ):
48
+ """
49
+ NOTE: this interface is experimental.
50
+ Args:
51
+ input_shape: shapes (channels and stride) of the input features
52
+ conv_dims: number of output channels for the intermediate conv layers.
53
+ mask_dim: number of output channels for the final conv layer.
54
+ norm (str or callable): normalization for all conv layers
55
+ """
56
+ super().__init__()
57
+
58
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
59
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
60
+ feature_channels = [v.channels for k, v in input_shape]
61
+
62
+ lateral_convs = []
63
+ output_convs = []
64
+
65
+ use_bias = norm == ""
66
+ for idx, in_channels in enumerate(feature_channels):
67
+ if idx == len(self.in_features) - 1:
68
+ output_norm = get_norm(norm, conv_dim)
69
+ output_conv = Conv2d(
70
+ in_channels,
71
+ conv_dim,
72
+ kernel_size=3,
73
+ stride=1,
74
+ padding=1,
75
+ bias=use_bias,
76
+ norm=output_norm,
77
+ activation=F.relu,
78
+ )
79
+ weight_init.c2_xavier_fill(output_conv)
80
+ self.add_module("layer_{}".format(idx + 1), output_conv)
81
+
82
+ lateral_convs.append(None)
83
+ output_convs.append(output_conv)
84
+ else:
85
+ lateral_norm = get_norm(norm, conv_dim)
86
+ output_norm = get_norm(norm, conv_dim)
87
+
88
+ lateral_conv = Conv2d(
89
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
90
+ )
91
+ output_conv = Conv2d(
92
+ conv_dim,
93
+ conv_dim,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=use_bias,
98
+ norm=output_norm,
99
+ activation=F.relu,
100
+ )
101
+ weight_init.c2_xavier_fill(lateral_conv)
102
+ weight_init.c2_xavier_fill(output_conv)
103
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
104
+ self.add_module("layer_{}".format(idx + 1), output_conv)
105
+
106
+ lateral_convs.append(lateral_conv)
107
+ output_convs.append(output_conv)
108
+ # Place convs into top-down order (from low to high resolution)
109
+ # to make the top-down computation in forward clearer.
110
+ self.lateral_convs = lateral_convs[::-1]
111
+ self.output_convs = output_convs[::-1]
112
+
113
+ self.mask_dim = mask_dim
114
+ self.mask_features = Conv2d(
115
+ conv_dim,
116
+ mask_dim,
117
+ kernel_size=3,
118
+ stride=1,
119
+ padding=1,
120
+ )
121
+ weight_init.c2_xavier_fill(self.mask_features)
122
+
123
+ self.maskformer_num_feature_levels = 3 # always use 3 scales
124
+
125
+ @classmethod
126
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
127
+ ret = {}
128
+ ret["input_shape"] = {
129
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
130
+ }
131
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
132
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
133
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
134
+ return ret
135
+
136
+ def forward_features(self, features):
137
+ multi_scale_features = []
138
+ num_cur_levels = 0
139
+ # Reverse feature maps into top-down order (from low to high resolution)
140
+ for idx, f in enumerate(self.in_features[::-1]):
141
+ x = features[f]
142
+ lateral_conv = self.lateral_convs[idx]
143
+ output_conv = self.output_convs[idx]
144
+ if lateral_conv is None:
145
+ y = output_conv(x)
146
+ else:
147
+ cur_fpn = lateral_conv(x)
148
+ # Following FPN implementation, we use nearest upsampling here
149
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
150
+ y = output_conv(y)
151
+ if num_cur_levels < self.maskformer_num_feature_levels:
152
+ multi_scale_features.append(y)
153
+ num_cur_levels += 1
154
+ return self.mask_features(y), None, multi_scale_features
155
+
156
+ def forward(self, features, targets=None):
157
+ logger = logging.getLogger(__name__)
158
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
159
+ return self.forward_features(features)
160
+
161
+
162
+ class TransformerEncoderOnly(nn.Module):
163
+ def __init__(
164
+ self,
165
+ d_model=512,
166
+ nhead=8,
167
+ num_encoder_layers=6,
168
+ dim_feedforward=2048,
169
+ dropout=0.1,
170
+ activation="relu",
171
+ normalize_before=False,
172
+ ):
173
+ super().__init__()
174
+
175
+ encoder_layer = TransformerEncoderLayer(
176
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
177
+ )
178
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
179
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
180
+
181
+ self._reset_parameters()
182
+
183
+ self.d_model = d_model
184
+ self.nhead = nhead
185
+
186
+ def _reset_parameters(self):
187
+ for p in self.parameters():
188
+ if p.dim() > 1:
189
+ nn.init.xavier_uniform_(p)
190
+
191
+ def forward(self, src, mask, pos_embed):
192
+ # flatten NxCxHxW to HWxNxC
193
+ bs, c, h, w = src.shape
194
+ src = src.flatten(2).permute(2, 0, 1)
195
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
196
+ if mask is not None:
197
+ mask = mask.flatten(1)
198
+
199
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
200
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
201
+
202
+
203
+ # This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.
204
+ @SEM_SEG_HEADS_REGISTRY.register()
205
+ class TransformerEncoderPixelDecoder(BasePixelDecoder):
206
+ @configurable
207
+ def __init__(
208
+ self,
209
+ input_shape: Dict[str, ShapeSpec],
210
+ *,
211
+ transformer_dropout: float,
212
+ transformer_nheads: int,
213
+ transformer_dim_feedforward: int,
214
+ transformer_enc_layers: int,
215
+ transformer_pre_norm: bool,
216
+ conv_dim: int,
217
+ mask_dim: int,
218
+ norm: Optional[Union[str, Callable]] = None,
219
+ ):
220
+ """
221
+ NOTE: this interface is experimental.
222
+ Args:
223
+ input_shape: shapes (channels and stride) of the input features
224
+ transformer_dropout: dropout probability in transformer
225
+ transformer_nheads: number of heads in transformer
226
+ transformer_dim_feedforward: dimension of feedforward network
227
+ transformer_enc_layers: number of transformer encoder layers
228
+ transformer_pre_norm: whether to use pre-layernorm or not
229
+ conv_dims: number of output channels for the intermediate conv layers.
230
+ mask_dim: number of output channels for the final conv layer.
231
+ norm (str or callable): normalization for all conv layers
232
+ """
233
+ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
234
+
235
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
236
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
237
+ feature_strides = [v.stride for k, v in input_shape]
238
+ feature_channels = [v.channels for k, v in input_shape]
239
+
240
+ in_channels = feature_channels[len(self.in_features) - 1]
241
+ self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
242
+ weight_init.c2_xavier_fill(self.input_proj)
243
+ self.transformer = TransformerEncoderOnly(
244
+ d_model=conv_dim,
245
+ dropout=transformer_dropout,
246
+ nhead=transformer_nheads,
247
+ dim_feedforward=transformer_dim_feedforward,
248
+ num_encoder_layers=transformer_enc_layers,
249
+ normalize_before=transformer_pre_norm,
250
+ )
251
+ N_steps = conv_dim // 2
252
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
253
+
254
+ # update layer
255
+ use_bias = norm == ""
256
+ output_norm = get_norm(norm, conv_dim)
257
+ output_conv = Conv2d(
258
+ conv_dim,
259
+ conv_dim,
260
+ kernel_size=3,
261
+ stride=1,
262
+ padding=1,
263
+ bias=use_bias,
264
+ norm=output_norm,
265
+ activation=F.relu,
266
+ )
267
+ weight_init.c2_xavier_fill(output_conv)
268
+ delattr(self, "layer_{}".format(len(self.in_features)))
269
+ self.add_module("layer_{}".format(len(self.in_features)), output_conv)
270
+ self.output_convs[0] = output_conv
271
+
272
+ @classmethod
273
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
274
+ ret = super().from_config(cfg, input_shape)
275
+ ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
276
+ ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
277
+ ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
278
+ ret[
279
+ "transformer_enc_layers"
280
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
281
+ ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
282
+ return ret
283
+
284
+ def forward_features(self, features):
285
+ multi_scale_features = []
286
+ num_cur_levels = 0
287
+ # Reverse feature maps into top-down order (from low to high resolution)
288
+ for idx, f in enumerate(self.in_features[::-1]):
289
+ x = features[f]
290
+ lateral_conv = self.lateral_convs[idx]
291
+ output_conv = self.output_convs[idx]
292
+ if lateral_conv is None:
293
+ transformer = self.input_proj(x)
294
+ pos = self.pe_layer(x)
295
+ transformer = self.transformer(transformer, None, pos)
296
+ y = output_conv(transformer)
297
+ # save intermediate feature as input to Transformer decoder
298
+ transformer_encoder_features = transformer
299
+ else:
300
+ cur_fpn = lateral_conv(x)
301
+ # Following FPN implementation, we use nearest upsampling here
302
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
303
+ y = output_conv(y)
304
+ if num_cur_levels < self.maskformer_num_feature_levels:
305
+ multi_scale_features.append(y)
306
+ num_cur_levels += 1
307
+ return self.mask_features(y), transformer_encoder_features, multi_scale_features
308
+
309
+ def forward(self, features, targets=None):
310
+ logger = logging.getLogger(__name__)
311
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
312
+ return self.forward_features(features)
mask2former/modeling/pixel_decoder/msdeformattn.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11
+ from torch.cuda.amp import autocast
12
+
13
+ from detectron2.config import configurable
14
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
15
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
16
+
17
+ from ..transformer_decoder.position_encoding import PositionEmbeddingSine
18
+ from ..transformer_decoder.transformer import _get_clones, _get_activation_fn
19
+ from .ops.modules import MSDeformAttn
20
+
21
+
22
+ # MSDeformAttn Transformer encoder in deformable detr
23
+ class MSDeformAttnTransformerEncoderOnly(nn.Module):
24
+ def __init__(self, d_model=256, nhead=8,
25
+ num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
26
+ activation="relu",
27
+ num_feature_levels=4, enc_n_points=4,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.d_model = d_model
32
+ self.nhead = nhead
33
+
34
+ encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,
35
+ dropout, activation,
36
+ num_feature_levels, nhead, enc_n_points)
37
+ self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)
38
+
39
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
40
+
41
+ self._reset_parameters()
42
+
43
+ def _reset_parameters(self):
44
+ for p in self.parameters():
45
+ if p.dim() > 1:
46
+ nn.init.xavier_uniform_(p)
47
+ for m in self.modules():
48
+ if isinstance(m, MSDeformAttn):
49
+ m._reset_parameters()
50
+ normal_(self.level_embed)
51
+
52
+ def get_valid_ratio(self, mask):
53
+ _, H, W = mask.shape
54
+ valid_H = torch.sum(~mask[:, :, 0], 1)
55
+ valid_W = torch.sum(~mask[:, 0, :], 1)
56
+ valid_ratio_h = valid_H.float() / H
57
+ valid_ratio_w = valid_W.float() / W
58
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
59
+ return valid_ratio
60
+
61
+ def forward(self, srcs, pos_embeds):
62
+ masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
63
+ # prepare input for encoder
64
+ src_flatten = []
65
+ mask_flatten = []
66
+ lvl_pos_embed_flatten = []
67
+ spatial_shapes = []
68
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
69
+ bs, c, h, w = src.shape
70
+ spatial_shape = (h, w)
71
+ spatial_shapes.append(spatial_shape)
72
+ src = src.flatten(2).transpose(1, 2)
73
+ mask = mask.flatten(1)
74
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
75
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
76
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
77
+ src_flatten.append(src)
78
+ mask_flatten.append(mask)
79
+ src_flatten = torch.cat(src_flatten, 1)
80
+ mask_flatten = torch.cat(mask_flatten, 1)
81
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
82
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
83
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
84
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
85
+
86
+ # encoder
87
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
88
+
89
+ return memory, spatial_shapes, level_start_index
90
+
91
+
92
+ class MSDeformAttnTransformerEncoderLayer(nn.Module):
93
+ def __init__(self,
94
+ d_model=256, d_ffn=1024,
95
+ dropout=0.1, activation="relu",
96
+ n_levels=4, n_heads=8, n_points=4):
97
+ super().__init__()
98
+
99
+ # self attention
100
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
101
+ self.dropout1 = nn.Dropout(dropout)
102
+ self.norm1 = nn.LayerNorm(d_model)
103
+
104
+ # ffn
105
+ self.linear1 = nn.Linear(d_model, d_ffn)
106
+ self.activation = _get_activation_fn(activation)
107
+ self.dropout2 = nn.Dropout(dropout)
108
+ self.linear2 = nn.Linear(d_ffn, d_model)
109
+ self.dropout3 = nn.Dropout(dropout)
110
+ self.norm2 = nn.LayerNorm(d_model)
111
+
112
+ @staticmethod
113
+ def with_pos_embed(tensor, pos):
114
+ return tensor if pos is None else tensor + pos
115
+
116
+ def forward_ffn(self, src):
117
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
118
+ src = src + self.dropout3(src2)
119
+ src = self.norm2(src)
120
+ return src
121
+
122
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
123
+ # self attention
124
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
125
+ src = src + self.dropout1(src2)
126
+ src = self.norm1(src)
127
+
128
+ # ffn
129
+ src = self.forward_ffn(src)
130
+
131
+ return src
132
+
133
+
134
+ class MSDeformAttnTransformerEncoder(nn.Module):
135
+ def __init__(self, encoder_layer, num_layers):
136
+ super().__init__()
137
+ self.layers = _get_clones(encoder_layer, num_layers)
138
+ self.num_layers = num_layers
139
+
140
+ @staticmethod
141
+ def get_reference_points(spatial_shapes, valid_ratios, device):
142
+ reference_points_list = []
143
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
144
+
145
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
146
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
147
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
148
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
149
+ ref = torch.stack((ref_x, ref_y), -1)
150
+ reference_points_list.append(ref)
151
+ reference_points = torch.cat(reference_points_list, 1)
152
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
153
+ return reference_points
154
+
155
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
156
+ output = src
157
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
158
+ for _, layer in enumerate(self.layers):
159
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
160
+
161
+ return output
162
+
163
+
164
+ @SEM_SEG_HEADS_REGISTRY.register()
165
+ class MSDeformAttnPixelDecoder(nn.Module):
166
+ @configurable
167
+ def __init__(
168
+ self,
169
+ input_shape: Dict[str, ShapeSpec],
170
+ *,
171
+ transformer_dropout: float,
172
+ transformer_nheads: int,
173
+ transformer_dim_feedforward: int,
174
+ transformer_enc_layers: int,
175
+ conv_dim: int,
176
+ mask_dim: int,
177
+ norm: Optional[Union[str, Callable]] = None,
178
+ # deformable transformer encoder args
179
+ transformer_in_features: List[str],
180
+ common_stride: int,
181
+ ):
182
+ """
183
+ NOTE: this interface is experimental.
184
+ Args:
185
+ input_shape: shapes (channels and stride) of the input features
186
+ transformer_dropout: dropout probability in transformer
187
+ transformer_nheads: number of heads in transformer
188
+ transformer_dim_feedforward: dimension of feedforward network
189
+ transformer_enc_layers: number of transformer encoder layers
190
+ conv_dims: number of output channels for the intermediate conv layers.
191
+ mask_dim: number of output channels for the final conv layer.
192
+ norm (str or callable): normalization for all conv layers
193
+ """
194
+ super().__init__()
195
+ transformer_input_shape = {
196
+ k: v for k, v in input_shape.items() if k in transformer_in_features
197
+ }
198
+
199
+ # this is the input shape of pixel decoder
200
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
201
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
202
+ self.feature_strides = [v.stride for k, v in input_shape]
203
+ self.feature_channels = [v.channels for k, v in input_shape]
204
+
205
+ # this is the input shape of transformer encoder (could use less features than pixel decoder
206
+ transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
207
+ self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5"
208
+ transformer_in_channels = [v.channels for k, v in transformer_input_shape]
209
+ self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers
210
+
211
+ self.transformer_num_feature_levels = len(self.transformer_in_features)
212
+ if self.transformer_num_feature_levels > 1:
213
+ input_proj_list = []
214
+ # from low resolution to high resolution (res5 -> res2)
215
+ for in_channels in transformer_in_channels[::-1]:
216
+ input_proj_list.append(nn.Sequential(
217
+ nn.Conv2d(in_channels, conv_dim, kernel_size=1),
218
+ nn.GroupNorm(32, conv_dim),
219
+ ))
220
+ self.input_proj = nn.ModuleList(input_proj_list)
221
+ else:
222
+ self.input_proj = nn.ModuleList([
223
+ nn.Sequential(
224
+ nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
225
+ nn.GroupNorm(32, conv_dim),
226
+ )])
227
+
228
+ for proj in self.input_proj:
229
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
230
+ nn.init.constant_(proj[0].bias, 0)
231
+
232
+ self.transformer = MSDeformAttnTransformerEncoderOnly(
233
+ d_model=conv_dim,
234
+ dropout=transformer_dropout,
235
+ nhead=transformer_nheads,
236
+ dim_feedforward=transformer_dim_feedforward,
237
+ num_encoder_layers=transformer_enc_layers,
238
+ num_feature_levels=self.transformer_num_feature_levels,
239
+ )
240
+ N_steps = conv_dim // 2
241
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
242
+
243
+ self.mask_dim = mask_dim
244
+ # use 1x1 conv instead
245
+ self.mask_features = Conv2d(
246
+ conv_dim,
247
+ mask_dim,
248
+ kernel_size=1,
249
+ stride=1,
250
+ padding=0,
251
+ )
252
+ weight_init.c2_xavier_fill(self.mask_features)
253
+
254
+ self.maskformer_num_feature_levels = 3 # always use 3 scales
255
+ self.common_stride = common_stride
256
+
257
+ # extra fpn levels
258
+ stride = min(self.transformer_feature_strides)
259
+ self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
260
+
261
+ lateral_convs = []
262
+ output_convs = []
263
+
264
+ use_bias = norm == ""
265
+ for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):
266
+ lateral_norm = get_norm(norm, conv_dim)
267
+ output_norm = get_norm(norm, conv_dim)
268
+
269
+ lateral_conv = Conv2d(
270
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
271
+ )
272
+ output_conv = Conv2d(
273
+ conv_dim,
274
+ conv_dim,
275
+ kernel_size=3,
276
+ stride=1,
277
+ padding=1,
278
+ bias=use_bias,
279
+ norm=output_norm,
280
+ activation=F.relu,
281
+ )
282
+ weight_init.c2_xavier_fill(lateral_conv)
283
+ weight_init.c2_xavier_fill(output_conv)
284
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
285
+ self.add_module("layer_{}".format(idx + 1), output_conv)
286
+
287
+ lateral_convs.append(lateral_conv)
288
+ output_convs.append(output_conv)
289
+ # Place convs into top-down order (from low to high resolution)
290
+ # to make the top-down computation in forward clearer.
291
+ self.lateral_convs = lateral_convs[::-1]
292
+ self.output_convs = output_convs[::-1]
293
+
294
+ @classmethod
295
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
296
+ ret = {}
297
+ ret["input_shape"] = {
298
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
299
+ }
300
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
301
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
302
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
303
+ ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
304
+ ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
305
+ # ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
306
+ ret["transformer_dim_feedforward"] = 1024 # use 1024 for deformable transformer encoder
307
+ ret[
308
+ "transformer_enc_layers"
309
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
310
+ ret["transformer_in_features"] = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES
311
+ ret["common_stride"] = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
312
+ return ret
313
+
314
+ @autocast(enabled=False)
315
+ def forward_features(self, features):
316
+ srcs = []
317
+ pos = []
318
+ # Reverse feature maps into top-down order (from low to high resolution)
319
+ for idx, f in enumerate(self.transformer_in_features[::-1]):
320
+ x = features[f].float() # deformable detr does not support half precision
321
+ srcs.append(self.input_proj[idx](x))
322
+ pos.append(self.pe_layer(x))
323
+
324
+ y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
325
+ bs = y.shape[0]
326
+
327
+ split_size_or_sections = [None] * self.transformer_num_feature_levels
328
+ for i in range(self.transformer_num_feature_levels):
329
+ if i < self.transformer_num_feature_levels - 1:
330
+ split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
331
+ else:
332
+ split_size_or_sections[i] = y.shape[1] - level_start_index[i]
333
+ y = torch.split(y, split_size_or_sections, dim=1)
334
+
335
+ out = []
336
+ multi_scale_features = []
337
+ num_cur_levels = 0
338
+ for i, z in enumerate(y):
339
+ out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
340
+
341
+ # append `out` with extra FPN levels
342
+ # Reverse feature maps into top-down order (from low to high resolution)
343
+ for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
344
+ x = features[f].float()
345
+ lateral_conv = self.lateral_convs[idx]
346
+ output_conv = self.output_convs[idx]
347
+ cur_fpn = lateral_conv(x)
348
+ # Following FPN implementation, we use nearest upsampling here
349
+ y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
350
+ y = output_conv(y)
351
+ out.append(y)
352
+
353
+ for o in out:
354
+ if num_cur_levels < self.maskformer_num_feature_levels:
355
+ multi_scale_features.append(o)
356
+ num_cur_levels += 1
357
+
358
+ return self.mask_features(out[-1]), out[0], multi_scale_features
mask2former/modeling/pixel_decoder/ops/functions/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn_func import MSDeformAttnFunction
13
+
mask2former/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.autograd import Function
19
+ from torch.autograd.function import once_differentiable
20
+
21
+ try:
22
+ import MultiScaleDeformableAttention as MSDA
23
+ except ModuleNotFoundError as e:
24
+ info_string = (
25
+ "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
26
+ "\t`cd mask2former/modeling/pixel_decoder/ops`\n"
27
+ "\t`sh make.sh`\n"
28
+ )
29
+ raise ModuleNotFoundError(info_string)
30
+
31
+
32
+ class MSDeformAttnFunction(Function):
33
+ @staticmethod
34
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
35
+ ctx.im2col_step = im2col_step
36
+ output = MSDA.ms_deform_attn_forward(
37
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
38
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
39
+ return output
40
+
41
+ @staticmethod
42
+ @once_differentiable
43
+ def backward(ctx, grad_output):
44
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
45
+ grad_value, grad_sampling_loc, grad_attn_weight = \
46
+ MSDA.ms_deform_attn_backward(
47
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
48
+
49
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
50
+
51
+
52
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
53
+ # for debug and test only,
54
+ # need to use cuda version instead
55
+ N_, S_, M_, D_ = value.shape
56
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
57
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
58
+ sampling_grids = 2 * sampling_locations - 1
59
+ sampling_value_list = []
60
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
61
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
62
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
63
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
64
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
65
+ # N_*M_, D_, Lq_, P_
66
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
67
+ mode='bilinear', padding_mode='zeros', align_corners=False)
68
+ sampling_value_list.append(sampling_value_l_)
69
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
70
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
71
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
72
+ return output.transpose(1, 2).contiguous()
mask2former/modeling/pixel_decoder/ops/make.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ------------------------------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------------------------------
7
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ # ------------------------------------------------------------------------------------------------
9
+
10
+ # Copyright (c) Facebook, Inc. and its affiliates.
11
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
12
+
13
+ python setup.py build install
mask2former/modeling/pixel_decoder/ops/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn import MSDeformAttn
mask2former/modeling/pixel_decoder/ops/modules/ms_deform_attn.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import warnings
17
+ import math
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+ from torch.nn.init import xavier_uniform_, constant_
23
+
24
+ from ..functions import MSDeformAttnFunction
25
+ from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
26
+
27
+
28
+ def _is_power_of_2(n):
29
+ if (not isinstance(n, int)) or (n < 0):
30
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
31
+ return (n & (n-1) == 0) and n != 0
32
+
33
+
34
+ class MSDeformAttn(nn.Module):
35
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
36
+ """
37
+ Multi-Scale Deformable Attention Module
38
+ :param d_model hidden dimension
39
+ :param n_levels number of feature levels
40
+ :param n_heads number of attention heads
41
+ :param n_points number of sampling points per attention head per feature level
42
+ """
43
+ super().__init__()
44
+ if d_model % n_heads != 0:
45
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
46
+ _d_per_head = d_model // n_heads
47
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
48
+ if not _is_power_of_2(_d_per_head):
49
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
50
+ "which is more efficient in our CUDA implementation.")
51
+
52
+ self.im2col_step = 128
53
+
54
+ self.d_model = d_model
55
+ self.n_levels = n_levels
56
+ self.n_heads = n_heads
57
+ self.n_points = n_points
58
+
59
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
60
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
61
+ self.value_proj = nn.Linear(d_model, d_model)
62
+ self.output_proj = nn.Linear(d_model, d_model)
63
+
64
+ self._reset_parameters()
65
+
66
+ def _reset_parameters(self):
67
+ constant_(self.sampling_offsets.weight.data, 0.)
68
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
69
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
70
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
71
+ for i in range(self.n_points):
72
+ grid_init[:, :, i, :] *= i + 1
73
+ with torch.no_grad():
74
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
75
+ constant_(self.attention_weights.weight.data, 0.)
76
+ constant_(self.attention_weights.bias.data, 0.)
77
+ xavier_uniform_(self.value_proj.weight.data)
78
+ constant_(self.value_proj.bias.data, 0.)
79
+ xavier_uniform_(self.output_proj.weight.data)
80
+ constant_(self.output_proj.bias.data, 0.)
81
+
82
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
83
+ """
84
+ :param query (N, Length_{query}, C)
85
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
86
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
87
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
88
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
89
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
90
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
91
+
92
+ :return output (N, Length_{query}, C)
93
+ """
94
+ N, Len_q, _ = query.shape
95
+ N, Len_in, _ = input_flatten.shape
96
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
97
+
98
+ value = self.value_proj(input_flatten)
99
+ if input_padding_mask is not None:
100
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
101
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
102
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
103
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
104
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
105
+ # N, Len_q, n_heads, n_levels, n_points, 2
106
+ if reference_points.shape[-1] == 2:
107
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
108
+ sampling_locations = reference_points[:, :, None, :, None, :] \
109
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
110
+ elif reference_points.shape[-1] == 4:
111
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
112
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
113
+ else:
114
+ raise ValueError(
115
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
116
+ try:
117
+ output = MSDeformAttnFunction.apply(
118
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
119
+ except:
120
+ # CPU
121
+ output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
122
+ # # For FLOPs calculation only
123
+ # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
124
+ output = self.output_proj(output)
125
+ return output
mask2former/modeling/pixel_decoder/ops/setup.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ import os
13
+ import glob
14
+
15
+ import torch
16
+
17
+ from torch.utils.cpp_extension import CUDA_HOME
18
+ from torch.utils.cpp_extension import CppExtension
19
+ from torch.utils.cpp_extension import CUDAExtension
20
+
21
+ from setuptools import find_packages
22
+ from setuptools import setup
23
+
24
+ requirements = ["torch", "torchvision"]
25
+
26
+ def get_extensions():
27
+ this_dir = os.path.dirname(os.path.abspath(__file__))
28
+ extensions_dir = os.path.join(this_dir, "src")
29
+
30
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
31
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
32
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
33
+
34
+ sources = main_file + source_cpu
35
+ extension = CppExtension
36
+ extra_compile_args = {"cxx": []}
37
+ define_macros = []
38
+
39
+ # Force cuda since torch ask for a device, not if cuda is in fact available.
40
+ if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
41
+ extension = CUDAExtension
42
+ sources += source_cuda
43
+ define_macros += [("WITH_CUDA", None)]
44
+ extra_compile_args["nvcc"] = [
45
+ "-DCUDA_HAS_FP16=1",
46
+ "-D__CUDA_NO_HALF_OPERATORS__",
47
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
48
+ "-D__CUDA_NO_HALF2_OPERATORS__",
49
+ ]
50
+ else:
51
+ if CUDA_HOME is None:
52
+ raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
53
+ else:
54
+ raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
55
+
56
+ sources = [os.path.join(extensions_dir, s) for s in sources]
57
+ include_dirs = [extensions_dir]
58
+ ext_modules = [
59
+ extension(
60
+ "MultiScaleDeformableAttention",
61
+ sources,
62
+ include_dirs=include_dirs,
63
+ define_macros=define_macros,
64
+ extra_compile_args=extra_compile_args,
65
+ )
66
+ ]
67
+ return ext_modules
68
+
69
+ setup(
70
+ name="MultiScaleDeformableAttention",
71
+ version="1.0",
72
+ author="Weijie Su",
73
+ url="https://github.com/fundamentalvision/Deformable-DETR",
74
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
75
+ packages=find_packages(exclude=("configs", "tests",)),
76
+ ext_modules=get_extensions(),
77
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
78
+ )
mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include <vector>
17
+
18
+ #include <ATen/ATen.h>
19
+ #include <ATen/cuda/CUDAContext.h>
20
+
21
+
22
+ at::Tensor
23
+ ms_deform_attn_cpu_forward(
24
+ const at::Tensor &value,
25
+ const at::Tensor &spatial_shapes,
26
+ const at::Tensor &level_start_index,
27
+ const at::Tensor &sampling_loc,
28
+ const at::Tensor &attn_weight,
29
+ const int im2col_step)
30
+ {
31
+ AT_ERROR("Not implement on cpu");
32
+ }
33
+
34
+ std::vector<at::Tensor>
35
+ ms_deform_attn_cpu_backward(
36
+ const at::Tensor &value,
37
+ const at::Tensor &spatial_shapes,
38
+ const at::Tensor &level_start_index,
39
+ const at::Tensor &sampling_loc,
40
+ const at::Tensor &attn_weight,
41
+ const at::Tensor &grad_output,
42
+ const int im2col_step)
43
+ {
44
+ AT_ERROR("Not implement on cpu");
45
+ }
46
+
mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+ #include <torch/extension.h>
18
+
19
+ at::Tensor
20
+ ms_deform_attn_cpu_forward(
21
+ const at::Tensor &value,
22
+ const at::Tensor &spatial_shapes,
23
+ const at::Tensor &level_start_index,
24
+ const at::Tensor &sampling_loc,
25
+ const at::Tensor &attn_weight,
26
+ const int im2col_step);
27
+
28
+ std::vector<at::Tensor>
29
+ ms_deform_attn_cpu_backward(
30
+ const at::Tensor &value,
31
+ const at::Tensor &spatial_shapes,
32
+ const at::Tensor &level_start_index,
33
+ const at::Tensor &sampling_loc,
34
+ const at::Tensor &attn_weight,
35
+ const at::Tensor &grad_output,
36
+ const int im2col_step);
37
+
38
+
mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include <vector>
17
+ #include "cuda/ms_deform_im2col_cuda.cuh"
18
+
19
+ #include <ATen/ATen.h>
20
+ #include <ATen/cuda/CUDAContext.h>
21
+ #include <cuda.h>
22
+ #include <cuda_runtime.h>
23
+
24
+
25
+ at::Tensor ms_deform_attn_cuda_forward(
26
+ const at::Tensor &value,
27
+ const at::Tensor &spatial_shapes,
28
+ const at::Tensor &level_start_index,
29
+ const at::Tensor &sampling_loc,
30
+ const at::Tensor &attn_weight,
31
+ const int im2col_step)
32
+ {
33
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
34
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
35
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
36
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
37
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
38
+
39
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
40
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
41
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
42
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
43
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
44
+
45
+ const int batch = value.size(0);
46
+ const int spatial_size = value.size(1);
47
+ const int num_heads = value.size(2);
48
+ const int channels = value.size(3);
49
+
50
+ const int num_levels = spatial_shapes.size(0);
51
+
52
+ const int num_query = sampling_loc.size(1);
53
+ const int num_point = sampling_loc.size(4);
54
+
55
+ const int im2col_step_ = std::min(batch, im2col_step);
56
+
57
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
58
+
59
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
60
+
61
+ const int batch_n = im2col_step_;
62
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
63
+ auto per_value_size = spatial_size * num_heads * channels;
64
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
65
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
66
+ for (int n = 0; n < batch/im2col_step_; ++n)
67
+ {
68
+ auto columns = output_n.select(0, n);
69
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
70
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
71
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
72
+ spatial_shapes.data<int64_t>(),
73
+ level_start_index.data<int64_t>(),
74
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
75
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
76
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
77
+ columns.data<scalar_t>());
78
+
79
+ }));
80
+ }
81
+
82
+ output = output.view({batch, num_query, num_heads*channels});
83
+
84
+ return output;
85
+ }
86
+
87
+
88
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
89
+ const at::Tensor &value,
90
+ const at::Tensor &spatial_shapes,
91
+ const at::Tensor &level_start_index,
92
+ const at::Tensor &sampling_loc,
93
+ const at::Tensor &attn_weight,
94
+ const at::Tensor &grad_output,
95
+ const int im2col_step)
96
+ {
97
+
98
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
99
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
100
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
101
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
102
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
103
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
104
+
105
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
106
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
107
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
108
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
109
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
110
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
111
+
112
+ const int batch = value.size(0);
113
+ const int spatial_size = value.size(1);
114
+ const int num_heads = value.size(2);
115
+ const int channels = value.size(3);
116
+
117
+ const int num_levels = spatial_shapes.size(0);
118
+
119
+ const int num_query = sampling_loc.size(1);
120
+ const int num_point = sampling_loc.size(4);
121
+
122
+ const int im2col_step_ = std::min(batch, im2col_step);
123
+
124
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
125
+
126
+ auto grad_value = at::zeros_like(value);
127
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
128
+ auto grad_attn_weight = at::zeros_like(attn_weight);
129
+
130
+ const int batch_n = im2col_step_;
131
+ auto per_value_size = spatial_size * num_heads * channels;
132
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
133
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
134
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
135
+
136
+ for (int n = 0; n < batch/im2col_step_; ++n)
137
+ {
138
+ auto grad_output_g = grad_output_n.select(0, n);
139
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
140
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
141
+ grad_output_g.data<scalar_t>(),
142
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
143
+ spatial_shapes.data<int64_t>(),
144
+ level_start_index.data<int64_t>(),
145
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
147
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
148
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
149
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
150
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
151
+
152
+ }));
153
+ }
154
+
155
+ return {
156
+ grad_value, grad_sampling_loc, grad_attn_weight
157
+ };
158
+ }
mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+ #include <torch/extension.h>
18
+
19
+ at::Tensor ms_deform_attn_cuda_forward(
20
+ const at::Tensor &value,
21
+ const at::Tensor &spatial_shapes,
22
+ const at::Tensor &level_start_index,
23
+ const at::Tensor &sampling_loc,
24
+ const at::Tensor &attn_weight,
25
+ const int im2col_step);
26
+
27
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
28
+ const at::Tensor &value,
29
+ const at::Tensor &spatial_shapes,
30
+ const at::Tensor &level_start_index,
31
+ const at::Tensor &sampling_loc,
32
+ const at::Tensor &attn_weight,
33
+ const at::Tensor &grad_output,
34
+ const int im2col_step);
35
+
mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ /*!
13
+ * Copyright (c) Facebook, Inc. and its affiliates.
14
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
15
+ */
16
+
17
+ #include <cstdio>
18
+ #include <algorithm>
19
+ #include <cstring>
20
+
21
+ #include <ATen/ATen.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+
24
+ #include <THC/THCAtomics.cuh>
25
+
26
+ #define CUDA_KERNEL_LOOP(i, n) \
27
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
28
+ i < (n); \
29
+ i += blockDim.x * gridDim.x)
30
+
31
+ const int CUDA_NUM_THREADS = 1024;
32
+ inline int GET_BLOCKS(const int N, const int num_threads)
33
+ {
34
+ return (N + num_threads - 1) / num_threads;
35
+ }
36
+
37
+
38
+ template <typename scalar_t>
39
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
40
+ const int &height, const int &width, const int &nheads, const int &channels,
41
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
42
+ {
43
+ const int h_low = floor(h);
44
+ const int w_low = floor(w);
45
+ const int h_high = h_low + 1;
46
+ const int w_high = w_low + 1;
47
+
48
+ const scalar_t lh = h - h_low;
49
+ const scalar_t lw = w - w_low;
50
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
51
+
52
+ const int w_stride = nheads * channels;
53
+ const int h_stride = width * w_stride;
54
+ const int h_low_ptr_offset = h_low * h_stride;
55
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
56
+ const int w_low_ptr_offset = w_low * w_stride;
57
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
58
+ const int base_ptr = m * channels + c;
59
+
60
+ scalar_t v1 = 0;
61
+ if (h_low >= 0 && w_low >= 0)
62
+ {
63
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
64
+ v1 = bottom_data[ptr1];
65
+ }
66
+ scalar_t v2 = 0;
67
+ if (h_low >= 0 && w_high <= width - 1)
68
+ {
69
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
70
+ v2 = bottom_data[ptr2];
71
+ }
72
+ scalar_t v3 = 0;
73
+ if (h_high <= height - 1 && w_low >= 0)
74
+ {
75
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
76
+ v3 = bottom_data[ptr3];
77
+ }
78
+ scalar_t v4 = 0;
79
+ if (h_high <= height - 1 && w_high <= width - 1)
80
+ {
81
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
82
+ v4 = bottom_data[ptr4];
83
+ }
84
+
85
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
86
+
87
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
88
+ return val;
89
+ }
90
+
91
+
92
+ template <typename scalar_t>
93
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
94
+ const int &height, const int &width, const int &nheads, const int &channels,
95
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
96
+ const scalar_t &top_grad,
97
+ const scalar_t &attn_weight,
98
+ scalar_t* &grad_value,
99
+ scalar_t* grad_sampling_loc,
100
+ scalar_t* grad_attn_weight)
101
+ {
102
+ const int h_low = floor(h);
103
+ const int w_low = floor(w);
104
+ const int h_high = h_low + 1;
105
+ const int w_high = w_low + 1;
106
+
107
+ const scalar_t lh = h - h_low;
108
+ const scalar_t lw = w - w_low;
109
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
110
+
111
+ const int w_stride = nheads * channels;
112
+ const int h_stride = width * w_stride;
113
+ const int h_low_ptr_offset = h_low * h_stride;
114
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
115
+ const int w_low_ptr_offset = w_low * w_stride;
116
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
117
+ const int base_ptr = m * channels + c;
118
+
119
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
120
+ const scalar_t top_grad_value = top_grad * attn_weight;
121
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
122
+
123
+ scalar_t v1 = 0;
124
+ if (h_low >= 0 && w_low >= 0)
125
+ {
126
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
127
+ v1 = bottom_data[ptr1];
128
+ grad_h_weight -= hw * v1;
129
+ grad_w_weight -= hh * v1;
130
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
131
+ }
132
+ scalar_t v2 = 0;
133
+ if (h_low >= 0 && w_high <= width - 1)
134
+ {
135
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
136
+ v2 = bottom_data[ptr2];
137
+ grad_h_weight -= lw * v2;
138
+ grad_w_weight += hh * v2;
139
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
140
+ }
141
+ scalar_t v3 = 0;
142
+ if (h_high <= height - 1 && w_low >= 0)
143
+ {
144
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
145
+ v3 = bottom_data[ptr3];
146
+ grad_h_weight += hw * v3;
147
+ grad_w_weight -= lh * v3;
148
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
149
+ }
150
+ scalar_t v4 = 0;
151
+ if (h_high <= height - 1 && w_high <= width - 1)
152
+ {
153
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
154
+ v4 = bottom_data[ptr4];
155
+ grad_h_weight += lw * v4;
156
+ grad_w_weight += lh * v4;
157
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
158
+ }
159
+
160
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
161
+ *grad_attn_weight = top_grad * val;
162
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
163
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
164
+ }
165
+
166
+
167
+ template <typename scalar_t>
168
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
169
+ const int &height, const int &width, const int &nheads, const int &channels,
170
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
171
+ const scalar_t &top_grad,
172
+ const scalar_t &attn_weight,
173
+ scalar_t* &grad_value,
174
+ scalar_t* grad_sampling_loc,
175
+ scalar_t* grad_attn_weight)
176
+ {
177
+ const int h_low = floor(h);
178
+ const int w_low = floor(w);
179
+ const int h_high = h_low + 1;
180
+ const int w_high = w_low + 1;
181
+
182
+ const scalar_t lh = h - h_low;
183
+ const scalar_t lw = w - w_low;
184
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
185
+
186
+ const int w_stride = nheads * channels;
187
+ const int h_stride = width * w_stride;
188
+ const int h_low_ptr_offset = h_low * h_stride;
189
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
190
+ const int w_low_ptr_offset = w_low * w_stride;
191
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
192
+ const int base_ptr = m * channels + c;
193
+
194
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
195
+ const scalar_t top_grad_value = top_grad * attn_weight;
196
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
197
+
198
+ scalar_t v1 = 0;
199
+ if (h_low >= 0 && w_low >= 0)
200
+ {
201
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
202
+ v1 = bottom_data[ptr1];
203
+ grad_h_weight -= hw * v1;
204
+ grad_w_weight -= hh * v1;
205
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
206
+ }
207
+ scalar_t v2 = 0;
208
+ if (h_low >= 0 && w_high <= width - 1)
209
+ {
210
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
211
+ v2 = bottom_data[ptr2];
212
+ grad_h_weight -= lw * v2;
213
+ grad_w_weight += hh * v2;
214
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
215
+ }
216
+ scalar_t v3 = 0;
217
+ if (h_high <= height - 1 && w_low >= 0)
218
+ {
219
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
220
+ v3 = bottom_data[ptr3];
221
+ grad_h_weight += hw * v3;
222
+ grad_w_weight -= lh * v3;
223
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
224
+ }
225
+ scalar_t v4 = 0;
226
+ if (h_high <= height - 1 && w_high <= width - 1)
227
+ {
228
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
229
+ v4 = bottom_data[ptr4];
230
+ grad_h_weight += lw * v4;
231
+ grad_w_weight += lh * v4;
232
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
233
+ }
234
+
235
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
236
+ atomicAdd(grad_attn_weight, top_grad * val);
237
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
238
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
239
+ }
240
+
241
+
242
+ template <typename scalar_t>
243
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
244
+ const scalar_t *data_value,
245
+ const int64_t *data_spatial_shapes,
246
+ const int64_t *data_level_start_index,
247
+ const scalar_t *data_sampling_loc,
248
+ const scalar_t *data_attn_weight,
249
+ const int batch_size,
250
+ const int spatial_size,
251
+ const int num_heads,
252
+ const int channels,
253
+ const int num_levels,
254
+ const int num_query,
255
+ const int num_point,
256
+ scalar_t *data_col)
257
+ {
258
+ CUDA_KERNEL_LOOP(index, n)
259
+ {
260
+ int _temp = index;
261
+ const int c_col = _temp % channels;
262
+ _temp /= channels;
263
+ const int sampling_index = _temp;
264
+ const int m_col = _temp % num_heads;
265
+ _temp /= num_heads;
266
+ const int q_col = _temp % num_query;
267
+ _temp /= num_query;
268
+ const int b_col = _temp;
269
+
270
+ scalar_t *data_col_ptr = data_col + index;
271
+ int data_weight_ptr = sampling_index * num_levels * num_point;
272
+ int data_loc_w_ptr = data_weight_ptr << 1;
273
+ const int qid_stride = num_heads * channels;
274
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
275
+ scalar_t col = 0;
276
+
277
+ for (int l_col=0; l_col < num_levels; ++l_col)
278
+ {
279
+ const int level_start_id = data_level_start_index[l_col];
280
+ const int spatial_h_ptr = l_col << 1;
281
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
282
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
283
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
284
+ for (int p_col=0; p_col < num_point; ++p_col)
285
+ {
286
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
287
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
288
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
289
+
290
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
291
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
292
+
293
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
294
+ {
295
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
296
+ }
297
+
298
+ data_weight_ptr += 1;
299
+ data_loc_w_ptr += 2;
300
+ }
301
+ }
302
+ *data_col_ptr = col;
303
+ }
304
+ }
305
+
306
+ template <typename scalar_t, unsigned int blockSize>
307
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
308
+ const scalar_t *grad_col,
309
+ const scalar_t *data_value,
310
+ const int64_t *data_spatial_shapes,
311
+ const int64_t *data_level_start_index,
312
+ const scalar_t *data_sampling_loc,
313
+ const scalar_t *data_attn_weight,
314
+ const int batch_size,
315
+ const int spatial_size,
316
+ const int num_heads,
317
+ const int channels,
318
+ const int num_levels,
319
+ const int num_query,
320
+ const int num_point,
321
+ scalar_t *grad_value,
322
+ scalar_t *grad_sampling_loc,
323
+ scalar_t *grad_attn_weight)
324
+ {
325
+ CUDA_KERNEL_LOOP(index, n)
326
+ {
327
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
328
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
329
+ unsigned int tid = threadIdx.x;
330
+ int _temp = index;
331
+ const int c_col = _temp % channels;
332
+ _temp /= channels;
333
+ const int sampling_index = _temp;
334
+ const int m_col = _temp % num_heads;
335
+ _temp /= num_heads;
336
+ const int q_col = _temp % num_query;
337
+ _temp /= num_query;
338
+ const int b_col = _temp;
339
+
340
+ const scalar_t top_grad = grad_col[index];
341
+
342
+ int data_weight_ptr = sampling_index * num_levels * num_point;
343
+ int data_loc_w_ptr = data_weight_ptr << 1;
344
+ const int grad_sampling_ptr = data_weight_ptr;
345
+ grad_sampling_loc += grad_sampling_ptr << 1;
346
+ grad_attn_weight += grad_sampling_ptr;
347
+ const int grad_weight_stride = 1;
348
+ const int grad_loc_stride = 2;
349
+ const int qid_stride = num_heads * channels;
350
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
351
+
352
+ for (int l_col=0; l_col < num_levels; ++l_col)
353
+ {
354
+ const int level_start_id = data_level_start_index[l_col];
355
+ const int spatial_h_ptr = l_col << 1;
356
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
357
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
358
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
359
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
360
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
361
+
362
+ for (int p_col=0; p_col < num_point; ++p_col)
363
+ {
364
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
365
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
366
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
367
+
368
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
369
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
370
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
371
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
372
+ *(cache_grad_attn_weight+threadIdx.x)=0;
373
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
374
+ {
375
+ ms_deform_attn_col2im_bilinear(
376
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
377
+ top_grad, weight, grad_value_ptr,
378
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
379
+ }
380
+
381
+ __syncthreads();
382
+ if (tid == 0)
383
+ {
384
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
385
+ int sid=2;
386
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
387
+ {
388
+ _grad_w += cache_grad_sampling_loc[sid];
389
+ _grad_h += cache_grad_sampling_loc[sid + 1];
390
+ _grad_a += cache_grad_attn_weight[tid];
391
+ sid += 2;
392
+ }
393
+
394
+
395
+ *grad_sampling_loc = _grad_w;
396
+ *(grad_sampling_loc + 1) = _grad_h;
397
+ *grad_attn_weight = _grad_a;
398
+ }
399
+ __syncthreads();
400
+
401
+ data_weight_ptr += 1;
402
+ data_loc_w_ptr += 2;
403
+ grad_attn_weight += grad_weight_stride;
404
+ grad_sampling_loc += grad_loc_stride;
405
+ }
406
+ }
407
+ }
408
+ }
409
+
410
+
411
+ template <typename scalar_t, unsigned int blockSize>
412
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
413
+ const scalar_t *grad_col,
414
+ const scalar_t *data_value,
415
+ const int64_t *data_spatial_shapes,
416
+ const int64_t *data_level_start_index,
417
+ const scalar_t *data_sampling_loc,
418
+ const scalar_t *data_attn_weight,
419
+ const int batch_size,
420
+ const int spatial_size,
421
+ const int num_heads,
422
+ const int channels,
423
+ const int num_levels,
424
+ const int num_query,
425
+ const int num_point,
426
+ scalar_t *grad_value,
427
+ scalar_t *grad_sampling_loc,
428
+ scalar_t *grad_attn_weight)
429
+ {
430
+ CUDA_KERNEL_LOOP(index, n)
431
+ {
432
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
433
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
434
+ unsigned int tid = threadIdx.x;
435
+ int _temp = index;
436
+ const int c_col = _temp % channels;
437
+ _temp /= channels;
438
+ const int sampling_index = _temp;
439
+ const int m_col = _temp % num_heads;
440
+ _temp /= num_heads;
441
+ const int q_col = _temp % num_query;
442
+ _temp /= num_query;
443
+ const int b_col = _temp;
444
+
445
+ const scalar_t top_grad = grad_col[index];
446
+
447
+ int data_weight_ptr = sampling_index * num_levels * num_point;
448
+ int data_loc_w_ptr = data_weight_ptr << 1;
449
+ const int grad_sampling_ptr = data_weight_ptr;
450
+ grad_sampling_loc += grad_sampling_ptr << 1;
451
+ grad_attn_weight += grad_sampling_ptr;
452
+ const int grad_weight_stride = 1;
453
+ const int grad_loc_stride = 2;
454
+ const int qid_stride = num_heads * channels;
455
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
456
+
457
+ for (int l_col=0; l_col < num_levels; ++l_col)
458
+ {
459
+ const int level_start_id = data_level_start_index[l_col];
460
+ const int spatial_h_ptr = l_col << 1;
461
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
462
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
463
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
464
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
465
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
466
+
467
+ for (int p_col=0; p_col < num_point; ++p_col)
468
+ {
469
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
470
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
471
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
472
+
473
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
474
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
475
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
476
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
477
+ *(cache_grad_attn_weight+threadIdx.x)=0;
478
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
479
+ {
480
+ ms_deform_attn_col2im_bilinear(
481
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
482
+ top_grad, weight, grad_value_ptr,
483
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
484
+ }
485
+
486
+ __syncthreads();
487
+
488
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
489
+ {
490
+ if (tid < s) {
491
+ const unsigned int xid1 = tid << 1;
492
+ const unsigned int xid2 = (tid + s) << 1;
493
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
494
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
495
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
496
+ }
497
+ __syncthreads();
498
+ }
499
+
500
+ if (tid == 0)
501
+ {
502
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
503
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
504
+ *grad_attn_weight = cache_grad_attn_weight[0];
505
+ }
506
+ __syncthreads();
507
+
508
+ data_weight_ptr += 1;
509
+ data_loc_w_ptr += 2;
510
+ grad_attn_weight += grad_weight_stride;
511
+ grad_sampling_loc += grad_loc_stride;
512
+ }
513
+ }
514
+ }
515
+ }
516
+
517
+
518
+ template <typename scalar_t>
519
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
520
+ const scalar_t *grad_col,
521
+ const scalar_t *data_value,
522
+ const int64_t *data_spatial_shapes,
523
+ const int64_t *data_level_start_index,
524
+ const scalar_t *data_sampling_loc,
525
+ const scalar_t *data_attn_weight,
526
+ const int batch_size,
527
+ const int spatial_size,
528
+ const int num_heads,
529
+ const int channels,
530
+ const int num_levels,
531
+ const int num_query,
532
+ const int num_point,
533
+ scalar_t *grad_value,
534
+ scalar_t *grad_sampling_loc,
535
+ scalar_t *grad_attn_weight)
536
+ {
537
+ CUDA_KERNEL_LOOP(index, n)
538
+ {
539
+ extern __shared__ int _s[];
540
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
541
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
542
+ unsigned int tid = threadIdx.x;
543
+ int _temp = index;
544
+ const int c_col = _temp % channels;
545
+ _temp /= channels;
546
+ const int sampling_index = _temp;
547
+ const int m_col = _temp % num_heads;
548
+ _temp /= num_heads;
549
+ const int q_col = _temp % num_query;
550
+ _temp /= num_query;
551
+ const int b_col = _temp;
552
+
553
+ const scalar_t top_grad = grad_col[index];
554
+
555
+ int data_weight_ptr = sampling_index * num_levels * num_point;
556
+ int data_loc_w_ptr = data_weight_ptr << 1;
557
+ const int grad_sampling_ptr = data_weight_ptr;
558
+ grad_sampling_loc += grad_sampling_ptr << 1;
559
+ grad_attn_weight += grad_sampling_ptr;
560
+ const int grad_weight_stride = 1;
561
+ const int grad_loc_stride = 2;
562
+ const int qid_stride = num_heads * channels;
563
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
564
+
565
+ for (int l_col=0; l_col < num_levels; ++l_col)
566
+ {
567
+ const int level_start_id = data_level_start_index[l_col];
568
+ const int spatial_h_ptr = l_col << 1;
569
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
570
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
571
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
572
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
573
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
574
+
575
+ for (int p_col=0; p_col < num_point; ++p_col)
576
+ {
577
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
578
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
579
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
580
+
581
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
582
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
583
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
584
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
585
+ *(cache_grad_attn_weight+threadIdx.x)=0;
586
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
587
+ {
588
+ ms_deform_attn_col2im_bilinear(
589
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
590
+ top_grad, weight, grad_value_ptr,
591
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
592
+ }
593
+
594
+ __syncthreads();
595
+ if (tid == 0)
596
+ {
597
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
598
+ int sid=2;
599
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
600
+ {
601
+ _grad_w += cache_grad_sampling_loc[sid];
602
+ _grad_h += cache_grad_sampling_loc[sid + 1];
603
+ _grad_a += cache_grad_attn_weight[tid];
604
+ sid += 2;
605
+ }
606
+
607
+
608
+ *grad_sampling_loc = _grad_w;
609
+ *(grad_sampling_loc + 1) = _grad_h;
610
+ *grad_attn_weight = _grad_a;
611
+ }
612
+ __syncthreads();
613
+
614
+ data_weight_ptr += 1;
615
+ data_loc_w_ptr += 2;
616
+ grad_attn_weight += grad_weight_stride;
617
+ grad_sampling_loc += grad_loc_stride;
618
+ }
619
+ }
620
+ }
621
+ }
622
+
623
+ template <typename scalar_t>
624
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
625
+ const scalar_t *grad_col,
626
+ const scalar_t *data_value,
627
+ const int64_t *data_spatial_shapes,
628
+ const int64_t *data_level_start_index,
629
+ const scalar_t *data_sampling_loc,
630
+ const scalar_t *data_attn_weight,
631
+ const int batch_size,
632
+ const int spatial_size,
633
+ const int num_heads,
634
+ const int channels,
635
+ const int num_levels,
636
+ const int num_query,
637
+ const int num_point,
638
+ scalar_t *grad_value,
639
+ scalar_t *grad_sampling_loc,
640
+ scalar_t *grad_attn_weight)
641
+ {
642
+ CUDA_KERNEL_LOOP(index, n)
643
+ {
644
+ extern __shared__ int _s[];
645
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
646
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
647
+ unsigned int tid = threadIdx.x;
648
+ int _temp = index;
649
+ const int c_col = _temp % channels;
650
+ _temp /= channels;
651
+ const int sampling_index = _temp;
652
+ const int m_col = _temp % num_heads;
653
+ _temp /= num_heads;
654
+ const int q_col = _temp % num_query;
655
+ _temp /= num_query;
656
+ const int b_col = _temp;
657
+
658
+ const scalar_t top_grad = grad_col[index];
659
+
660
+ int data_weight_ptr = sampling_index * num_levels * num_point;
661
+ int data_loc_w_ptr = data_weight_ptr << 1;
662
+ const int grad_sampling_ptr = data_weight_ptr;
663
+ grad_sampling_loc += grad_sampling_ptr << 1;
664
+ grad_attn_weight += grad_sampling_ptr;
665
+ const int grad_weight_stride = 1;
666
+ const int grad_loc_stride = 2;
667
+ const int qid_stride = num_heads * channels;
668
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
669
+
670
+ for (int l_col=0; l_col < num_levels; ++l_col)
671
+ {
672
+ const int level_start_id = data_level_start_index[l_col];
673
+ const int spatial_h_ptr = l_col << 1;
674
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
675
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
676
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
677
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
678
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
679
+
680
+ for (int p_col=0; p_col < num_point; ++p_col)
681
+ {
682
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
683
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
684
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
685
+
686
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
687
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
688
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
689
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
690
+ *(cache_grad_attn_weight+threadIdx.x)=0;
691
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
692
+ {
693
+ ms_deform_attn_col2im_bilinear(
694
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
695
+ top_grad, weight, grad_value_ptr,
696
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
697
+ }
698
+
699
+ __syncthreads();
700
+
701
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
702
+ {
703
+ if (tid < s) {
704
+ const unsigned int xid1 = tid << 1;
705
+ const unsigned int xid2 = (tid + s) << 1;
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
709
+ if (tid + (s << 1) < spre)
710
+ {
711
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
712
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
713
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
714
+ }
715
+ }
716
+ __syncthreads();
717
+ }
718
+
719
+ if (tid == 0)
720
+ {
721
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
722
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
723
+ *grad_attn_weight = cache_grad_attn_weight[0];
724
+ }
725
+ __syncthreads();
726
+
727
+ data_weight_ptr += 1;
728
+ data_loc_w_ptr += 2;
729
+ grad_attn_weight += grad_weight_stride;
730
+ grad_sampling_loc += grad_loc_stride;
731
+ }
732
+ }
733
+ }
734
+ }
735
+
736
+ template <typename scalar_t>
737
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
738
+ const scalar_t *grad_col,
739
+ const scalar_t *data_value,
740
+ const int64_t *data_spatial_shapes,
741
+ const int64_t *data_level_start_index,
742
+ const scalar_t *data_sampling_loc,
743
+ const scalar_t *data_attn_weight,
744
+ const int batch_size,
745
+ const int spatial_size,
746
+ const int num_heads,
747
+ const int channels,
748
+ const int num_levels,
749
+ const int num_query,
750
+ const int num_point,
751
+ scalar_t *grad_value,
752
+ scalar_t *grad_sampling_loc,
753
+ scalar_t *grad_attn_weight)
754
+ {
755
+ CUDA_KERNEL_LOOP(index, n)
756
+ {
757
+ extern __shared__ int _s[];
758
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
759
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
760
+ unsigned int tid = threadIdx.x;
761
+ int _temp = index;
762
+ const int c_col = _temp % channels;
763
+ _temp /= channels;
764
+ const int sampling_index = _temp;
765
+ const int m_col = _temp % num_heads;
766
+ _temp /= num_heads;
767
+ const int q_col = _temp % num_query;
768
+ _temp /= num_query;
769
+ const int b_col = _temp;
770
+
771
+ const scalar_t top_grad = grad_col[index];
772
+
773
+ int data_weight_ptr = sampling_index * num_levels * num_point;
774
+ int data_loc_w_ptr = data_weight_ptr << 1;
775
+ const int grad_sampling_ptr = data_weight_ptr;
776
+ grad_sampling_loc += grad_sampling_ptr << 1;
777
+ grad_attn_weight += grad_sampling_ptr;
778
+ const int grad_weight_stride = 1;
779
+ const int grad_loc_stride = 2;
780
+ const int qid_stride = num_heads * channels;
781
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
782
+
783
+ for (int l_col=0; l_col < num_levels; ++l_col)
784
+ {
785
+ const int level_start_id = data_level_start_index[l_col];
786
+ const int spatial_h_ptr = l_col << 1;
787
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
788
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
789
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
790
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
791
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
792
+
793
+ for (int p_col=0; p_col < num_point; ++p_col)
794
+ {
795
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
796
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
797
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
798
+
799
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
800
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
801
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
802
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
803
+ *(cache_grad_attn_weight+threadIdx.x)=0;
804
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
805
+ {
806
+ ms_deform_attn_col2im_bilinear(
807
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
808
+ top_grad, weight, grad_value_ptr,
809
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
810
+ }
811
+
812
+ __syncthreads();
813
+
814
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
815
+ {
816
+ if (tid < s) {
817
+ const unsigned int xid1 = tid << 1;
818
+ const unsigned int xid2 = (tid + s) << 1;
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
822
+ if (tid + (s << 1) < spre)
823
+ {
824
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
825
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
826
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
827
+ }
828
+ }
829
+ __syncthreads();
830
+ }
831
+
832
+ if (tid == 0)
833
+ {
834
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
835
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
836
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
837
+ }
838
+ __syncthreads();
839
+
840
+ data_weight_ptr += 1;
841
+ data_loc_w_ptr += 2;
842
+ grad_attn_weight += grad_weight_stride;
843
+ grad_sampling_loc += grad_loc_stride;
844
+ }
845
+ }
846
+ }
847
+ }
848
+
849
+
850
+ template <typename scalar_t>
851
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
852
+ const scalar_t *grad_col,
853
+ const scalar_t *data_value,
854
+ const int64_t *data_spatial_shapes,
855
+ const int64_t *data_level_start_index,
856
+ const scalar_t *data_sampling_loc,
857
+ const scalar_t *data_attn_weight,
858
+ const int batch_size,
859
+ const int spatial_size,
860
+ const int num_heads,
861
+ const int channels,
862
+ const int num_levels,
863
+ const int num_query,
864
+ const int num_point,
865
+ scalar_t *grad_value,
866
+ scalar_t *grad_sampling_loc,
867
+ scalar_t *grad_attn_weight)
868
+ {
869
+ CUDA_KERNEL_LOOP(index, n)
870
+ {
871
+ int _temp = index;
872
+ const int c_col = _temp % channels;
873
+ _temp /= channels;
874
+ const int sampling_index = _temp;
875
+ const int m_col = _temp % num_heads;
876
+ _temp /= num_heads;
877
+ const int q_col = _temp % num_query;
878
+ _temp /= num_query;
879
+ const int b_col = _temp;
880
+
881
+ const scalar_t top_grad = grad_col[index];
882
+
883
+ int data_weight_ptr = sampling_index * num_levels * num_point;
884
+ int data_loc_w_ptr = data_weight_ptr << 1;
885
+ const int grad_sampling_ptr = data_weight_ptr;
886
+ grad_sampling_loc += grad_sampling_ptr << 1;
887
+ grad_attn_weight += grad_sampling_ptr;
888
+ const int grad_weight_stride = 1;
889
+ const int grad_loc_stride = 2;
890
+ const int qid_stride = num_heads * channels;
891
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
892
+
893
+ for (int l_col=0; l_col < num_levels; ++l_col)
894
+ {
895
+ const int level_start_id = data_level_start_index[l_col];
896
+ const int spatial_h_ptr = l_col << 1;
897
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
898
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
899
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
900
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
901
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
902
+
903
+ for (int p_col=0; p_col < num_point; ++p_col)
904
+ {
905
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
906
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
907
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
908
+
909
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
910
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
911
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
912
+ {
913
+ ms_deform_attn_col2im_bilinear_gm(
914
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
915
+ top_grad, weight, grad_value_ptr,
916
+ grad_sampling_loc, grad_attn_weight);
917
+ }
918
+ data_weight_ptr += 1;
919
+ data_loc_w_ptr += 2;
920
+ grad_attn_weight += grad_weight_stride;
921
+ grad_sampling_loc += grad_loc_stride;
922
+ }
923
+ }
924
+ }
925
+ }
926
+
927
+
928
+ template <typename scalar_t>
929
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
930
+ const scalar_t* data_value,
931
+ const int64_t* data_spatial_shapes,
932
+ const int64_t* data_level_start_index,
933
+ const scalar_t* data_sampling_loc,
934
+ const scalar_t* data_attn_weight,
935
+ const int batch_size,
936
+ const int spatial_size,
937
+ const int num_heads,
938
+ const int channels,
939
+ const int num_levels,
940
+ const int num_query,
941
+ const int num_point,
942
+ scalar_t* data_col)
943
+ {
944
+ const int num_kernels = batch_size * num_query * num_heads * channels;
945
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
946
+ const int num_threads = CUDA_NUM_THREADS;
947
+ ms_deformable_im2col_gpu_kernel<scalar_t>
948
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
949
+ 0, stream>>>(
950
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
951
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
952
+
953
+ cudaError_t err = cudaGetLastError();
954
+ if (err != cudaSuccess)
955
+ {
956
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
957
+ }
958
+
959
+ }
960
+
961
+ template <typename scalar_t>
962
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
963
+ const scalar_t* grad_col,
964
+ const scalar_t* data_value,
965
+ const int64_t * data_spatial_shapes,
966
+ const int64_t * data_level_start_index,
967
+ const scalar_t * data_sampling_loc,
968
+ const scalar_t * data_attn_weight,
969
+ const int batch_size,
970
+ const int spatial_size,
971
+ const int num_heads,
972
+ const int channels,
973
+ const int num_levels,
974
+ const int num_query,
975
+ const int num_point,
976
+ scalar_t* grad_value,
977
+ scalar_t* grad_sampling_loc,
978
+ scalar_t* grad_attn_weight)
979
+ {
980
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
981
+ const int num_kernels = batch_size * num_query * num_heads * channels;
982
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
983
+ if (channels > 1024)
984
+ {
985
+ if ((channels & 1023) == 0)
986
+ {
987
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
988
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
989
+ num_threads*3*sizeof(scalar_t), stream>>>(
990
+ num_kernels,
991
+ grad_col,
992
+ data_value,
993
+ data_spatial_shapes,
994
+ data_level_start_index,
995
+ data_sampling_loc,
996
+ data_attn_weight,
997
+ batch_size,
998
+ spatial_size,
999
+ num_heads,
1000
+ channels,
1001
+ num_levels,
1002
+ num_query,
1003
+ num_point,
1004
+ grad_value,
1005
+ grad_sampling_loc,
1006
+ grad_attn_weight);
1007
+ }
1008
+ else
1009
+ {
1010
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1011
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1012
+ 0, stream>>>(
1013
+ num_kernels,
1014
+ grad_col,
1015
+ data_value,
1016
+ data_spatial_shapes,
1017
+ data_level_start_index,
1018
+ data_sampling_loc,
1019
+ data_attn_weight,
1020
+ batch_size,
1021
+ spatial_size,
1022
+ num_heads,
1023
+ channels,
1024
+ num_levels,
1025
+ num_query,
1026
+ num_point,
1027
+ grad_value,
1028
+ grad_sampling_loc,
1029
+ grad_attn_weight);
1030
+ }
1031
+ }
1032
+ else{
1033
+ switch(channels)
1034
+ {
1035
+ case 1:
1036
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1037
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1038
+ 0, stream>>>(
1039
+ num_kernels,
1040
+ grad_col,
1041
+ data_value,
1042
+ data_spatial_shapes,
1043
+ data_level_start_index,
1044
+ data_sampling_loc,
1045
+ data_attn_weight,
1046
+ batch_size,
1047
+ spatial_size,
1048
+ num_heads,
1049
+ channels,
1050
+ num_levels,
1051
+ num_query,
1052
+ num_point,
1053
+ grad_value,
1054
+ grad_sampling_loc,
1055
+ grad_attn_weight);
1056
+ break;
1057
+ case 2:
1058
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1059
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1060
+ 0, stream>>>(
1061
+ num_kernels,
1062
+ grad_col,
1063
+ data_value,
1064
+ data_spatial_shapes,
1065
+ data_level_start_index,
1066
+ data_sampling_loc,
1067
+ data_attn_weight,
1068
+ batch_size,
1069
+ spatial_size,
1070
+ num_heads,
1071
+ channels,
1072
+ num_levels,
1073
+ num_query,
1074
+ num_point,
1075
+ grad_value,
1076
+ grad_sampling_loc,
1077
+ grad_attn_weight);
1078
+ break;
1079
+ case 4:
1080
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1081
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1082
+ 0, stream>>>(
1083
+ num_kernels,
1084
+ grad_col,
1085
+ data_value,
1086
+ data_spatial_shapes,
1087
+ data_level_start_index,
1088
+ data_sampling_loc,
1089
+ data_attn_weight,
1090
+ batch_size,
1091
+ spatial_size,
1092
+ num_heads,
1093
+ channels,
1094
+ num_levels,
1095
+ num_query,
1096
+ num_point,
1097
+ grad_value,
1098
+ grad_sampling_loc,
1099
+ grad_attn_weight);
1100
+ break;
1101
+ case 8:
1102
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1103
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1104
+ 0, stream>>>(
1105
+ num_kernels,
1106
+ grad_col,
1107
+ data_value,
1108
+ data_spatial_shapes,
1109
+ data_level_start_index,
1110
+ data_sampling_loc,
1111
+ data_attn_weight,
1112
+ batch_size,
1113
+ spatial_size,
1114
+ num_heads,
1115
+ channels,
1116
+ num_levels,
1117
+ num_query,
1118
+ num_point,
1119
+ grad_value,
1120
+ grad_sampling_loc,
1121
+ grad_attn_weight);
1122
+ break;
1123
+ case 16:
1124
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1125
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1126
+ 0, stream>>>(
1127
+ num_kernels,
1128
+ grad_col,
1129
+ data_value,
1130
+ data_spatial_shapes,
1131
+ data_level_start_index,
1132
+ data_sampling_loc,
1133
+ data_attn_weight,
1134
+ batch_size,
1135
+ spatial_size,
1136
+ num_heads,
1137
+ channels,
1138
+ num_levels,
1139
+ num_query,
1140
+ num_point,
1141
+ grad_value,
1142
+ grad_sampling_loc,
1143
+ grad_attn_weight);
1144
+ break;
1145
+ case 32:
1146
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1147
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1148
+ 0, stream>>>(
1149
+ num_kernels,
1150
+ grad_col,
1151
+ data_value,
1152
+ data_spatial_shapes,
1153
+ data_level_start_index,
1154
+ data_sampling_loc,
1155
+ data_attn_weight,
1156
+ batch_size,
1157
+ spatial_size,
1158
+ num_heads,
1159
+ channels,
1160
+ num_levels,
1161
+ num_query,
1162
+ num_point,
1163
+ grad_value,
1164
+ grad_sampling_loc,
1165
+ grad_attn_weight);
1166
+ break;
1167
+ case 64:
1168
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1169
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1170
+ 0, stream>>>(
1171
+ num_kernels,
1172
+ grad_col,
1173
+ data_value,
1174
+ data_spatial_shapes,
1175
+ data_level_start_index,
1176
+ data_sampling_loc,
1177
+ data_attn_weight,
1178
+ batch_size,
1179
+ spatial_size,
1180
+ num_heads,
1181
+ channels,
1182
+ num_levels,
1183
+ num_query,
1184
+ num_point,
1185
+ grad_value,
1186
+ grad_sampling_loc,
1187
+ grad_attn_weight);
1188
+ break;
1189
+ case 128:
1190
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1191
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1192
+ 0, stream>>>(
1193
+ num_kernels,
1194
+ grad_col,
1195
+ data_value,
1196
+ data_spatial_shapes,
1197
+ data_level_start_index,
1198
+ data_sampling_loc,
1199
+ data_attn_weight,
1200
+ batch_size,
1201
+ spatial_size,
1202
+ num_heads,
1203
+ channels,
1204
+ num_levels,
1205
+ num_query,
1206
+ num_point,
1207
+ grad_value,
1208
+ grad_sampling_loc,
1209
+ grad_attn_weight);
1210
+ break;
1211
+ case 256:
1212
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1213
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1214
+ 0, stream>>>(
1215
+ num_kernels,
1216
+ grad_col,
1217
+ data_value,
1218
+ data_spatial_shapes,
1219
+ data_level_start_index,
1220
+ data_sampling_loc,
1221
+ data_attn_weight,
1222
+ batch_size,
1223
+ spatial_size,
1224
+ num_heads,
1225
+ channels,
1226
+ num_levels,
1227
+ num_query,
1228
+ num_point,
1229
+ grad_value,
1230
+ grad_sampling_loc,
1231
+ grad_attn_weight);
1232
+ break;
1233
+ case 512:
1234
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1235
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1236
+ 0, stream>>>(
1237
+ num_kernels,
1238
+ grad_col,
1239
+ data_value,
1240
+ data_spatial_shapes,
1241
+ data_level_start_index,
1242
+ data_sampling_loc,
1243
+ data_attn_weight,
1244
+ batch_size,
1245
+ spatial_size,
1246
+ num_heads,
1247
+ channels,
1248
+ num_levels,
1249
+ num_query,
1250
+ num_point,
1251
+ grad_value,
1252
+ grad_sampling_loc,
1253
+ grad_attn_weight);
1254
+ break;
1255
+ case 1024:
1256
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1257
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1258
+ 0, stream>>>(
1259
+ num_kernels,
1260
+ grad_col,
1261
+ data_value,
1262
+ data_spatial_shapes,
1263
+ data_level_start_index,
1264
+ data_sampling_loc,
1265
+ data_attn_weight,
1266
+ batch_size,
1267
+ spatial_size,
1268
+ num_heads,
1269
+ channels,
1270
+ num_levels,
1271
+ num_query,
1272
+ num_point,
1273
+ grad_value,
1274
+ grad_sampling_loc,
1275
+ grad_attn_weight);
1276
+ break;
1277
+ default:
1278
+ if (channels < 64)
1279
+ {
1280
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1281
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1282
+ num_threads*3*sizeof(scalar_t), stream>>>(
1283
+ num_kernels,
1284
+ grad_col,
1285
+ data_value,
1286
+ data_spatial_shapes,
1287
+ data_level_start_index,
1288
+ data_sampling_loc,
1289
+ data_attn_weight,
1290
+ batch_size,
1291
+ spatial_size,
1292
+ num_heads,
1293
+ channels,
1294
+ num_levels,
1295
+ num_query,
1296
+ num_point,
1297
+ grad_value,
1298
+ grad_sampling_loc,
1299
+ grad_attn_weight);
1300
+ }
1301
+ else
1302
+ {
1303
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1304
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1305
+ num_threads*3*sizeof(scalar_t), stream>>>(
1306
+ num_kernels,
1307
+ grad_col,
1308
+ data_value,
1309
+ data_spatial_shapes,
1310
+ data_level_start_index,
1311
+ data_sampling_loc,
1312
+ data_attn_weight,
1313
+ batch_size,
1314
+ spatial_size,
1315
+ num_heads,
1316
+ channels,
1317
+ num_levels,
1318
+ num_query,
1319
+ num_point,
1320
+ grad_value,
1321
+ grad_sampling_loc,
1322
+ grad_attn_weight);
1323
+ }
1324
+ }
1325
+ }
1326
+ cudaError_t err = cudaGetLastError();
1327
+ if (err != cudaSuccess)
1328
+ {
1329
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1330
+ }
1331
+
1332
+ }
mask2former/modeling/pixel_decoder/ops/src/ms_deform_attn.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+
18
+ #include "cpu/ms_deform_attn_cpu.h"
19
+
20
+ #ifdef WITH_CUDA
21
+ #include "cuda/ms_deform_attn_cuda.h"
22
+ #endif
23
+
24
+
25
+ at::Tensor
26
+ ms_deform_attn_forward(
27
+ const at::Tensor &value,
28
+ const at::Tensor &spatial_shapes,
29
+ const at::Tensor &level_start_index,
30
+ const at::Tensor &sampling_loc,
31
+ const at::Tensor &attn_weight,
32
+ const int im2col_step)
33
+ {
34
+ if (value.type().is_cuda())
35
+ {
36
+ #ifdef WITH_CUDA
37
+ return ms_deform_attn_cuda_forward(
38
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
39
+ #else
40
+ AT_ERROR("Not compiled with GPU support");
41
+ #endif
42
+ }
43
+ AT_ERROR("Not implemented on the CPU");
44
+ }
45
+
46
+ std::vector<at::Tensor>
47
+ ms_deform_attn_backward(
48
+ const at::Tensor &value,
49
+ const at::Tensor &spatial_shapes,
50
+ const at::Tensor &level_start_index,
51
+ const at::Tensor &sampling_loc,
52
+ const at::Tensor &attn_weight,
53
+ const at::Tensor &grad_output,
54
+ const int im2col_step)
55
+ {
56
+ if (value.type().is_cuda())
57
+ {
58
+ #ifdef WITH_CUDA
59
+ return ms_deform_attn_cuda_backward(
60
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
61
+ #else
62
+ AT_ERROR("Not compiled with GPU support");
63
+ #endif
64
+ }
65
+ AT_ERROR("Not implemented on the CPU");
66
+ }
67
+
mask2former/modeling/pixel_decoder/ops/src/vision.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include "ms_deform_attn.h"
17
+
18
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
20
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
21
+ }
mask2former/modeling/pixel_decoder/ops/test.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.autograd import gradcheck
20
+
21
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
22
+
23
+
24
+ N, M, D = 1, 2, 2
25
+ Lq, L, P = 2, 2, 2
26
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
27
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
28
+ S = sum([(H*W).item() for H, W in shapes])
29
+
30
+
31
+ torch.manual_seed(3)
32
+
33
+
34
+ @torch.no_grad()
35
+ def check_forward_equal_with_pytorch_double():
36
+ value = torch.rand(N, S, M, D).cuda() * 0.01
37
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
38
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
39
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
40
+ im2col_step = 2
41
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
42
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
43
+ fwdok = torch.allclose(output_cuda, output_pytorch)
44
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
45
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
46
+
47
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
48
+
49
+
50
+ @torch.no_grad()
51
+ def check_forward_equal_with_pytorch_float():
52
+ value = torch.rand(N, S, M, D).cuda() * 0.01
53
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
54
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
55
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
56
+ im2col_step = 2
57
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
58
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
59
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
60
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
61
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
62
+
63
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
64
+
65
+
66
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
67
+
68
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
69
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
70
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
71
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
72
+ im2col_step = 2
73
+ func = MSDeformAttnFunction.apply
74
+
75
+ value.requires_grad = grad_value
76
+ sampling_locations.requires_grad = grad_sampling_loc
77
+ attention_weights.requires_grad = grad_attn_weight
78
+
79
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
80
+
81
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
82
+
83
+
84
+ if __name__ == '__main__':
85
+ check_forward_equal_with_pytorch_double()
86
+ check_forward_equal_with_pytorch_float()
87
+
88
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
89
+ check_gradient_numerical(channels, True, True, True)
90
+
91
+
92
+
mask2former/modeling/transformer_decoder/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .maskformer_transformer_decoder import StandardTransformerDecoder
3
+ from .mask2former_transformer_decoder import MultiScaleMaskedTransformerDecoder
4
+ from .opd_transformer_decoder import OPDMultiScaleMaskedTransformerDecoder
mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ import logging
4
+ import fvcore.nn.weight_init as weight_init
5
+ from typing import Optional
6
+ import torch
7
+ from torch import nn, Tensor
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d
12
+
13
+ from .position_encoding import PositionEmbeddingSine
14
+ from .maskformer_transformer_decoder import TRANSFORMER_DECODER_REGISTRY
15
+
16
+
17
+ class SelfAttentionLayer(nn.Module):
18
+
19
+ def __init__(self, d_model, nhead, dropout=0.0,
20
+ activation="relu", normalize_before=False):
21
+ super().__init__()
22
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
23
+
24
+ self.norm = nn.LayerNorm(d_model)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ self.activation = _get_activation_fn(activation)
28
+ self.normalize_before = normalize_before
29
+
30
+ self._reset_parameters()
31
+
32
+ def _reset_parameters(self):
33
+ for p in self.parameters():
34
+ if p.dim() > 1:
35
+ nn.init.xavier_uniform_(p)
36
+
37
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
38
+ return tensor if pos is None else tensor + pos
39
+
40
+ def forward_post(self, tgt,
41
+ tgt_mask: Optional[Tensor] = None,
42
+ tgt_key_padding_mask: Optional[Tensor] = None,
43
+ query_pos: Optional[Tensor] = None):
44
+ q = k = self.with_pos_embed(tgt, query_pos)
45
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
46
+ key_padding_mask=tgt_key_padding_mask)[0]
47
+ tgt = tgt + self.dropout(tgt2)
48
+ tgt = self.norm(tgt)
49
+
50
+ return tgt
51
+
52
+ def forward_pre(self, tgt,
53
+ tgt_mask: Optional[Tensor] = None,
54
+ tgt_key_padding_mask: Optional[Tensor] = None,
55
+ query_pos: Optional[Tensor] = None):
56
+ tgt2 = self.norm(tgt)
57
+ q = k = self.with_pos_embed(tgt2, query_pos)
58
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
59
+ key_padding_mask=tgt_key_padding_mask)[0]
60
+ tgt = tgt + self.dropout(tgt2)
61
+
62
+ return tgt
63
+
64
+ def forward(self, tgt,
65
+ tgt_mask: Optional[Tensor] = None,
66
+ tgt_key_padding_mask: Optional[Tensor] = None,
67
+ query_pos: Optional[Tensor] = None):
68
+ if self.normalize_before:
69
+ return self.forward_pre(tgt, tgt_mask,
70
+ tgt_key_padding_mask, query_pos)
71
+ return self.forward_post(tgt, tgt_mask,
72
+ tgt_key_padding_mask, query_pos)
73
+
74
+
75
+ class CrossAttentionLayer(nn.Module):
76
+
77
+ def __init__(self, d_model, nhead, dropout=0.0,
78
+ activation="relu", normalize_before=False):
79
+ super().__init__()
80
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
81
+
82
+ self.norm = nn.LayerNorm(d_model)
83
+ self.dropout = nn.Dropout(dropout)
84
+
85
+ self.activation = _get_activation_fn(activation)
86
+ self.normalize_before = normalize_before
87
+
88
+ self._reset_parameters()
89
+
90
+ def _reset_parameters(self):
91
+ for p in self.parameters():
92
+ if p.dim() > 1:
93
+ nn.init.xavier_uniform_(p)
94
+
95
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
96
+ return tensor if pos is None else tensor + pos
97
+
98
+ def forward_post(self, tgt, memory,
99
+ memory_mask: Optional[Tensor] = None,
100
+ memory_key_padding_mask: Optional[Tensor] = None,
101
+ pos: Optional[Tensor] = None,
102
+ query_pos: Optional[Tensor] = None):
103
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
104
+ key=self.with_pos_embed(memory, pos),
105
+ value=memory, attn_mask=memory_mask,
106
+ key_padding_mask=memory_key_padding_mask)[0]
107
+ tgt = tgt + self.dropout(tgt2)
108
+ tgt = self.norm(tgt)
109
+
110
+ return tgt
111
+
112
+ def forward_pre(self, tgt, memory,
113
+ memory_mask: Optional[Tensor] = None,
114
+ memory_key_padding_mask: Optional[Tensor] = None,
115
+ pos: Optional[Tensor] = None,
116
+ query_pos: Optional[Tensor] = None):
117
+ tgt2 = self.norm(tgt)
118
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
119
+ key=self.with_pos_embed(memory, pos),
120
+ value=memory, attn_mask=memory_mask,
121
+ key_padding_mask=memory_key_padding_mask)[0]
122
+ tgt = tgt + self.dropout(tgt2)
123
+
124
+ return tgt
125
+
126
+ def forward(self, tgt, memory,
127
+ memory_mask: Optional[Tensor] = None,
128
+ memory_key_padding_mask: Optional[Tensor] = None,
129
+ pos: Optional[Tensor] = None,
130
+ query_pos: Optional[Tensor] = None):
131
+ if self.normalize_before:
132
+ return self.forward_pre(tgt, memory, memory_mask,
133
+ memory_key_padding_mask, pos, query_pos)
134
+ return self.forward_post(tgt, memory, memory_mask,
135
+ memory_key_padding_mask, pos, query_pos)
136
+
137
+
138
+ class FFNLayer(nn.Module):
139
+
140
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
141
+ activation="relu", normalize_before=False):
142
+ super().__init__()
143
+ # Implementation of Feedforward model
144
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
145
+ self.dropout = nn.Dropout(dropout)
146
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
147
+
148
+ self.norm = nn.LayerNorm(d_model)
149
+
150
+ self.activation = _get_activation_fn(activation)
151
+ self.normalize_before = normalize_before
152
+
153
+ self._reset_parameters()
154
+
155
+ def _reset_parameters(self):
156
+ for p in self.parameters():
157
+ if p.dim() > 1:
158
+ nn.init.xavier_uniform_(p)
159
+
160
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
161
+ return tensor if pos is None else tensor + pos
162
+
163
+ def forward_post(self, tgt):
164
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
165
+ tgt = tgt + self.dropout(tgt2)
166
+ tgt = self.norm(tgt)
167
+ return tgt
168
+
169
+ def forward_pre(self, tgt):
170
+ tgt2 = self.norm(tgt)
171
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
172
+ tgt = tgt + self.dropout(tgt2)
173
+ return tgt
174
+
175
+ def forward(self, tgt):
176
+ if self.normalize_before:
177
+ return self.forward_pre(tgt)
178
+ return self.forward_post(tgt)
179
+
180
+
181
+ def _get_activation_fn(activation):
182
+ """Return an activation function given a string"""
183
+ if activation == "relu":
184
+ return F.relu
185
+ if activation == "gelu":
186
+ return F.gelu
187
+ if activation == "glu":
188
+ return F.glu
189
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
190
+
191
+
192
+ class MLP(nn.Module):
193
+ """ Very simple multi-layer perceptron (also called FFN)"""
194
+
195
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
196
+ super().__init__()
197
+ self.num_layers = num_layers
198
+ h = [hidden_dim] * (num_layers - 1)
199
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
200
+
201
+ def forward(self, x):
202
+ for i, layer in enumerate(self.layers):
203
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
204
+ return x
205
+
206
+
207
+ @TRANSFORMER_DECODER_REGISTRY.register()
208
+ class MultiScaleMaskedTransformerDecoder(nn.Module):
209
+
210
+ _version = 2
211
+
212
+ def _load_from_state_dict(
213
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
214
+ ):
215
+ version = local_metadata.get("version", None)
216
+ if version is None or version < 2:
217
+ # Do not warn if train from scratch
218
+ scratch = True
219
+ logger = logging.getLogger(__name__)
220
+ for k in list(state_dict.keys()):
221
+ newk = k
222
+ if "static_query" in k:
223
+ newk = k.replace("static_query", "query_feat")
224
+ if newk != k:
225
+ state_dict[newk] = state_dict[k]
226
+ del state_dict[k]
227
+ scratch = False
228
+
229
+ if not scratch:
230
+ logger.warning(
231
+ f"Weight format of {self.__class__.__name__} have changed! "
232
+ "Please upgrade your models. Applying automatic conversion now ..."
233
+ )
234
+
235
+ @configurable
236
+ def __init__(
237
+ self,
238
+ in_channels,
239
+ mask_classification=True,
240
+ *,
241
+ num_classes: int,
242
+ hidden_dim: int,
243
+ num_queries: int,
244
+ nheads: int,
245
+ dim_feedforward: int,
246
+ dec_layers: int,
247
+ pre_norm: bool,
248
+ mask_dim: int,
249
+ enforce_input_project: bool,
250
+ ):
251
+ """
252
+ NOTE: this interface is experimental.
253
+ Args:
254
+ in_channels: channels of the input features
255
+ mask_classification: whether to add mask classifier or not
256
+ num_classes: number of classes
257
+ hidden_dim: Transformer feature dimension
258
+ num_queries: number of queries
259
+ nheads: number of heads
260
+ dim_feedforward: feature dimension in feedforward network
261
+ enc_layers: number of Transformer encoder layers
262
+ dec_layers: number of Transformer decoder layers
263
+ pre_norm: whether to use pre-LayerNorm or not
264
+ mask_dim: mask feature dimension
265
+ enforce_input_project: add input project 1x1 conv even if input
266
+ channels and hidden dim is identical
267
+ """
268
+ super().__init__()
269
+
270
+ assert mask_classification, "Only support mask classification model"
271
+ self.mask_classification = mask_classification
272
+
273
+ # positional encoding
274
+ N_steps = hidden_dim // 2
275
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
276
+
277
+ # define Transformer decoder here
278
+ self.num_heads = nheads
279
+ self.num_layers = dec_layers
280
+ self.transformer_self_attention_layers = nn.ModuleList()
281
+ self.transformer_cross_attention_layers = nn.ModuleList()
282
+ self.transformer_ffn_layers = nn.ModuleList()
283
+
284
+ for _ in range(self.num_layers):
285
+ self.transformer_self_attention_layers.append(
286
+ SelfAttentionLayer(
287
+ d_model=hidden_dim,
288
+ nhead=nheads,
289
+ dropout=0.0,
290
+ normalize_before=pre_norm,
291
+ )
292
+ )
293
+
294
+ self.transformer_cross_attention_layers.append(
295
+ CrossAttentionLayer(
296
+ d_model=hidden_dim,
297
+ nhead=nheads,
298
+ dropout=0.0,
299
+ normalize_before=pre_norm,
300
+ )
301
+ )
302
+
303
+ self.transformer_ffn_layers.append(
304
+ FFNLayer(
305
+ d_model=hidden_dim,
306
+ dim_feedforward=dim_feedforward,
307
+ dropout=0.0,
308
+ normalize_before=pre_norm,
309
+ )
310
+ )
311
+
312
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
313
+
314
+ self.num_queries = num_queries
315
+ # learnable query features
316
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
317
+ # learnable query p.e.
318
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
319
+
320
+ # level embedding (we always use 3 scales)
321
+ self.num_feature_levels = 3
322
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
323
+ self.input_proj = nn.ModuleList()
324
+ for _ in range(self.num_feature_levels):
325
+ if in_channels != hidden_dim or enforce_input_project:
326
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
327
+ weight_init.c2_xavier_fill(self.input_proj[-1])
328
+ else:
329
+ self.input_proj.append(nn.Sequential())
330
+
331
+ # output FFNs
332
+ if self.mask_classification:
333
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
334
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
335
+
336
+ @classmethod
337
+ def from_config(cls, cfg, in_channels, mask_classification):
338
+ ret = {}
339
+ ret["in_channels"] = in_channels
340
+ ret["mask_classification"] = mask_classification
341
+
342
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
343
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
344
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
345
+ # Transformer parameters:
346
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
347
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
348
+
349
+ # NOTE: because we add learnable query features which requires supervision,
350
+ # we add minus 1 to decoder layers to be consistent with our loss
351
+ # implementation: that is, number of auxiliary losses is always
352
+ # equal to number of decoder layers. With learnable query features, the number of
353
+ # auxiliary losses equals number of decoders plus 1.
354
+ assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1
355
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
356
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
357
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
358
+
359
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
360
+
361
+ return ret
362
+
363
+ def forward(self, x, mask_features, mask = None):
364
+ # x is a list of multi-scale feature
365
+ assert len(x) == self.num_feature_levels
366
+ src = []
367
+ pos = []
368
+ size_list = []
369
+
370
+ # disable mask, it does not affect performance
371
+ del mask
372
+
373
+ for i in range(self.num_feature_levels):
374
+ size_list.append(x[i].shape[-2:])
375
+ pos.append(self.pe_layer(x[i], None).flatten(2))
376
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
377
+
378
+ # flatten NxCxHxW to HWxNxC
379
+ pos[-1] = pos[-1].permute(2, 0, 1)
380
+ src[-1] = src[-1].permute(2, 0, 1)
381
+
382
+ _, bs, _ = src[0].shape
383
+
384
+ # QxNxC
385
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
386
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
387
+
388
+ predictions_class = []
389
+ predictions_mask = []
390
+
391
+ # prediction heads on learnable query features
392
+ outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
393
+ predictions_class.append(outputs_class)
394
+ predictions_mask.append(outputs_mask)
395
+
396
+ for i in range(self.num_layers):
397
+ level_index = i % self.num_feature_levels
398
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
399
+ # attention: cross-attention first
400
+ output = self.transformer_cross_attention_layers[i](
401
+ output, src[level_index],
402
+ memory_mask=attn_mask,
403
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
404
+ pos=pos[level_index], query_pos=query_embed
405
+ )
406
+
407
+ output = self.transformer_self_attention_layers[i](
408
+ output, tgt_mask=None,
409
+ tgt_key_padding_mask=None,
410
+ query_pos=query_embed
411
+ )
412
+
413
+ # FFN
414
+ output = self.transformer_ffn_layers[i](
415
+ output
416
+ )
417
+
418
+ outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
419
+ predictions_class.append(outputs_class)
420
+ predictions_mask.append(outputs_mask)
421
+
422
+ assert len(predictions_class) == self.num_layers + 1
423
+
424
+ out = {
425
+ 'pred_logits': predictions_class[-1],
426
+ 'pred_masks': predictions_mask[-1],
427
+ 'aux_outputs': self._set_aux_loss(
428
+ predictions_class if self.mask_classification else None, predictions_mask
429
+ )
430
+ }
431
+ return out
432
+
433
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):
434
+ decoder_output = self.decoder_norm(output)
435
+ decoder_output = decoder_output.transpose(0, 1)
436
+ outputs_class = self.class_embed(decoder_output)
437
+ mask_embed = self.mask_embed(decoder_output)
438
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
439
+
440
+ # NOTE: prediction is of higher-resolution
441
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
442
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
443
+ # must use bool type
444
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
445
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
446
+ attn_mask = attn_mask.detach()
447
+
448
+ return outputs_class, outputs_mask, attn_mask
449
+
450
+ @torch.jit.unused
451
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks):
452
+ # this is a workaround to make torchscript happy, as torchscript
453
+ # doesn't support dictionary with non-homogeneous values, such
454
+ # as a dict having both a Tensor and a list.
455
+ if self.mask_classification:
456
+ return [
457
+ {"pred_logits": a, "pred_masks": b}
458
+ for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
459
+ ]
460
+ else:
461
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
mask2former/modeling/transformer_decoder/maskformer_transformer_decoder.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ import fvcore.nn.weight_init as weight_init
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.layers import Conv2d
10
+ from detectron2.utils.registry import Registry
11
+
12
+ from .position_encoding import PositionEmbeddingSine
13
+ from .transformer import Transformer
14
+
15
+
16
+ TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE")
17
+ TRANSFORMER_DECODER_REGISTRY.__doc__ = """
18
+ Registry for transformer module in MaskFormer.
19
+ """
20
+
21
+
22
+ def build_transformer_decoder(cfg, in_channels, mask_classification=True):
23
+ """
24
+ Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`.
25
+ """
26
+ name = cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME
27
+ return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, in_channels, mask_classification)
28
+
29
+
30
+ @TRANSFORMER_DECODER_REGISTRY.register()
31
+ class StandardTransformerDecoder(nn.Module):
32
+ @configurable
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ mask_classification=True,
37
+ *,
38
+ num_classes: int,
39
+ hidden_dim: int,
40
+ num_queries: int,
41
+ nheads: int,
42
+ dropout: float,
43
+ dim_feedforward: int,
44
+ enc_layers: int,
45
+ dec_layers: int,
46
+ pre_norm: bool,
47
+ deep_supervision: bool,
48
+ mask_dim: int,
49
+ enforce_input_project: bool,
50
+ ):
51
+ """
52
+ NOTE: this interface is experimental.
53
+ Args:
54
+ in_channels: channels of the input features
55
+ mask_classification: whether to add mask classifier or not
56
+ num_classes: number of classes
57
+ hidden_dim: Transformer feature dimension
58
+ num_queries: number of queries
59
+ nheads: number of heads
60
+ dropout: dropout in Transformer
61
+ dim_feedforward: feature dimension in feedforward network
62
+ enc_layers: number of Transformer encoder layers
63
+ dec_layers: number of Transformer decoder layers
64
+ pre_norm: whether to use pre-LayerNorm or not
65
+ deep_supervision: whether to add supervision to every decoder layers
66
+ mask_dim: mask feature dimension
67
+ enforce_input_project: add input project 1x1 conv even if input
68
+ channels and hidden dim is identical
69
+ """
70
+ super().__init__()
71
+
72
+ self.mask_classification = mask_classification
73
+
74
+ # positional encoding
75
+ N_steps = hidden_dim // 2
76
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
77
+
78
+ transformer = Transformer(
79
+ d_model=hidden_dim,
80
+ dropout=dropout,
81
+ nhead=nheads,
82
+ dim_feedforward=dim_feedforward,
83
+ num_encoder_layers=enc_layers,
84
+ num_decoder_layers=dec_layers,
85
+ normalize_before=pre_norm,
86
+ return_intermediate_dec=deep_supervision,
87
+ )
88
+
89
+ self.num_queries = num_queries
90
+ self.transformer = transformer
91
+ hidden_dim = transformer.d_model
92
+
93
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
94
+
95
+ if in_channels != hidden_dim or enforce_input_project:
96
+ self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
97
+ weight_init.c2_xavier_fill(self.input_proj)
98
+ else:
99
+ self.input_proj = nn.Sequential()
100
+ self.aux_loss = deep_supervision
101
+
102
+ # output FFNs
103
+ if self.mask_classification:
104
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
105
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
106
+
107
+ @classmethod
108
+ def from_config(cls, cfg, in_channels, mask_classification):
109
+ ret = {}
110
+ ret["in_channels"] = in_channels
111
+ ret["mask_classification"] = mask_classification
112
+
113
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
114
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
115
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
116
+ # Transformer parameters:
117
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
118
+ ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
119
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
120
+ ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
121
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
122
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
123
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
124
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
125
+
126
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
127
+
128
+ return ret
129
+
130
+ def forward(self, x, mask_features, mask=None):
131
+ if mask is not None:
132
+ mask = F.interpolate(mask[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
133
+ pos = self.pe_layer(x, mask)
134
+
135
+ src = x
136
+ hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
137
+
138
+ if self.mask_classification:
139
+ outputs_class = self.class_embed(hs)
140
+ out = {"pred_logits": outputs_class[-1]}
141
+ else:
142
+ out = {}
143
+
144
+ if self.aux_loss:
145
+ # [l, bs, queries, embed]
146
+ mask_embed = self.mask_embed(hs)
147
+ outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
148
+ out["pred_masks"] = outputs_seg_masks[-1]
149
+ out["aux_outputs"] = self._set_aux_loss(
150
+ outputs_class if self.mask_classification else None, outputs_seg_masks
151
+ )
152
+ else:
153
+ # FIXME h_boxes takes the last one computed, keep this in mind
154
+ # [bs, queries, embed]
155
+ mask_embed = self.mask_embed(hs[-1])
156
+ outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
157
+ out["pred_masks"] = outputs_seg_masks
158
+ return out
159
+
160
+ @torch.jit.unused
161
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks):
162
+ # this is a workaround to make torchscript happy, as torchscript
163
+ # doesn't support dictionary with non-homogeneous values, such
164
+ # as a dict having both a Tensor and a list.
165
+ if self.mask_classification:
166
+ return [
167
+ {"pred_logits": a, "pred_masks": b}
168
+ for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
169
+ ]
170
+ else:
171
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
172
+
173
+
174
+ class MLP(nn.Module):
175
+ """Very simple multi-layer perceptron (also called FFN)"""
176
+
177
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
178
+ super().__init__()
179
+ self.num_layers = num_layers
180
+ h = [hidden_dim] * (num_layers - 1)
181
+ self.layers = nn.ModuleList(
182
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
183
+ )
184
+
185
+ def forward(self, x):
186
+ for i, layer in enumerate(self.layers):
187
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
188
+ return x
mask2former/modeling/transformer_decoder/opd_transformer_decoder.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ import logging
4
+ import fvcore.nn.weight_init as weight_init
5
+ from typing import Optional
6
+ import torch
7
+ from torch import nn, Tensor
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d
12
+
13
+ from .position_encoding import PositionEmbeddingSine
14
+ from .maskformer_transformer_decoder import TRANSFORMER_DECODER_REGISTRY
15
+ from .mask2former_transformer_decoder import (
16
+ SelfAttentionLayer,
17
+ CrossAttentionLayer,
18
+ FFNLayer,
19
+ MLP,
20
+ )
21
+ from ..criterion import convert_to_filled_tensor
22
+
23
+
24
+ @TRANSFORMER_DECODER_REGISTRY.register()
25
+ class OPDMultiScaleMaskedTransformerDecoder(nn.Module):
26
+
27
+ _version = 2
28
+
29
+ def _load_from_state_dict(
30
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
31
+ ):
32
+ version = local_metadata.get("version", None)
33
+ if version is None or version < 2:
34
+ # Do not warn if train from scratch
35
+ scratch = True
36
+ logger = logging.getLogger(__name__)
37
+ for k in list(state_dict.keys()):
38
+ newk = k
39
+ if "static_query" in k:
40
+ newk = k.replace("static_query", "query_feat")
41
+ if newk != k:
42
+ state_dict[newk] = state_dict[k]
43
+ del state_dict[k]
44
+ scratch = False
45
+
46
+ if not scratch:
47
+ logger.warning(
48
+ f"Weight format of {self.__class__.__name__} have changed! "
49
+ "Please upgrade your models. Applying automatic conversion now ..."
50
+ )
51
+
52
+ @configurable
53
+ def __init__(
54
+ self,
55
+ in_channels,
56
+ mask_classification=True,
57
+ *,
58
+ num_classes: int,
59
+ hidden_dim: int,
60
+ num_queries: int,
61
+ nheads: int,
62
+ dim_feedforward: int,
63
+ dec_layers: int,
64
+ pre_norm: bool,
65
+ mask_dim: int,
66
+ enforce_input_project: bool,
67
+ # OPD
68
+ motionnet_type,
69
+ obj_method
70
+ ):
71
+ """
72
+ NOTE: this interface is experimental.
73
+ Args:
74
+ in_channels: channels of the input features
75
+ mask_classification: whether to add mask classifier or not
76
+ num_classes: number of classes
77
+ hidden_dim: Transformer feature dimension
78
+ num_queries: number of queries
79
+ nheads: number of heads
80
+ dim_feedforward: feature dimension in feedforward network
81
+ enc_layers: number of Transformer encoder layers
82
+ dec_layers: number of Transformer decoder layers
83
+ pre_norm: whether to use pre-LayerNorm or not
84
+ mask_dim: mask feature dimension
85
+ enforce_input_project: add input project 1x1 conv even if input
86
+ channels and hidden dim is identical
87
+ """
88
+ super().__init__()
89
+
90
+ # OPD
91
+ self.motionnet_type = motionnet_type
92
+ self.num_classes = num_classes
93
+ self.obj_method = obj_method
94
+
95
+ assert mask_classification, "Only support mask classification model"
96
+ self.mask_classification = mask_classification
97
+
98
+ # positional encoding
99
+ N_steps = hidden_dim // 2
100
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
101
+
102
+ # define Transformer decoder here
103
+ self.num_heads = nheads
104
+ self.num_layers = dec_layers
105
+ self.transformer_self_attention_layers = nn.ModuleList()
106
+ self.transformer_cross_attention_layers = nn.ModuleList()
107
+ self.transformer_ffn_layers = nn.ModuleList()
108
+
109
+ for _ in range(self.num_layers):
110
+ self.transformer_self_attention_layers.append(
111
+ SelfAttentionLayer(
112
+ d_model=hidden_dim,
113
+ nhead=nheads,
114
+ dropout=0.0,
115
+ normalize_before=pre_norm,
116
+ )
117
+ )
118
+
119
+ self.transformer_cross_attention_layers.append(
120
+ CrossAttentionLayer(
121
+ d_model=hidden_dim,
122
+ nhead=nheads,
123
+ dropout=0.0,
124
+ normalize_before=pre_norm,
125
+ )
126
+ )
127
+
128
+ self.transformer_ffn_layers.append(
129
+ FFNLayer(
130
+ d_model=hidden_dim,
131
+ dim_feedforward=dim_feedforward,
132
+ dropout=0.0,
133
+ normalize_before=pre_norm,
134
+ )
135
+ )
136
+
137
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
138
+
139
+ self.num_queries = num_queries
140
+ # learnable query features
141
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
142
+ # learnable query p.e.
143
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
144
+
145
+ # level embedding (we always use 3 scales)
146
+ self.num_feature_levels = 3
147
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
148
+ self.input_proj = nn.ModuleList()
149
+ for _ in range(self.num_feature_levels):
150
+ if in_channels != hidden_dim or enforce_input_project:
151
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
152
+ weight_init.c2_xavier_fill(self.input_proj[-1])
153
+ else:
154
+ self.input_proj.append(nn.Sequential())
155
+
156
+ # output FFNs
157
+ if self.mask_classification:
158
+ self.class_embed = nn.Sequential(
159
+ nn.Linear(hidden_dim, 32),
160
+ nn.ReLU(inplace=True),
161
+ nn.Linear(32, num_classes + 1),
162
+ )
163
+ # OPD Changes
164
+ self.mtype_embed = nn.Sequential(
165
+ nn.Linear(hidden_dim, 32),
166
+ nn.ReLU(inplace=True),
167
+ nn.Linear(32, 2),
168
+ )
169
+ self.morigin_embed = nn.Sequential(
170
+ nn.Linear(hidden_dim, 32),
171
+ nn.ReLU(inplace=True),
172
+ nn.Linear(32, 3),
173
+ )
174
+ self.maxis_embed = nn.Sequential(
175
+ nn.Linear(hidden_dim, 32),
176
+ nn.ReLU(inplace=True),
177
+ nn.Linear(32, 3),
178
+ )
179
+ self.mstate_embed = nn.Sequential(
180
+ nn.Linear(hidden_dim, 32),
181
+ nn.ReLU(inplace=True),
182
+ nn.Linear(32, 1),
183
+ )
184
+ self.mstatemax_embed = nn.Sequential(
185
+ nn.Linear(hidden_dim, 32),
186
+ nn.ReLU(inplace=True),
187
+ nn.Linear(32, 1),
188
+ )
189
+ if self.motionnet_type == "BMOC_V0":
190
+ # Define the layers for the extrinsic prediction
191
+ self.extrinsic_feature_layer = nn.Sequential(
192
+ # 16 * 256 * 64 * 64
193
+ nn.Conv2d(256, 256, 3, 2, 1), # 16 * 256 * 32 * 32
194
+ nn.BatchNorm2d(256),
195
+ nn.ReLU(inplace=True),
196
+ nn.MaxPool2d(2, 2), # 16 * 256 * 16 * 16
197
+ nn.Conv2d(256, 256, 3, 2, 1), # 16 * 256 * 8 * 8
198
+ nn.BatchNorm2d(256),
199
+ nn.ReLU(inplace=True),
200
+ nn.MaxPool2d(2, 2), # 16 * 256 * 4 * 4
201
+ nn.Conv2d(256, 64, 1), # 16 * 64 * 4 * 4
202
+ nn.BatchNorm2d(64),
203
+ nn.ReLU(inplace=True),
204
+ nn.Flatten() # 16 * 1024
205
+ )
206
+ for layer in self.extrinsic_feature_layer:
207
+ if isinstance(layer, nn.Conv2d):
208
+ nn.init.kaiming_normal_(
209
+ layer.weight, mode="fan_out", nonlinearity="relu"
210
+ )
211
+ self.extrinsic_pred_layer = nn.Sequential(
212
+ nn.Linear(768, 512),
213
+ # nn.Linear(768, 512),
214
+ nn.ReLU(inplace=True),
215
+ nn.Linear(512, 128),
216
+ nn.ReLU(inplace=True),
217
+ nn.Linear(128, 32),
218
+ nn.ReLU(inplace=True),
219
+ nn.Linear(32, 12), # 16 * 12
220
+ )
221
+ elif self.motionnet_type == "BMOC_V1":
222
+ self.extrinsic_embed = nn.Sequential(
223
+ nn.Linear(hidden_dim, 32),
224
+ nn.ReLU(inplace=True),
225
+ nn.Linear(32, 12),
226
+ )
227
+ elif self.motionnet_type == "BMOC_V2":
228
+ self.extrinsic_embed = nn.Sequential(
229
+ nn.Linear(hidden_dim, 32),
230
+ nn.ReLU(inplace=True),
231
+ nn.Linear(32, 7),
232
+ )
233
+ elif self.motionnet_type == "BMOC_V3":
234
+ self.extrinsic_embed = nn.Sequential(
235
+ nn.Linear(hidden_dim, 32),
236
+ nn.ReLU(inplace=True),
237
+ nn.Linear(32, 9),
238
+ )
239
+ elif self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6":
240
+ if self.motionnet_type == "BMOC_V5":
241
+ self.mask_weight_layer = SelfAttentionLayer(
242
+ d_model=hidden_dim,
243
+ nhead=nheads,
244
+ dropout=0.0,
245
+ normalize_before=pre_norm,
246
+ )
247
+ # Define the layers for the extrinsic prediction
248
+ self.extrinsic_feature_layer = nn.Sequential(
249
+ nn.BatchNorm2d(256),
250
+ # 16 * 256 * 64 * 64
251
+ nn.Conv2d(256, 256, 3, 2, 1), # 16 * 256 * 32 * 32
252
+ nn.BatchNorm2d(256),
253
+ nn.ReLU(inplace=True),
254
+ nn.MaxPool2d(2, 2), # 16 * 256 * 16 * 16
255
+ nn.Conv2d(256, 256, 3, 2, 1), # 16 * 256 * 8 * 8
256
+ nn.BatchNorm2d(256),
257
+ nn.ReLU(inplace=True),
258
+ nn.MaxPool2d(2, 2), # 16 * 256 * 4 * 4
259
+ nn.Conv2d(256, 64, 1), # 16 * 64 * 4 * 4
260
+ nn.BatchNorm2d(64),
261
+ nn.ReLU(inplace=True),
262
+ nn.Flatten() # 16 * 1024
263
+ )
264
+ for layer in self.extrinsic_feature_layer:
265
+ if isinstance(layer, nn.Conv2d):
266
+ nn.init.kaiming_normal_(
267
+ layer.weight, mode="fan_out", nonlinearity="relu"
268
+ )
269
+ if self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5":
270
+ self.extrinsic_pred_layer = nn.Sequential(
271
+ nn.Linear(1024, 512),
272
+ nn.ReLU(inplace=True),
273
+ nn.Linear(512, 128),
274
+ nn.ReLU(inplace=True),
275
+ nn.Linear(128, 32),
276
+ nn.ReLU(inplace=True),
277
+ nn.Linear(32, 7), # 16 * 7
278
+ )
279
+ elif self.motionnet_type == "BMOC_V6":
280
+ self.extrinsic_pred_layer = nn.Sequential(
281
+ # nn.Linear(1024, 512),
282
+ nn.Linear(768, 512),
283
+ nn.ReLU(inplace=True),
284
+ nn.Linear(512, 128),
285
+ nn.ReLU(inplace=True),
286
+ nn.Linear(128, 32),
287
+ nn.ReLU(inplace=True),
288
+ nn.Linear(32, 12), # 16 * 12
289
+ )
290
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
291
+
292
+ @classmethod
293
+ def from_config(cls, cfg, in_channels, mask_classification):
294
+ ret = {}
295
+ ret["in_channels"] = in_channels
296
+ ret["mask_classification"] = mask_classification
297
+
298
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
299
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
300
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
301
+ # Transformer parameters:
302
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
303
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
304
+
305
+ # NOTE: because we add learnable query features which requires supervision,
306
+ # we add minus 1 to decoder layers to be consistent with our loss
307
+ # implementation: that is, number of auxiliary losses is always
308
+ # equal to number of decoder layers. With learnable query features, the number of
309
+ # auxiliary losses equals number of decoders plus 1.
310
+ assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1
311
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
312
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
313
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
314
+
315
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
316
+
317
+ # OPD
318
+ ret["motionnet_type"] = cfg.MODEL.MOTIONNET.TYPE
319
+
320
+ ret['obj_method'] = cfg.OBJ_DETECT
321
+
322
+ return ret
323
+
324
+ def forward(self, x, mask_features, mask = None):
325
+ # x is a list of multi-scale feature
326
+ assert len(x) == self.num_feature_levels
327
+ src = []
328
+ pos = []
329
+ size_list = []
330
+
331
+ # disable mask, it does not affect performance
332
+ # if not self.obj_method:
333
+ # del mask
334
+ # import pdb
335
+ # pdb.set_trace()
336
+
337
+ for i in range(self.num_feature_levels):
338
+ size_list.append(x[i].shape[-2:])
339
+ pos.append(self.pe_layer(x[i], None).flatten(2))
340
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
341
+
342
+ # flatten NxCxHxW to HWxNxC
343
+ pos[-1] = pos[-1].permute(2, 0, 1)
344
+ src[-1] = src[-1].permute(2, 0, 1)
345
+
346
+ _, bs, _ = src[0].shape
347
+
348
+ # QxNxC
349
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
350
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
351
+
352
+ predictions_class = []
353
+ predictions_mask = []
354
+ # OPD
355
+ predictions_mtype = []
356
+ predictions_morigin = []
357
+ predictions_maxis = []
358
+ predictions_mstate = []
359
+ predictions_mstatemax = []
360
+
361
+ if self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3" or self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6":
362
+ predictions_extrinsic = []
363
+
364
+
365
+ # prediction heads on learnable query features
366
+ outputs_class, outputs_mask, attn_mask, outputs_mtype, outputs_morigin, outputs_maxis, outputs_extrinsic, outputs_mstate, outputs_mstatemax = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], query_embed=query_embed, mask=mask)
367
+ predictions_class.append(outputs_class)
368
+ predictions_mask.append(outputs_mask)
369
+ # OPD
370
+ predictions_mtype.append(outputs_mtype)
371
+ predictions_morigin.append(outputs_morigin)
372
+ predictions_maxis.append(outputs_maxis)
373
+ predictions_mstate.append(outputs_mstate)
374
+ predictions_mstatemax.append(outputs_mstatemax)
375
+
376
+ if self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3" or self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6":
377
+ predictions_extrinsic.append(outputs_extrinsic)
378
+
379
+ for i in range(self.num_layers):
380
+ level_index = i % self.num_feature_levels
381
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
382
+ # attention: cross-attention first
383
+ output = self.transformer_cross_attention_layers[i](
384
+ output, src[level_index],
385
+ memory_mask=attn_mask,
386
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
387
+ pos=pos[level_index], query_pos=query_embed
388
+ )
389
+
390
+ output = self.transformer_self_attention_layers[i](
391
+ output, tgt_mask=None,
392
+ tgt_key_padding_mask=None,
393
+ query_pos=query_embed
394
+ )
395
+
396
+ # FFN
397
+ output = self.transformer_ffn_layers[i](
398
+ output
399
+ )
400
+
401
+ outputs_class, outputs_mask, attn_mask, outputs_mtype, outputs_morigin, outputs_maxis, outputs_extrinsic, outputs_mstate, outputs_mstatemax = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], query_embed=query_embed)
402
+ predictions_class.append(outputs_class)
403
+ predictions_mask.append(outputs_mask)
404
+ # OPD
405
+ predictions_mtype.append(outputs_mtype)
406
+ predictions_morigin.append(outputs_morigin)
407
+ predictions_maxis.append(outputs_maxis)
408
+ predictions_mstate.append(outputs_mstate)
409
+ predictions_mstatemax.append(outputs_mstatemax)
410
+
411
+ if self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3" or self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6":
412
+ predictions_extrinsic.append(outputs_extrinsic)
413
+
414
+ assert len(predictions_class) == self.num_layers + 1
415
+ if self.mask_classification:
416
+ if self.motionnet_type == "BMOC_V0" or self.motionnet_type == "BMCC":
417
+ aux_outputs = self._set_aux_loss(
418
+ predictions_class, predictions_mask, predictions_mtype, predictions_morigin, predictions_maxis, None, predictions_mstate, predictions_mstatemax
419
+ )
420
+ elif self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3" or self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6":
421
+ aux_outputs = self._set_aux_loss(
422
+ predictions_class, predictions_mask, predictions_mtype, predictions_morigin, predictions_maxis, predictions_extrinsic, predictions_mstate, predictions_mstatemax
423
+ )
424
+
425
+ else:
426
+ aux_outputs = self._set_aux_loss(
427
+ None, predictions_mask, None, None, None, None, None
428
+ )
429
+ # OPD
430
+ if self.motionnet_type == "BMOC_V0":
431
+ extrinsic_feature = self.extrinsic_feature_layer(mask_features)
432
+ predictions_extrinsic = self.extrinsic_pred_layer(extrinsic_feature)
433
+
434
+ out = {
435
+ 'pred_logits': predictions_class[-1],
436
+ 'pred_masks': predictions_mask[-1],
437
+ # OPD
438
+ 'pred_mtypes': predictions_mtype[-1],
439
+ 'pred_morigins': predictions_morigin[-1],
440
+ 'pred_maxises': predictions_maxis[-1],
441
+ 'aux_outputs': aux_outputs,
442
+ 'pred_mstates': predictions_mstate[-1],
443
+ 'pred_mstatemaxs': predictions_mstatemax[-1],
444
+ }
445
+ if self.motionnet_type == "BMOC_V0":
446
+ out['pred_extrinsics'] = predictions_extrinsic
447
+ elif self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3" or self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6":
448
+ out['pred_extrinsics'] = predictions_extrinsic[-1]
449
+
450
+ return out
451
+
452
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, query_embed, mask = None):
453
+ decoder_output = self.decoder_norm(output)
454
+ decoder_output = decoder_output.transpose(0, 1)
455
+ outputs_class = self.class_embed(decoder_output)
456
+ # OPD Changes
457
+ outputs_mtype = self.mtype_embed(decoder_output)
458
+ outputs_morigin = self.morigin_embed(decoder_output)
459
+ outputs_maxis = self.maxis_embed(decoder_output)
460
+ outputs_mstate = self.mstate_embed(decoder_output)
461
+ outputs_mstatemax = self.mstatemax_embed(decoder_output)
462
+
463
+ if self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3":
464
+ outputs_extrinsic = self.extrinsic_embed(decoder_output)
465
+ elif self.motionnet_type == "BMOC_V0" or self.motionnet_type == "BMCC":
466
+ outputs_extrinsic = None
467
+
468
+ mask_embed = self.mask_embed(decoder_output)
469
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
470
+
471
+ # import pdb
472
+ # pdb.set_trace()
473
+ # TODO: Add different variants of using object mask to get the extrinsic
474
+
475
+ if self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V6":
476
+ binary_mask = (outputs_mask > 0).float()
477
+ weighted_masked_feature = mask_features + torch.einsum("bqhw,bchw->bchw", binary_mask, mask_features)
478
+ extrinsic_feature = self.extrinsic_feature_layer(weighted_masked_feature)
479
+ outputs_extrinsic = self.extrinsic_pred_layer(extrinsic_feature)
480
+ elif self.motionnet_type == "BMOC_V5":
481
+ # Get one weight for each query
482
+ mask_weights = torch.transpose(self.mask_weight_layer(
483
+ torch.transpose(mask_embed, 0, 1), tgt_mask=None,
484
+ tgt_key_padding_mask=None,
485
+ query_pos=query_embed
486
+ ), 0, 1).mean(2)
487
+ binary_mask = (outputs_mask > 0).float()
488
+ weighted_mask = torch.einsum("bq,bqhw->bqhw", mask_weights, binary_mask)
489
+ weighted_masked_feature = mask_features + torch.einsum("bqhw,bchw->bchw", weighted_mask, mask_features)
490
+ extrinsic_feature = self.extrinsic_feature_layer(weighted_masked_feature)
491
+ outputs_extrinsic = self.extrinsic_pred_layer(extrinsic_feature)
492
+
493
+ # NOTE: prediction is of higher-resolution
494
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
495
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
496
+ # must use bool type
497
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
498
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
499
+ attn_mask = attn_mask.detach()
500
+
501
+ return outputs_class, outputs_mask, attn_mask, outputs_mtype, outputs_morigin, outputs_maxis, outputs_extrinsic, outputs_mstate, outputs_mstatemax
502
+
503
+ @torch.jit.unused
504
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks, predictions_mtype, predictions_morigin, predictions_maxis, predictions_extrinsic, predictions_mstate, predictions_mstatemax):
505
+ # this is a workaround to make torchscript happy, as torchscript
506
+ # doesn't support dictionary with non-homogeneous values, such
507
+ # as a dict having both a Tensor and a list.
508
+ if self.mask_classification:
509
+ if self.motionnet_type == "BMOC_V0" or self.motionnet_type == "BMCC":
510
+ return [
511
+ {"pred_logits": a, "pred_masks": b, "pred_mtypes": c, "pred_morigins": d, "pred_maxises": e, "pred_mstates": f, "pred_mstatemaxs": g}
512
+ for a, b, c, d, e, f, g in zip(outputs_class[:-1], outputs_seg_masks[:-1], predictions_mtype[:-1], predictions_morigin[:-1], predictions_maxis[:-1], predictions_mstate[:-1], predictions_mstatemax[:-1])
513
+ ]
514
+ elif self.motionnet_type == "BMOC_V1" or self.motionnet_type == "BMOC_V2" or self.motionnet_type == "BMOC_V3" or self.motionnet_type == "BMOC_V4" or self.motionnet_type == "BMOC_V5" or self.motionnet_type == "BMOC_V6":
515
+ return [
516
+ {"pred_logits": a, "pred_masks": b, "pred_mtypes": c, "pred_morigins": d, "pred_maxises": e, "pred_extrinsics": f, "pred_mstates": g, "pred_mstatemaxs": h}
517
+ for a, b, c, d, e, f, g, h in zip(outputs_class[:-1], outputs_seg_masks[:-1], predictions_mtype[:-1], predictions_morigin[:-1], predictions_maxis[:-1], predictions_extrinsic[:-1], predictions_mstate[:-1], predictions_mstatemax[:-1])
518
+ ]
519
+ else:
520
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
mask2former/modeling/transformer_decoder/position_encoding.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
3
+ """
4
+ Various positional encodings for the transformer.
5
+ """
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+
18
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19
+ super().__init__()
20
+ self.num_pos_feats = num_pos_feats
21
+ self.temperature = temperature
22
+ self.normalize = normalize
23
+ if scale is not None and normalize is False:
24
+ raise ValueError("normalize should be True if scale is passed")
25
+ if scale is None:
26
+ scale = 2 * math.pi
27
+ self.scale = scale
28
+
29
+ def forward(self, x, mask=None):
30
+ if mask is None:
31
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+
43
+ pos_x = x_embed[:, :, :, None] / dim_t
44
+ pos_y = y_embed[:, :, :, None] / dim_t
45
+ pos_x = torch.stack(
46
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
47
+ ).flatten(3)
48
+ pos_y = torch.stack(
49
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
50
+ ).flatten(3)
51
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
52
+ return pos
53
+
54
+ def __repr__(self, _repr_indent=4):
55
+ head = "Positional encoding " + self.__class__.__name__
56
+ body = [
57
+ "num_pos_feats: {}".format(self.num_pos_feats),
58
+ "temperature: {}".format(self.temperature),
59
+ "normalize: {}".format(self.normalize),
60
+ "scale: {}".format(self.scale),
61
+ ]
62
+ # _repr_indent = 4
63
+ lines = [head] + [" " * _repr_indent + line for line in body]
64
+ return "\n".join(lines)
mask2former/modeling/transformer_decoder/transformer.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
3
+ """
4
+ Transformer class.
5
+
6
+ Copy-paste from torch.nn.Transformer with modifications:
7
+ * positional encodings are passed in MHattention
8
+ * extra LN at the end of encoder is removed
9
+ * decoder returns a stack of activations from all decoding layers
10
+ """
11
+ import copy
12
+ from typing import List, Optional
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import Tensor, nn
17
+
18
+
19
+ class Transformer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ d_model=512,
23
+ nhead=8,
24
+ num_encoder_layers=6,
25
+ num_decoder_layers=6,
26
+ dim_feedforward=2048,
27
+ dropout=0.1,
28
+ activation="relu",
29
+ normalize_before=False,
30
+ return_intermediate_dec=False,
31
+ ):
32
+ super().__init__()
33
+
34
+ encoder_layer = TransformerEncoderLayer(
35
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
36
+ )
37
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
38
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
39
+
40
+ decoder_layer = TransformerDecoderLayer(
41
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
42
+ )
43
+ decoder_norm = nn.LayerNorm(d_model)
44
+ self.decoder = TransformerDecoder(
45
+ decoder_layer,
46
+ num_decoder_layers,
47
+ decoder_norm,
48
+ return_intermediate=return_intermediate_dec,
49
+ )
50
+
51
+ self._reset_parameters()
52
+
53
+ self.d_model = d_model
54
+ self.nhead = nhead
55
+
56
+ def _reset_parameters(self):
57
+ for p in self.parameters():
58
+ if p.dim() > 1:
59
+ nn.init.xavier_uniform_(p)
60
+
61
+ def forward(self, src, mask, query_embed, pos_embed):
62
+ # flatten NxCxHxW to HWxNxC
63
+ bs, c, h, w = src.shape
64
+ src = src.flatten(2).permute(2, 0, 1)
65
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
66
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
67
+ if mask is not None:
68
+ mask = mask.flatten(1)
69
+
70
+ tgt = torch.zeros_like(query_embed)
71
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
72
+ hs = self.decoder(
73
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
74
+ )
75
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
76
+
77
+
78
+ class TransformerEncoder(nn.Module):
79
+ def __init__(self, encoder_layer, num_layers, norm=None):
80
+ super().__init__()
81
+ self.layers = _get_clones(encoder_layer, num_layers)
82
+ self.num_layers = num_layers
83
+ self.norm = norm
84
+
85
+ def forward(
86
+ self,
87
+ src,
88
+ mask: Optional[Tensor] = None,
89
+ src_key_padding_mask: Optional[Tensor] = None,
90
+ pos: Optional[Tensor] = None,
91
+ ):
92
+ output = src
93
+
94
+ for layer in self.layers:
95
+ output = layer(
96
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
97
+ )
98
+
99
+ if self.norm is not None:
100
+ output = self.norm(output)
101
+
102
+ return output
103
+
104
+
105
+ class TransformerDecoder(nn.Module):
106
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
107
+ super().__init__()
108
+ self.layers = _get_clones(decoder_layer, num_layers)
109
+ self.num_layers = num_layers
110
+ self.norm = norm
111
+ self.return_intermediate = return_intermediate
112
+
113
+ def forward(
114
+ self,
115
+ tgt,
116
+ memory,
117
+ tgt_mask: Optional[Tensor] = None,
118
+ memory_mask: Optional[Tensor] = None,
119
+ tgt_key_padding_mask: Optional[Tensor] = None,
120
+ memory_key_padding_mask: Optional[Tensor] = None,
121
+ pos: Optional[Tensor] = None,
122
+ query_pos: Optional[Tensor] = None,
123
+ ):
124
+ output = tgt
125
+
126
+ intermediate = []
127
+
128
+ for layer in self.layers:
129
+ output = layer(
130
+ output,
131
+ memory,
132
+ tgt_mask=tgt_mask,
133
+ memory_mask=memory_mask,
134
+ tgt_key_padding_mask=tgt_key_padding_mask,
135
+ memory_key_padding_mask=memory_key_padding_mask,
136
+ pos=pos,
137
+ query_pos=query_pos,
138
+ )
139
+ if self.return_intermediate:
140
+ intermediate.append(self.norm(output))
141
+
142
+ if self.norm is not None:
143
+ output = self.norm(output)
144
+ if self.return_intermediate:
145
+ intermediate.pop()
146
+ intermediate.append(output)
147
+
148
+ if self.return_intermediate:
149
+ return torch.stack(intermediate)
150
+
151
+ return output.unsqueeze(0)
152
+
153
+
154
+ class TransformerEncoderLayer(nn.Module):
155
+ def __init__(
156
+ self,
157
+ d_model,
158
+ nhead,
159
+ dim_feedforward=2048,
160
+ dropout=0.1,
161
+ activation="relu",
162
+ normalize_before=False,
163
+ ):
164
+ super().__init__()
165
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
166
+ # Implementation of Feedforward model
167
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
168
+ self.dropout = nn.Dropout(dropout)
169
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
170
+
171
+ self.norm1 = nn.LayerNorm(d_model)
172
+ self.norm2 = nn.LayerNorm(d_model)
173
+ self.dropout1 = nn.Dropout(dropout)
174
+ self.dropout2 = nn.Dropout(dropout)
175
+
176
+ self.activation = _get_activation_fn(activation)
177
+ self.normalize_before = normalize_before
178
+
179
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
180
+ return tensor if pos is None else tensor + pos
181
+
182
+ def forward_post(
183
+ self,
184
+ src,
185
+ src_mask: Optional[Tensor] = None,
186
+ src_key_padding_mask: Optional[Tensor] = None,
187
+ pos: Optional[Tensor] = None,
188
+ ):
189
+ q = k = self.with_pos_embed(src, pos)
190
+ src2 = self.self_attn(
191
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
192
+ )[0]
193
+ src = src + self.dropout1(src2)
194
+ src = self.norm1(src)
195
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
196
+ src = src + self.dropout2(src2)
197
+ src = self.norm2(src)
198
+ return src
199
+
200
+ def forward_pre(
201
+ self,
202
+ src,
203
+ src_mask: Optional[Tensor] = None,
204
+ src_key_padding_mask: Optional[Tensor] = None,
205
+ pos: Optional[Tensor] = None,
206
+ ):
207
+ src2 = self.norm1(src)
208
+ q = k = self.with_pos_embed(src2, pos)
209
+ src2 = self.self_attn(
210
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
211
+ )[0]
212
+ src = src + self.dropout1(src2)
213
+ src2 = self.norm2(src)
214
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
215
+ src = src + self.dropout2(src2)
216
+ return src
217
+
218
+ def forward(
219
+ self,
220
+ src,
221
+ src_mask: Optional[Tensor] = None,
222
+ src_key_padding_mask: Optional[Tensor] = None,
223
+ pos: Optional[Tensor] = None,
224
+ ):
225
+ if self.normalize_before:
226
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
227
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
228
+
229
+
230
+ class TransformerDecoderLayer(nn.Module):
231
+ def __init__(
232
+ self,
233
+ d_model,
234
+ nhead,
235
+ dim_feedforward=2048,
236
+ dropout=0.1,
237
+ activation="relu",
238
+ normalize_before=False,
239
+ ):
240
+ super().__init__()
241
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
242
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
243
+ # Implementation of Feedforward model
244
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
245
+ self.dropout = nn.Dropout(dropout)
246
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
247
+
248
+ self.norm1 = nn.LayerNorm(d_model)
249
+ self.norm2 = nn.LayerNorm(d_model)
250
+ self.norm3 = nn.LayerNorm(d_model)
251
+ self.dropout1 = nn.Dropout(dropout)
252
+ self.dropout2 = nn.Dropout(dropout)
253
+ self.dropout3 = nn.Dropout(dropout)
254
+
255
+ self.activation = _get_activation_fn(activation)
256
+ self.normalize_before = normalize_before
257
+
258
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
259
+ return tensor if pos is None else tensor + pos
260
+
261
+ def forward_post(
262
+ self,
263
+ tgt,
264
+ memory,
265
+ tgt_mask: Optional[Tensor] = None,
266
+ memory_mask: Optional[Tensor] = None,
267
+ tgt_key_padding_mask: Optional[Tensor] = None,
268
+ memory_key_padding_mask: Optional[Tensor] = None,
269
+ pos: Optional[Tensor] = None,
270
+ query_pos: Optional[Tensor] = None,
271
+ ):
272
+ q = k = self.with_pos_embed(tgt, query_pos)
273
+ tgt2 = self.self_attn(
274
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
275
+ )[0]
276
+ tgt = tgt + self.dropout1(tgt2)
277
+ tgt = self.norm1(tgt)
278
+ tgt2 = self.multihead_attn(
279
+ query=self.with_pos_embed(tgt, query_pos),
280
+ key=self.with_pos_embed(memory, pos),
281
+ value=memory,
282
+ attn_mask=memory_mask,
283
+ key_padding_mask=memory_key_padding_mask,
284
+ )[0]
285
+ tgt = tgt + self.dropout2(tgt2)
286
+ tgt = self.norm2(tgt)
287
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
288
+ tgt = tgt + self.dropout3(tgt2)
289
+ tgt = self.norm3(tgt)
290
+ return tgt
291
+
292
+ def forward_pre(
293
+ self,
294
+ tgt,
295
+ memory,
296
+ tgt_mask: Optional[Tensor] = None,
297
+ memory_mask: Optional[Tensor] = None,
298
+ tgt_key_padding_mask: Optional[Tensor] = None,
299
+ memory_key_padding_mask: Optional[Tensor] = None,
300
+ pos: Optional[Tensor] = None,
301
+ query_pos: Optional[Tensor] = None,
302
+ ):
303
+ tgt2 = self.norm1(tgt)
304
+ q = k = self.with_pos_embed(tgt2, query_pos)
305
+ tgt2 = self.self_attn(
306
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
307
+ )[0]
308
+ tgt = tgt + self.dropout1(tgt2)
309
+ tgt2 = self.norm2(tgt)
310
+ tgt2 = self.multihead_attn(
311
+ query=self.with_pos_embed(tgt2, query_pos),
312
+ key=self.with_pos_embed(memory, pos),
313
+ value=memory,
314
+ attn_mask=memory_mask,
315
+ key_padding_mask=memory_key_padding_mask,
316
+ )[0]
317
+ tgt = tgt + self.dropout2(tgt2)
318
+ tgt2 = self.norm3(tgt)
319
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
320
+ tgt = tgt + self.dropout3(tgt2)
321
+ return tgt
322
+
323
+ def forward(
324
+ self,
325
+ tgt,
326
+ memory,
327
+ tgt_mask: Optional[Tensor] = None,
328
+ memory_mask: Optional[Tensor] = None,
329
+ tgt_key_padding_mask: Optional[Tensor] = None,
330
+ memory_key_padding_mask: Optional[Tensor] = None,
331
+ pos: Optional[Tensor] = None,
332
+ query_pos: Optional[Tensor] = None,
333
+ ):
334
+ if self.normalize_before:
335
+ return self.forward_pre(
336
+ tgt,
337
+ memory,
338
+ tgt_mask,
339
+ memory_mask,
340
+ tgt_key_padding_mask,
341
+ memory_key_padding_mask,
342
+ pos,
343
+ query_pos,
344
+ )
345
+ return self.forward_post(
346
+ tgt,
347
+ memory,
348
+ tgt_mask,
349
+ memory_mask,
350
+ tgt_key_padding_mask,
351
+ memory_key_padding_mask,
352
+ pos,
353
+ query_pos,
354
+ )
355
+
356
+
357
+ def _get_clones(module, N):
358
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
359
+
360
+
361
+ def _get_activation_fn(activation):
362
+ """Return an activation function given a string"""
363
+ if activation == "relu":
364
+ return F.relu
365
+ if activation == "gelu":
366
+ return F.gelu
367
+ if activation == "glu":
368
+ return F.glu
369
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
mask2former/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .motion_visualizer import MotionVisualizer
mask2former/utils/misc.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
3
+ """
4
+ Misc functions, including distributed helpers.
5
+
6
+ Mostly copy-paste from torchvision references.
7
+ """
8
+ from typing import List, Optional
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torchvision
13
+ from torch import Tensor
14
+
15
+
16
+ def _max_by_axis(the_list):
17
+ # type: (List[List[int]]) -> List[int]
18
+ maxes = the_list[0]
19
+ for sublist in the_list[1:]:
20
+ for index, item in enumerate(sublist):
21
+ maxes[index] = max(maxes[index], item)
22
+ return maxes
23
+
24
+
25
+ class NestedTensor(object):
26
+ def __init__(self, tensors, mask: Optional[Tensor]):
27
+ self.tensors = tensors
28
+ self.mask = mask
29
+
30
+ def to(self, device):
31
+ # type: (Device) -> NestedTensor # noqa
32
+ cast_tensor = self.tensors.to(device)
33
+ mask = self.mask
34
+ if mask is not None:
35
+ assert mask is not None
36
+ cast_mask = mask.to(device)
37
+ else:
38
+ cast_mask = None
39
+ return NestedTensor(cast_tensor, cast_mask)
40
+
41
+ def decompose(self):
42
+ return self.tensors, self.mask
43
+
44
+ def __repr__(self):
45
+ return str(self.tensors)
46
+
47
+
48
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
49
+ # TODO make this more general
50
+ if tensor_list[0].ndim == 3:
51
+ if torchvision._is_tracing():
52
+ # nested_tensor_from_tensor_list() does not export well to ONNX
53
+ # call _onnx_nested_tensor_from_tensor_list() instead
54
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
55
+
56
+ # TODO make it support different-sized images
57
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
58
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
59
+ batch_shape = [len(tensor_list)] + max_size
60
+ b, c, h, w = batch_shape
61
+ dtype = tensor_list[0].dtype
62
+ device = tensor_list[0].device
63
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
64
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
65
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
66
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
67
+ m[: img.shape[1], : img.shape[2]] = False
68
+ else:
69
+ raise ValueError("not supported")
70
+ return NestedTensor(tensor, mask)
71
+
72
+
73
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
74
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
75
+ @torch.jit.unused
76
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
77
+ max_size = []
78
+ for i in range(tensor_list[0].dim()):
79
+ max_size_i = torch.max(
80
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
81
+ ).to(torch.int64)
82
+ max_size.append(max_size_i)
83
+ max_size = tuple(max_size)
84
+
85
+ # work around for
86
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
87
+ # m[: img.shape[1], :img.shape[2]] = False
88
+ # which is not yet supported in onnx
89
+ padded_imgs = []
90
+ padded_masks = []
91
+ for img in tensor_list:
92
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
93
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
94
+ padded_imgs.append(padded_img)
95
+
96
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
97
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
98
+ padded_masks.append(padded_mask.to(torch.bool))
99
+
100
+ tensor = torch.stack(padded_imgs)
101
+ mask = torch.stack(padded_masks)
102
+
103
+ return NestedTensor(tensor, mask=mask)
104
+
105
+
106
+ def is_dist_avail_and_initialized():
107
+ if not dist.is_available():
108
+ return False
109
+ if not dist.is_initialized():
110
+ return False
111
+ return True
mask2former/utils/motion_visualizer.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fvcore.common.file_io import PathManager
2
+ from detectron2.utils.visualizer import (
3
+ Visualizer,
4
+ ColorMode,
5
+ _create_text_labels,
6
+ GenericMask,
7
+ )
8
+ from detectron2.structures import (
9
+ BitMasks,
10
+ Boxes,
11
+ BoxMode,
12
+ Keypoints,
13
+ PolygonMasks,
14
+ RotatedBoxes,
15
+ )
16
+ from detectron2.utils.colormap import random_color
17
+
18
+ from PIL import Image
19
+ import numpy as np
20
+ from numpy.linalg import norm
21
+ import math
22
+
23
+ MOTION_TYPE = {0: "rotation", 1: "translation"}
24
+ _COLORS_CAT = {
25
+ 0: np.array([166, 206, 227]) / 255,
26
+ 1: np.array([31, 120, 180]) / 255,
27
+ 2: np.array([202, 178, 214]) / 255,
28
+ 3: np.array([106, 61, 154]) / 255,
29
+ 4: np.array([178, 223, 138]) / 255,
30
+ 5: np.array([51, 160, 44]) / 255,
31
+ }
32
+ _COLORS_LEVEL = {
33
+ 0: np.array([0, 255, 0]) / 255,
34
+ 1: np.array([255, 128, 0]) / 255,
35
+ 2: np.array([255, 0, 0]) / 255,
36
+ }
37
+
38
+
39
+ def getFocalLength(FOV, height, width=None):
40
+ # FOV is in radius, should be vertical angle
41
+ if width == None:
42
+ f = height / (2 * math.tan(FOV / 2))
43
+ return f
44
+ else:
45
+ fx = height / (2 * math.tan(FOV / 2))
46
+ fy = fx / height * width
47
+ return (fx, fy)
48
+
49
+
50
+ def camera_to_image(point, is_real=False, intrinsic_matrix=None):
51
+ point_camera = np.array(point)
52
+ # Calculate the camera intrinsic parameters (they are fixed in this project)
53
+ if not is_real:
54
+ # Below is for the MoionNet synthetic dataset intrinsic
55
+ FOV = 50
56
+ img_width = img_height = 256
57
+ fx, fy = getFocalLength(FOV / 180 * math.pi, img_height, img_width)
58
+ cy = img_height / 2
59
+ cx = img_width / 2
60
+ x = point_camera[0] * fx / (-point_camera[2]) + cx
61
+ y = -(point_camera[1] * fy / (-point_camera[2])) + cy
62
+ else:
63
+ # Below is the for MotionREAL dataset
64
+ point_2d = np.dot(intrinsic_matrix, point_camera[:3])
65
+ x = point_2d[0] / point_2d[2]
66
+ y = point_2d[1] / point_2d[2]
67
+
68
+ return (x, y)
69
+
70
+
71
+ def rotation_from_vectors(source, dest):
72
+ a, b = (source / np.linalg.norm(source)).reshape(3), (
73
+ dest / np.linalg.norm(dest)
74
+ ).reshape(3)
75
+ v = np.cross(a, b)
76
+ c = np.dot(a, b)
77
+ s = np.linalg.norm(v)
78
+ kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
79
+ rmat = np.eye(3) + kmat + np.matmul(kmat, kmat) * ((1 - c) / (s ** 2))
80
+ return rmat
81
+
82
+
83
+ def rotatePoint(x, y, angle, scale):
84
+ rad = np.pi * angle / 180
85
+ x2 = np.cos(rad) * x - np.sin(rad) * y
86
+ y2 = np.sin(rad) * x + np.cos(rad) * y
87
+ return [x2 * scale, y2 * scale]
88
+
89
+
90
+ def circlePoints(axis, radius=0.5, num=50):
91
+ angles = np.linspace(0, 2 * np.pi, num, endpoint=False)
92
+ x_vec = np.cos(angles) * radius
93
+ y_vec = np.sin(angles) * radius
94
+ z_vec = np.zeros_like(x_vec) + 0.5
95
+ points = np.stack((x_vec, y_vec, z_vec), axis=0)
96
+ rot = rotation_from_vectors(np.array([0, 0, 1]), np.asarray(axis))
97
+ points = np.matmul(rot, points)
98
+ return points
99
+
100
+
101
+ def get_iou(bb1, bb2):
102
+ x_left = max(bb1[0], bb2[0])
103
+ y_top = max(bb1[1], bb2[1])
104
+ x_right = min(bb1[0] + bb1[2], bb2[0] + bb2[2])
105
+ y_bottom = min(bb1[1] + bb1[3], bb2[1] + bb2[3])
106
+
107
+ if x_right < x_left or y_bottom < y_top:
108
+ return 0.0
109
+
110
+ area = (x_right - x_left) * (y_bottom - y_top)
111
+
112
+ bb1_area = bb1[2] * bb1[3]
113
+ bb2_area = bb2[2] * bb2[3]
114
+ iou = area / float(bb1_area + bb2_area - area)
115
+ return iou
116
+
117
+
118
+ class MotionVisualizer(Visualizer):
119
+ def draw_gt_instance(self, anno, part_id_json, is_real=False, intrinsic_matrix=None, line_length=1):
120
+ # All annotations have been in the camera coordinate
121
+ masks = [anno["segmentation"]]
122
+ boxes = [BoxMode.convert(anno["bbox"], anno["bbox_mode"], BoxMode.XYXY_ABS)]
123
+ labels = [anno["category_id"]]
124
+ colors = None
125
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
126
+ "thing_colors"
127
+ ):
128
+ colors = [
129
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
130
+ for c in labels
131
+ ]
132
+
133
+ origins = [anno["motion"]["current_origin"]]
134
+ # Calculate the 2d origin (Only consider draw only one origin)
135
+ origins_4d = [origin[:] + [1] for origin in origins]
136
+ origin_2d = [camera_to_image(origin, is_real, intrinsic_matrix) for origin in origins_4d]
137
+
138
+ axises = [anno["motion"]["current_axis"]]
139
+ new_point = list(np.array(origins[0]) + line_length * np.array(axises[0]))
140
+ new_point = new_point[:] + [1]
141
+ new_point = camera_to_image(new_point, is_real, intrinsic_matrix)
142
+
143
+ arrow_p0 = rotatePoint(
144
+ new_point[0] - origin_2d[0][0], new_point[1] - origin_2d[0][1], 30, 0.1
145
+ )
146
+ arrow_p1 = rotatePoint(
147
+ new_point[0] - origin_2d[0][0], new_point[1] - origin_2d[0][1], -30, 0.1
148
+ )
149
+ circle_p = circlePoints(axises[0], 0.1, 50)
150
+ circle_p = line_length * circle_p + np.repeat(
151
+ np.asarray(origins[0])[:, np.newaxis], 50, axis=1
152
+ )
153
+ circle_p = circle_p.transpose()
154
+ circle_p_2d = np.asarray([camera_to_image(p, is_real, intrinsic_matrix) for p in circle_p])
155
+
156
+ self.draw_line(
157
+ [origin_2d[0][0], new_point[0]],
158
+ [origin_2d[0][1], new_point[1]],
159
+ color=_COLORS_LEVEL[0],
160
+ linewidth=2,
161
+ )
162
+ self.draw_line(
163
+ [new_point[0] - arrow_p0[0], new_point[0]],
164
+ [new_point[1] - arrow_p0[1], new_point[1]],
165
+ color=_COLORS_LEVEL[0],
166
+ linewidth=2,
167
+ )
168
+ self.draw_line(
169
+ [new_point[0] - arrow_p1[0], new_point[0]],
170
+ [new_point[1] - arrow_p1[1], new_point[1]],
171
+ color=_COLORS_LEVEL[0],
172
+ linewidth=2,
173
+ )
174
+ self.draw_polygon(
175
+ circle_p_2d, color=_COLORS_LEVEL[0], edge_color=_COLORS_LEVEL[0], alpha=0.0
176
+ )
177
+
178
+ mtype = 0 if anno["motion"]["type"] == "rotation" else 1
179
+
180
+ if not mtype:
181
+ self.draw_circle(origin_2d[0], color=_COLORS_LEVEL[0], radius=5)
182
+
183
+ names = self.metadata.get("thing_classes", None)
184
+ if names:
185
+ labels = [names[i] + "_" + anno["motion"]["type"] for i in labels]
186
+ labels = [
187
+ "{}".format(i) + ("|crowd" if a.get("iscrowd", 0) else "")
188
+ for i, a in zip(labels, [anno])
189
+ ]
190
+
191
+ cat_id = anno["category_id"]
192
+ self.overlay_instances(
193
+ labels=labels,
194
+ boxes=boxes,
195
+ masks=masks,
196
+ assigned_colors=[_COLORS_CAT[cat_id * 2 + mtype]],
197
+ )
198
+
199
+ part_id_json["partId"] = anno["motion"]["partId"]
200
+ part_id_json["type"] = anno["motion"]["type"]
201
+ part_id_json["category_id"] = anno["category_id"]
202
+
203
+ return self.output
204
+
205
+ def draw_prior(self, anno):
206
+ # All annotations have been in the camera coordinate
207
+ labels = [0]
208
+
209
+ origin = anno["start"]
210
+ origin_2d = anno["start_2d"]
211
+ new_point = anno["end_2d"]
212
+
213
+ axises = [anno["axises"]]
214
+ print(axises)
215
+
216
+ projection = anno["projMat"]
217
+
218
+ arrow_p0 = rotatePoint(
219
+ new_point[0] - origin_2d[0], new_point[1] - origin_2d[1], 30, 0.1
220
+ )
221
+ arrow_p1 = rotatePoint(
222
+ new_point[0] - origin_2d[0], new_point[1] - origin_2d[1], -30, 0.1
223
+ )
224
+
225
+ circle_p = circlePoints(axises[0], 0.1, 50)
226
+ circle_p = circle_p + np.repeat(np.asarray(origin)[:, np.newaxis], 50, axis=1)
227
+ # circle_p = circle_p.transpose()
228
+ circle_p = np.vstack((circle_p, np.ones(circle_p.shape[1])))
229
+ circle_p_2d = np.dot(projection, circle_p)
230
+ circle_p_2d = circle_p_2d / circle_p_2d[3, :]
231
+ circle_p_2d = circle_p_2d[:2, :]
232
+ circle_p_2d[0, :] = (circle_p_2d[0, :] + 1) / 2 * anno["img_size"]
233
+ circle_p_2d[1, :] = (-circle_p_2d[1, :] + 1) / 2 * anno["img_size"]
234
+ circle_p_2d = circle_p_2d.transpose()
235
+
236
+ axis_diff = anno["error"]
237
+ if axis_diff <= 2:
238
+ axis_color = _COLORS_LEVEL[0]
239
+ elif axis_diff > 2 and axis_diff <= 10:
240
+ axis_color = _COLORS_LEVEL[1]
241
+ elif axis_diff > 10:
242
+ axis_color = _COLORS_LEVEL[2]
243
+
244
+ print(axis_diff)
245
+
246
+ self.draw_line(
247
+ [origin_2d[0], new_point[0]],
248
+ [origin_2d[1], new_point[1]],
249
+ color=axis_color,
250
+ linewidth=2,
251
+ )
252
+ self.draw_line(
253
+ [new_point[0] - arrow_p0[0], new_point[0]],
254
+ [new_point[1] - arrow_p0[1], new_point[1]],
255
+ color=axis_color,
256
+ linewidth=2,
257
+ )
258
+ self.draw_line(
259
+ [new_point[0] - arrow_p1[0], new_point[0]],
260
+ [new_point[1] - arrow_p1[1], new_point[1]],
261
+ color=axis_color,
262
+ linewidth=2,
263
+ )
264
+ self.draw_polygon(
265
+ circle_p_2d, color=axis_color, edge_color=axis_color, alpha=0.0
266
+ )
267
+
268
+ mtype = 1
269
+
270
+ if not mtype:
271
+ self.draw_circle(origin_2d, color=_COLORS_LEVEL[0], radius=5)
272
+
273
+ cat_id = 0
274
+ labels = [
275
+ "{}".format(i) + ("|crowd" if a.get("iscrowd", 0) else "")
276
+ for i, a in zip(labels, [anno])
277
+ ]
278
+ # self.overlay_instances(
279
+ # labels=labels, boxes=None, masks=None, assigned_colors=[_COLORS_CAT[cat_id*2+mtype]]
280
+ # )
281
+
282
+ return self.output
283
+
284
+ def draw_pred_instance(self, prediction, d, match, is_real=False, intrinsic_matrix=None, line_length=1, no_mask=False, diagonal_length=-1):
285
+ if "annotations" in d:
286
+ boxes = prediction.get("bbox", None)
287
+
288
+ anno = None
289
+ annos = d["annotations"]
290
+ max_iou = -1
291
+ if not len(annos):
292
+ return None
293
+
294
+ for gt_anno in annos:
295
+ iou = get_iou(gt_anno["bbox"], boxes)
296
+ if np.isnan(iou):
297
+ return False
298
+ if iou > max_iou:
299
+ max_iou = iou
300
+ anno = gt_anno
301
+ else:
302
+ max_iou = -1
303
+ boxes = prediction.get("bbox", None)
304
+ anno = d
305
+ boxes = prediction.get("bbox", None)
306
+ iou = get_iou(anno["bbox"], boxes)
307
+ if iou > max_iou:
308
+ max_iou = iou
309
+
310
+ boxes = [BoxMode.convert(boxes, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)]
311
+
312
+ # Based on the motion type, determine to visualize the predicted motion origin or gt motion origin
313
+ # For translation joint, the motion origin is meaningless
314
+ pred_type = prediction["mtype"]
315
+ if pred_type == 1:
316
+ pred_origin = anno["motion"]["current_origin"]
317
+ else:
318
+ pred_origin = prediction["morigin"]
319
+
320
+ # Prepare the predicted origin and predicted axis
321
+ pred_origin_4d = pred_origin + [1]
322
+ pred_origin_2d = camera_to_image(pred_origin_4d, is_real, intrinsic_matrix)
323
+ pred_axis = np.array(prediction["maxis"])
324
+ pred_axis = list(pred_axis / norm(pred_axis))
325
+ pred_new_point = list(np.array(pred_origin) + line_length * np.array(pred_axis))
326
+ pred_new_point = pred_new_point + [1]
327
+ pred_new_point = camera_to_image(pred_new_point, is_real, intrinsic_matrix)
328
+
329
+ # Prepare the gt origin and gt axis
330
+ gt_origin = anno["motion"]["current_origin"]
331
+ gt_origin_4d = gt_origin + [1]
332
+ gt_origin_2d = camera_to_image(gt_origin_4d, is_real, intrinsic_matrix)
333
+ gt_axis = anno["motion"][
334
+ "current_axis"
335
+ ] # gt_axis has been normalized in the annotation
336
+ gt_new_point = list(np.array(gt_origin) + line_length * np.array(gt_axis))
337
+ gt_new_point = gt_new_point + [1]
338
+ gt_new_point = camera_to_image(gt_new_point, is_real, intrinsic_matrix)
339
+
340
+ # Caluculate the axis and origin error to determine the color for the visualization of axis and origin
341
+ axis_diff = (
342
+ np.arccos(
343
+ np.abs(
344
+ np.dot(np.array(gt_axis), np.array(pred_axis))
345
+ / (norm(pred_axis) * norm(gt_axis))
346
+ )
347
+ )
348
+ / np.pi
349
+ * 180.0
350
+ )
351
+ if axis_diff <= 5:
352
+ axis_color = _COLORS_LEVEL[0]
353
+ elif axis_diff > 5 and axis_diff <= 10:
354
+ axis_color = _COLORS_LEVEL[1]
355
+ elif axis_diff > 10:
356
+ axis_color = _COLORS_LEVEL[2]
357
+
358
+ if diagonal_length == -1:
359
+ raise ValueError("diagonal length error")
360
+
361
+ origin_diff = np.linalg.norm(
362
+ np.cross(np.array(pred_origin) - np.array(gt_origin), np.array(gt_axis))
363
+ ) / np.linalg.norm(gt_axis) / diagonal_length
364
+ if origin_diff <= 0.1:
365
+ origin_color = _COLORS_LEVEL[0]
366
+ elif origin_diff > 0.1 and origin_diff <= 0.25:
367
+ origin_color = _COLORS_LEVEL[1]
368
+ elif origin_diff > 0.25:
369
+ origin_color = _COLORS_LEVEL[2]
370
+
371
+ # Visualize gt
372
+ gt_color = np.array([0, 0, 255]) / 255
373
+ gt_arrow_p0 = rotatePoint(
374
+ gt_new_point[0] - gt_origin_2d[0],
375
+ gt_new_point[1] - gt_origin_2d[1],
376
+ 30,
377
+ 0.1,
378
+ )
379
+ gt_arrow_p1 = rotatePoint(
380
+ gt_new_point[0] - gt_origin_2d[0],
381
+ gt_new_point[1] - gt_origin_2d[1],
382
+ -30,
383
+ 0.1,
384
+ )
385
+ gt_circle_p = circlePoints(gt_axis, 0.1, 50)
386
+ gt_circle_p = line_length * gt_circle_p + np.repeat(
387
+ np.asarray(gt_origin)[:, np.newaxis], 50, axis=1
388
+ )
389
+ gt_circle_p = gt_circle_p.transpose()
390
+ gt_circle_p_2d = np.asarray([camera_to_image(p, is_real, intrinsic_matrix) for p in gt_circle_p])
391
+ self.draw_line(
392
+ [gt_origin_2d[0], gt_new_point[0]],
393
+ [gt_origin_2d[1], gt_new_point[1]],
394
+ color=gt_color,
395
+ linewidth=2,
396
+ )
397
+ self.draw_line(
398
+ [gt_new_point[0] - gt_arrow_p0[0], gt_new_point[0]],
399
+ [gt_new_point[1] - gt_arrow_p0[1], gt_new_point[1]],
400
+ color=gt_color,
401
+ linewidth=2,
402
+ )
403
+ self.draw_line(
404
+ [gt_new_point[0] - gt_arrow_p1[0], gt_new_point[0]],
405
+ [gt_new_point[1] - gt_arrow_p1[1], gt_new_point[1]],
406
+ color=gt_color,
407
+ linewidth=2,
408
+ )
409
+ self.draw_polygon(
410
+ gt_circle_p_2d, color=gt_color, edge_color=gt_color, alpha=0.0
411
+ )
412
+ if pred_type == 0:
413
+ # self.draw_text("origin_error: {:.3f}".format(origin_diff), (origin_2d[0][0], origin_2d[0][1]-10*text_y_offset), color="c")
414
+ self.draw_circle(gt_origin_2d, color=gt_color, radius=5)
415
+
416
+ # Visualize the predicted axis
417
+ pred_arrow_p0 = rotatePoint(
418
+ pred_new_point[0] - pred_origin_2d[0],
419
+ pred_new_point[1] - pred_origin_2d[1],
420
+ 30,
421
+ 0.1,
422
+ )
423
+ pred_arrow_p1 = rotatePoint(
424
+ pred_new_point[0] - pred_origin_2d[0],
425
+ pred_new_point[1] - pred_origin_2d[1],
426
+ -30,
427
+ 0.1,
428
+ )
429
+ pred_circle_p = circlePoints(pred_axis, 0.1, 50)
430
+ pred_circle_p = line_length * pred_circle_p + np.repeat(
431
+ np.asarray(pred_origin)[:, np.newaxis], 50, axis=1
432
+ )
433
+ pred_circle_p = pred_circle_p.transpose()
434
+ pred_circle_p_2d = np.asarray([camera_to_image(p, is_real, intrinsic_matrix) for p in pred_circle_p])
435
+ # text_y_offset = 1 if (new_point[1]-origin_2d[0][1]) > 0 else -1
436
+ # self.draw_text("axis_error: {:.3f}".format(axis_diff), (origin_2d[0][0], origin_2d[0][1]-20*text_y_offset), color="tan")
437
+ self.draw_line(
438
+ [pred_origin_2d[0], pred_new_point[0]],
439
+ [pred_origin_2d[1], pred_new_point[1]],
440
+ color=axis_color,
441
+ linewidth=2,
442
+ )
443
+ self.draw_line(
444
+ [pred_new_point[0] - pred_arrow_p0[0], pred_new_point[0]],
445
+ [pred_new_point[1] - pred_arrow_p0[1], pred_new_point[1]],
446
+ color=axis_color,
447
+ linewidth=2,
448
+ )
449
+ self.draw_line(
450
+ [pred_new_point[0] - pred_arrow_p1[0], pred_new_point[0]],
451
+ [pred_new_point[1] - pred_arrow_p1[1], pred_new_point[1]],
452
+ color=axis_color,
453
+ linewidth=2,
454
+ )
455
+ self.draw_polygon(
456
+ pred_circle_p_2d, color=axis_color, edge_color=axis_color, alpha=0.0
457
+ )
458
+ if pred_type == 0:
459
+ # self.draw_text("origin_error: {:.3f}".format(origin_diff), (origin_2d[0][0], origin_2d[0][1]-10*text_y_offset), color="c")
460
+ self.draw_circle(pred_origin_2d, color=origin_color, radius=5)
461
+
462
+ # Assign color to the segmentation
463
+ cat_id = prediction.get("category_id", None)
464
+ color_cat = _COLORS_CAT[cat_id * 2 + pred_type]
465
+
466
+ scores = [prediction.get("score", None)]
467
+ classes = [prediction.get("category_id", None)]
468
+ labels = _create_text_labels_motion(
469
+ classes,
470
+ scores,
471
+ self.metadata.get("thing_classes", None),
472
+ MOTION_TYPE[pred_type],
473
+ )
474
+ keypoints = prediction.get("keypoints", None)
475
+ if prediction.get("segmentation"):
476
+ import pycocotools.mask as mask_util
477
+
478
+ masks = [prediction.get("segmentation")]
479
+ else:
480
+ masks = None
481
+
482
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
483
+ "thing_colors"
484
+ ):
485
+ colors = [
486
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
487
+ for c in classes
488
+ ]
489
+ alpha = 0.8
490
+ else:
491
+ colors = [color_cat]
492
+ alpha = 0.5
493
+
494
+ if self._instance_mode == ColorMode.IMAGE_BW:
495
+ self.output.img = self._create_grayscale_image(
496
+ (mask_util.decode(prediction.get("segmentation")).any() > 0).numpy()
497
+ )
498
+ alpha = 0.3
499
+ # import pdb
500
+ # pdb.set_trace()
501
+ match["iou"] = max_iou
502
+ # Add the gt information
503
+ match["gt"] = {}
504
+ match["gt"]["partId"] = anno["motion"]["partId"]
505
+ match["gt"]["label"] = anno["motion"]["part_label"]
506
+ match["gt"]["type"] = anno["motion"]["type"]
507
+ match["gt"]["category_id"] = anno["category_id"]
508
+ match["gt"]["origin"] = gt_origin
509
+ match["gt"]["axis"] = gt_axis
510
+ # add the prediction information
511
+ match["pred"] = {}
512
+ match["pred"]["score"] = scores[0]
513
+ match["pred"]["type"] = pred_type
514
+ match["pred"]["category_id"] = cat_id
515
+ match["pred"]["origin"] = pred_origin
516
+ match["pred"]["axis"] = pred_axis
517
+ # add additional information
518
+ match["axis_error"] = axis_diff
519
+ match["origin_error"] = origin_diff
520
+ match["match"] = (
521
+ int(pred_type)
522
+ == int(
523
+ list(MOTION_TYPE.keys())[
524
+ list(MOTION_TYPE.values()).index(anno["motion"]["type"])
525
+ ]
526
+ )
527
+ ) and (cat_id == anno["category_id"])
528
+
529
+ if no_mask:
530
+ masks = None
531
+
532
+ self.overlay_instances(
533
+ masks=masks,
534
+ boxes=boxes,
535
+ labels=labels,
536
+ keypoints=keypoints,
537
+ assigned_colors=colors,
538
+ alpha=alpha,
539
+ )
540
+ return self.output
541
+
542
+ def draw_pred_only(self, prediction, prob):
543
+ scores = prediction.scores if prediction.has("scores") else None
544
+ if scores.numpy()[0] < prob:
545
+ return None
546
+
547
+ origins = list(prediction.morigin.numpy())
548
+ origins = [list(origin) for origin in origins]
549
+
550
+ axises = list(prediction.maxis.numpy())
551
+ axises = [list(axis) for axis in axises]
552
+
553
+ types = list(prediction.mtype.numpy())
554
+ classes = prediction.pred_classes if prediction.has("pred_classes") else None
555
+
556
+ color_cat = _COLORS_CAT[classes.numpy()[0] * 2 + types[0]]
557
+
558
+ origins_4d = [origin[:] + [1] for origin in origins]
559
+ origin_2d = [camera_to_image(origin) for origin in origins_4d]
560
+
561
+ new_point = list(np.array(origins[0]) + np.array(axises[0]))
562
+ new_point = new_point[:] + [1]
563
+ new_point = camera_to_image(new_point)
564
+
565
+ axis_color = _COLORS_LEVEL[0]
566
+ origin_color = _COLORS_LEVEL[0]
567
+
568
+ arrow_p0 = rotatePoint(
569
+ new_point[0] - origin_2d[0][0], new_point[1] - origin_2d[0][1], 30, 0.1
570
+ )
571
+ arrow_p1 = rotatePoint(
572
+ new_point[0] - origin_2d[0][0], new_point[1] - origin_2d[0][1], -30, 0.1
573
+ )
574
+ circle_p = circlePoints(axises[0], 0.1, 50)
575
+ circle_p = circle_p + np.repeat(
576
+ np.asarray(origins[0])[:, np.newaxis], 50, axis=1
577
+ )
578
+ circle_p = circle_p.transpose()
579
+ circle_p_2d = np.asarray([camera_to_image(p) for p in circle_p])
580
+
581
+ # text_y_offset = 1 if (new_point[1]-origin_2d[0][1]) > 0 else -1
582
+ # self.draw_text("axis_error: {:.3f}".format(axis_diff), (origin_2d[0][0], origin_2d[0][1]-20*text_y_offset), color="tan")
583
+ self.draw_line(
584
+ [origin_2d[0][0], new_point[0]],
585
+ [origin_2d[0][1], new_point[1]],
586
+ color=axis_color,
587
+ linewidth=2,
588
+ )
589
+ self.draw_line(
590
+ [new_point[0] - arrow_p0[0], new_point[0]],
591
+ [new_point[1] - arrow_p0[1], new_point[1]],
592
+ color=axis_color,
593
+ linewidth=2,
594
+ )
595
+ self.draw_line(
596
+ [new_point[0] - arrow_p1[0], new_point[0]],
597
+ [new_point[1] - arrow_p1[1], new_point[1]],
598
+ color=axis_color,
599
+ linewidth=2,
600
+ )
601
+ self.draw_polygon(
602
+ circle_p_2d, color=axis_color, edge_color=axis_color, alpha=0.0
603
+ )
604
+
605
+ if types[0] == 0:
606
+ # self.draw_text("origin_error: {:.3f}".format(origin_diff), (origin_2d[0][0], origin_2d[0][1]-10*text_y_offset), color="c")
607
+ self.draw_circle(origin_2d[0], color=origin_color, radius=5)
608
+
609
+ boxes = prediction.pred_boxes if prediction.has("pred_boxes") else None
610
+ labels = _create_text_labels_motion(
611
+ classes,
612
+ scores,
613
+ self.metadata.get("thing_classes", None),
614
+ MOTION_TYPE[types[0]],
615
+ )
616
+ keypoints = (
617
+ prediction.pred_keypoints if prediction.has("pred_keypoints") else None
618
+ )
619
+
620
+ if prediction.has("pred_masks"):
621
+ masks = np.asarray(prediction.pred_masks)
622
+ masks = [
623
+ GenericMask(x, self.output.height, self.output.width) for x in masks
624
+ ]
625
+ else:
626
+ masks = None
627
+
628
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
629
+ "thing_colors"
630
+ ):
631
+ colors = [
632
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
633
+ for c in classes
634
+ ]
635
+ alpha = 0.8
636
+ else:
637
+ colors = [color_cat]
638
+ alpha = 0.5
639
+
640
+ if self._instance_mode == ColorMode.IMAGE_BW:
641
+ self.output.img = self._create_grayscale_image(
642
+ (prediction.pred_masks.any(dim=0) > 0).numpy()
643
+ )
644
+ alpha = 0.3
645
+
646
+ self.overlay_instances(
647
+ masks=masks,
648
+ boxes=boxes,
649
+ labels=labels,
650
+ keypoints=keypoints,
651
+ assigned_colors=colors,
652
+ alpha=alpha,
653
+ )
654
+ return self.output
655
+
656
+
657
+ def _create_text_labels_motion(classes, scores, class_names, motion_type):
658
+ """
659
+ Args:
660
+ classes (list[int] or None):
661
+ scores (list[float] or None):
662
+ class_names (list[str] or None):
663
+
664
+ Returns:
665
+ list[str] or None
666
+ """
667
+ labels = None
668
+ if classes is not None and class_names is not None and len(class_names) > 1:
669
+ labels = [class_names[i] for i in classes]
670
+ labels = [label + "_" + motion_type for label in labels]
671
+ if scores is not None:
672
+ if labels is None:
673
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
674
+ else:
675
+ labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
676
+ return labels
mask2former/utils/tranform.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+ from scipy.spatial.distance import cdist, euclidean
6
+
7
+ def geometric_median(X, eps=1e-5):
8
+ y = np.mean(X, 0)
9
+
10
+ while True:
11
+ D = cdist(X, [y])
12
+ nonzeros = (D != 0)[:, 0]
13
+
14
+ Dinv = 1 / D[nonzeros]
15
+ Dinvs = np.sum(Dinv)
16
+ W = Dinv / Dinvs
17
+ T = np.sum(W * X[nonzeros], 0)
18
+
19
+ num_zeros = len(X) - np.sum(nonzeros)
20
+ if num_zeros == 0:
21
+ y1 = T
22
+ elif num_zeros == len(X):
23
+ return y
24
+ else:
25
+ R = (T - y) * Dinvs
26
+ r = np.linalg.norm(R)
27
+ rinv = 0 if r == 0 else num_zeros/r
28
+ y1 = max(0, 1-rinv)*T + min(1, rinv)*y
29
+
30
+ if euclidean(y, y1) < eps:
31
+ return y1
32
+
33
+ y = y1
34
+
35
+ # Transformation code fomr pytorch3d https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
36
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
39
+ using Gram--Schmidt orthogonalization per Section B of [1].
40
+ Args:
41
+ d6: 6D rotation representation, of size (*, 6)
42
+
43
+ Returns:
44
+ batch of rotation matrices of size (*, 3, 3)
45
+
46
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
47
+ On the Continuity of Rotation Representations in Neural Networks.
48
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
49
+ Retrieved from http://arxiv.org/abs/1812.07035
50
+ """
51
+
52
+ a1, a2 = d6[..., :3], d6[..., 3:]
53
+ b1 = F.normalize(a1, dim=-1)
54
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
55
+ b2 = F.normalize(b2, dim=-1)
56
+ b3 = torch.cross(b1, b2, dim=-1)
57
+ return torch.stack((b1, b2, b3), dim=-2)
58
+
59
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
62
+ by dropping the last row. Note that 6D representation is not unique.
63
+ Args:
64
+ matrix: batch of rotation matrices of size (*, 3, 3)
65
+
66
+ Returns:
67
+ 6D rotation representation, of size (*, 6)
68
+
69
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
70
+ On the Continuity of Rotation Representations in Neural Networks.
71
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
72
+ Retrieved from http://arxiv.org/abs/1812.07035
73
+ """
74
+ batch_dim = matrix.size()[:-2]
75
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
76
+
77
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Returns torch.sqrt(torch.max(0, x))
80
+ but with a zero subgradient where x is 0.
81
+ """
82
+ ret = torch.zeros_like(x)
83
+ positive_mask = x > 0
84
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
85
+ return ret
86
+
87
+
88
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
89
+ """
90
+ Convert rotations given as rotation matrices to quaternions.
91
+
92
+ Args:
93
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
94
+
95
+ Returns:
96
+ quaternions with real part first, as tensor of shape (..., 4).
97
+ """
98
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
99
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
100
+
101
+ batch_dim = matrix.shape[:-2]
102
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
103
+ matrix.reshape(batch_dim + (9,)), dim=-1
104
+ )
105
+
106
+ q_abs = _sqrt_positive_part(
107
+ torch.stack(
108
+ [
109
+ 1.0 + m00 + m11 + m22,
110
+ 1.0 + m00 - m11 - m22,
111
+ 1.0 - m00 + m11 - m22,
112
+ 1.0 - m00 - m11 + m22,
113
+ ],
114
+ dim=-1,
115
+ )
116
+ )
117
+
118
+ # we produce the desired quaternion multiplied by each of r, i, j, k
119
+ quat_by_rijk = torch.stack(
120
+ [
121
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
122
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
123
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
124
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
125
+ ],
126
+ dim=-2,
127
+ )
128
+
129
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
130
+ # the candidate won't be picked.
131
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
132
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
133
+
134
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
135
+ # forall i; we pick the best-conditioned one (with the largest denominator)
136
+
137
+ return quat_candidates[
138
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
139
+ ].reshape(batch_dim + (4,))
140
+
141
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
142
+ """
143
+ Convert rotations given as quaternions to rotation matrices.
144
+
145
+ Args:
146
+ quaternions: quaternions with real part first,
147
+ as tensor of shape (..., 4).
148
+
149
+ Returns:
150
+ Rotation matrices as tensor of shape (..., 3, 3).
151
+ """
152
+ r, i, j, k = torch.unbind(quaternions, -1)
153
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
154
+
155
+ o = torch.stack(
156
+ (
157
+ 1 - two_s * (j * j + k * k),
158
+ two_s * (i * j - k * r),
159
+ two_s * (i * k + j * r),
160
+ two_s * (i * j + k * r),
161
+ 1 - two_s * (i * i + k * k),
162
+ two_s * (j * k - i * r),
163
+ two_s * (i * k - j * r),
164
+ two_s * (j * k + i * r),
165
+ 1 - two_s * (i * i + j * j),
166
+ ),
167
+ -1,
168
+ )
169
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
pre-requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.25.2
2
+ Pillow==10.0.1
3
+ torch==2.0.1
4
+ torchaudio==2.0.2
5
+ torchvision==0.15.2
6
+ urllib3==1.26.16
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h5py==3.9.0
2
+ imageio==2.31.3
3
+ open3d==0.17.0
4
+ opencv-python==4.8.0.76
5
+ pandas==2.1.0
6
+ pycocotools==2.0.7
7
+ scikit-image==0.21.0
8
+ scikit-learn==1.3.0
9
+ scipy==1.11.2
10
+ timm==0.9.7
11
+ detectron2 @ git+https://github.com/facebookresearch/detectron2.git@fc9c33b1f6e5d4c37bbb46dde19af41afc1ddb2a