seokju cho commited on
Commit
f8f62f3
1 Parent(s): 7722584

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. INSTALL.md +20 -0
  2. R-101.pkl +3 -0
  3. README.md +48 -12
  4. app.py +130 -0
  5. assets/fig1.png +0 -0
  6. cat_seg/__init__.py +19 -0
  7. cat_seg/__pycache__/__init__.cpython-38.pyc +0 -0
  8. cat_seg/__pycache__/cat_sam_model.cpython-38.pyc +0 -0
  9. cat_seg/__pycache__/cat_seg_model.cpython-38.pyc +0 -0
  10. cat_seg/__pycache__/cat_seg_panoptic.cpython-38.pyc +0 -0
  11. cat_seg/__pycache__/config.cpython-38.pyc +0 -0
  12. cat_seg/__pycache__/pancat_model.cpython-38.pyc +0 -0
  13. cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc +0 -0
  14. cat_seg/cat_seg_model.py +386 -0
  15. cat_seg/config.py +93 -0
  16. cat_seg/data/__init__.py +2 -0
  17. cat_seg/data/__pycache__/__init__.cpython-38.pyc +0 -0
  18. cat_seg/data/dataset_mappers/__init__.py +1 -0
  19. cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc +0 -0
  20. cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc +0 -0
  21. cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc +0 -0
  22. cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc +0 -0
  23. cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py +180 -0
  24. cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py +165 -0
  25. cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py +186 -0
  26. cat_seg/data/datasets/__init__.py +8 -0
  27. cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  28. cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc +0 -0
  29. cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc +0 -0
  30. cat_seg/data/datasets/__pycache__/register_ade_panoptic.cpython-38.pyc +0 -0
  31. cat_seg/data/datasets/__pycache__/register_coco_panoptic.cpython-38.pyc +0 -0
  32. cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc +0 -0
  33. cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc +0 -0
  34. cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc +0 -0
  35. cat_seg/data/datasets/__pycache__/register_pascal_context.cpython-38.pyc +0 -0
  36. cat_seg/data/datasets/register_ade20k_150.py +28 -0
  37. cat_seg/data/datasets/register_ade20k_847.py +0 -0
  38. cat_seg/data/datasets/register_coco_stuff.py +216 -0
  39. cat_seg/data/datasets/register_pascal_20.py +53 -0
  40. cat_seg/data/datasets/register_pascal_59.py +81 -0
  41. cat_seg/modeling/__init__.py +3 -0
  42. cat_seg/modeling/__pycache__/__init__.cpython-38.pyc +0 -0
  43. cat_seg/modeling/__pycache__/criterion.cpython-38.pyc +0 -0
  44. cat_seg/modeling/__pycache__/matcher.cpython-38.pyc +0 -0
  45. cat_seg/modeling/backbone/__init__.py +1 -0
  46. cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc +0 -0
  47. cat_seg/modeling/backbone/__pycache__/image_encoder.cpython-38.pyc +0 -0
  48. cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc +0 -0
  49. cat_seg/modeling/backbone/swin.py +768 -0
  50. cat_seg/modeling/heads/__init__.py +1 -0
INSTALL.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+
3
+ ### Requirements
4
+ - Linux or macOS with Python ≥ 3.6
5
+ - PyTorch ≥ 1.7 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
6
+ Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check
7
+ PyTorch version matches that is required by Detectron2.
8
+ - Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html).
9
+ - OpenCV is optional but needed by demo and visualization
10
+ - `pip install -r requirements.txt`
11
+
12
+ An example of installation is shown below:
13
+
14
+ ```
15
+ git clone https://github.com/~~~/CAT-Seg.git
16
+ cd CAT-Seg
17
+ conda create -n catseg python=3.8
18
+ conda activate catseg
19
+ pip install -r requirements.txt
20
+ ```
R-101.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1156c77bff95ecb027060b5c83391b45bf159acd7f5bf7eacb656be0c1f0ab55
3
+ size 178666803
README.md CHANGED
@@ -1,12 +1,48 @@
1
- ---
2
- title: SAM CAT Seg
3
- emoji: 📚
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 3.29.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CAT-Seg🐱: Cost Aggregation for Open-Vocabulary Semantic Segmentation
2
+
3
+ This is our official implementation of CAT-Seg🐱!
4
+
5
+ [[arXiv](#)] [[Project](#)]<br>
6
+ by [Seokju Cho](https://seokju-cho.github.io/)\*, [Heeseong Shin](https://github.com/hsshin98)\*, [Sunghwan Hong](https://sunghwanhong.github.io), Seungjun An, Seungjun Lee, [Anurag Arnab](https://anuragarnab.github.io), [Paul Hongsuck Seo](https://phseo.github.io), [Seungryong Kim](https://cvlab.korea.ac.kr)
7
+
8
+
9
+ ## Introduction
10
+ ![](assets/fig1.png)
11
+ We introduce cost aggregation to open-vocabulary semantic segmentation, which jointly aggregates both image and text modalities within the matching cost.
12
+
13
+ ## Installation
14
+ Install required packages.
15
+
16
+ ```bash
17
+ conda create --name catseg python=3.8
18
+ conda activate catseg
19
+ conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ## Data Preparation
24
+
25
+
26
+ ## Training
27
+ ### Preparation
28
+ you have to blah
29
+ ### Training script
30
+ ```bash
31
+ python train.py --config configs/eval/{a847 | pc459 | a150 | pc59 | pas20 | pas20b}.yaml
32
+ ```
33
+
34
+ ## Evaluation
35
+ ```bash
36
+ python eval.py --config configs/eval/{a847 | pc459 | a150 | pc59 | pas20 | pas20b}.yaml
37
+ ```
38
+
39
+ ## Citing CAT-Seg🐱 :pray:
40
+
41
+ ```BibTeX
42
+ @article{liang2022open,
43
+ title={Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP},
44
+ author={Liang, Feng and Wu, Bichen and Dai, Xiaoliang and Li, Kunpeng and Zhao, Yinan and Zhang, Hang and Zhang, Peizhao and Vajda, Peter and Marculescu, Diana},
45
+ journal={arXiv preprint arXiv:2210.04150},
46
+ year={2022}
47
+ }
48
+ ```
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
3
+ import argparse
4
+ import glob
5
+ import multiprocessing as mp
6
+ import os
7
+ #os.environ["CUDA_VISIBLE_DEVICES"] = ""
8
+ try:
9
+ import detectron2
10
+ except ModuleNotFoundError:
11
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
12
+
13
+ try:
14
+ import segment_anything
15
+ except ModuleNotFoundError:
16
+ os.system('pip install git+https://github.com/facebookresearch/segment-anything.git')
17
+
18
+ # fmt: off
19
+ import sys
20
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
21
+ # fmt: on
22
+
23
+ import tempfile
24
+ import time
25
+ import warnings
26
+
27
+ import cv2
28
+ import numpy as np
29
+ import tqdm
30
+
31
+ from detectron2.config import get_cfg
32
+ from detectron2.data.detection_utils import read_image
33
+ from detectron2.projects.deeplab import add_deeplab_config
34
+ from detectron2.utils.logger import setup_logger
35
+
36
+ from cat_seg import add_cat_seg_config
37
+ from demo.predictor import VisualizationDemo
38
+ import gradio as gr
39
+ import torch
40
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
41
+
42
+ # constants
43
+ WINDOW_NAME = "MaskFormer demo"
44
+
45
+
46
+ def setup_cfg(args):
47
+ # load config from file and command-line arguments
48
+ cfg = get_cfg()
49
+ add_deeplab_config(cfg)
50
+ add_cat_seg_config(cfg)
51
+ cfg.merge_from_file(args.config_file)
52
+ cfg.merge_from_list(args.opts)
53
+ if torch.cuda.is_available():
54
+ cfg.MODEL.DEVICE = "cuda"
55
+ cfg.freeze()
56
+ return cfg
57
+
58
+
59
+ def get_parser():
60
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
61
+ parser.add_argument(
62
+ "--config-file",
63
+ default="configs/vitl_swinb_384.yaml",
64
+ metavar="FILE",
65
+ help="path to config file",
66
+ )
67
+ parser.add_argument(
68
+ "--input",
69
+ nargs="+",
70
+ help="A list of space separated input images; "
71
+ "or a single glob pattern such as 'directory/*.jpg'",
72
+ )
73
+ parser.add_argument(
74
+ "--opts",
75
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
76
+ default=(
77
+ [
78
+ "MODEL.WEIGHTS", "model_final_cls.pth",
79
+ "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
80
+ "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
81
+ "TEST.SLIDING_WINDOW", "True",
82
+ "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
83
+ "MODEL.PROMPT_ENSEMBLE_TYPE", "single",
84
+ "MODEL.DEVICE", "cpu",
85
+ ]),
86
+ nargs=argparse.REMAINDER,
87
+ )
88
+ return parser
89
+
90
+ def save_masks(preds, text):
91
+ preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
92
+ for i, t in enumerate(text):
93
+ dir = f"mask_{t}.png"
94
+ mask = preds == i
95
+ cv2.imwrite(dir, mask * 255)
96
+
97
+ def predict(image, text, model_type):
98
+ #import pdb; pdb.set_trace()
99
+ #use_sam = True #
100
+ use_sam = model_type != "CAT-Seg"
101
+
102
+ predictions, visualized_output = demo.run_on_image(image, text, use_sam)
103
+ #save_masks(predictions, text.split(','))
104
+ canvas = fc(visualized_output.fig)
105
+ canvas.draw()
106
+ out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
107
+
108
+ return out[..., ::-1]
109
+
110
+ if __name__ == "__main__":
111
+ args = get_parser().parse_args()
112
+ cfg = setup_cfg(args)
113
+ global demo
114
+ demo = VisualizationDemo(cfg)
115
+
116
+ iface = gr.Interface(
117
+ fn=predict,
118
+ inputs=[gr.Image(), gr.Textbox(placeholder='background, cat, person'), ], #gr.Radio(["CAT-Seg", "Segment Anycat"], value="CAT-Seg")],
119
+ outputs="image",
120
+ description="""## Segment Anything with CAT-Seg!
121
+ Welcome to the Segment Anything with CAT-Seg!
122
+
123
+ In this demo, we combine state-of-the-art open-vocabulary semantic segmentation model, CAT-Seg with SAM(Segment Anything) for semantically labelling mask predictions from SAM.
124
+
125
+ Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.
126
+
127
+ Also, the demo might run on a CPU depending on the demand, so it may take a little time to process your image.
128
+
129
+ To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
130
+ iface.launch()
assets/fig1.png ADDED
cat_seg/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import data # register all new datasets
3
+ from . import modeling
4
+
5
+ # config
6
+ from .config import add_cat_seg_config
7
+
8
+ # dataset loading
9
+ from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper
10
+ from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import (
11
+ MaskFormerPanopticDatasetMapper,
12
+ )
13
+ from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
14
+ MaskFormerSemanticDatasetMapper,
15
+ )
16
+
17
+ # models
18
+ from .cat_seg_model import CATSeg
19
+ from .test_time_augmentation import SemanticSegmentorWithTTA
cat_seg/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (693 Bytes). View file
 
cat_seg/__pycache__/cat_sam_model.cpython-38.pyc ADDED
Binary file (13.7 kB). View file
 
cat_seg/__pycache__/cat_seg_model.cpython-38.pyc ADDED
Binary file (12.6 kB). View file
 
cat_seg/__pycache__/cat_seg_panoptic.cpython-38.pyc ADDED
Binary file (10 kB). View file
 
cat_seg/__pycache__/config.cpython-38.pyc ADDED
Binary file (2.39 kB). View file
 
cat_seg/__pycache__/pancat_model.cpython-38.pyc ADDED
Binary file (11.4 kB). View file
 
cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc ADDED
Binary file (4.41 kB). View file
 
cat_seg/cat_seg_model.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.data import MetadataCatalog
10
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
11
+ from detectron2.modeling.backbone import Backbone
12
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
13
+ from detectron2.structures import ImageList
14
+ from detectron2.utils.memory import _ignore_torch_cuda_oom
15
+
16
+ import numpy as np
17
+ from einops import rearrange
18
+ from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
19
+
20
+ @META_ARCH_REGISTRY.register()
21
+ class CATSeg(nn.Module):
22
+ @configurable
23
+ def __init__(
24
+ self,
25
+ *,
26
+ backbone: Backbone,
27
+ sem_seg_head: nn.Module,
28
+ size_divisibility: int,
29
+ pixel_mean: Tuple[float],
30
+ pixel_std: Tuple[float],
31
+ clip_pixel_mean: Tuple[float],
32
+ clip_pixel_std: Tuple[float],
33
+ train_class_json: str,
34
+ test_class_json: str,
35
+ sliding_window: bool,
36
+ clip_finetune: str,
37
+ backbone_multiplier: float,
38
+ clip_pretrained: str,
39
+ ):
40
+ """
41
+ Args:
42
+ backbone: a backbone module, must follow detectron2's backbone interface
43
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
44
+ """
45
+ super().__init__()
46
+ self.backbone = backbone
47
+ self.sem_seg_head = sem_seg_head
48
+ if size_divisibility < 0:
49
+ size_divisibility = self.backbone.size_divisibility
50
+ self.size_divisibility = size_divisibility
51
+
52
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
53
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
54
+ self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False)
55
+ self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False)
56
+
57
+ self.train_class_json = train_class_json
58
+ self.test_class_json = test_class_json
59
+
60
+ self.clip_finetune = clip_finetune
61
+ for name, params in self.sem_seg_head.predictor.clip_model.named_parameters():
62
+ if "visual" in name:
63
+ if clip_finetune == "prompt":
64
+ params.requires_grad = True if "prompt" in name else False
65
+ elif clip_finetune == "attention":
66
+ params.requires_grad = True if "attn" in name or "position" in name else False
67
+ elif clip_finetune == "full":
68
+ params.requires_grad = True
69
+ else:
70
+ params.requires_grad = False
71
+ else:
72
+ params.requires_grad = False
73
+
74
+ finetune_backbone = backbone_multiplier > 0.
75
+ for name, params in self.backbone.named_parameters():
76
+ if "norm0" in name:
77
+ params.requires_grad = False
78
+ else:
79
+ params.requires_grad = finetune_backbone
80
+
81
+ self.sliding_window = sliding_window
82
+ self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336)
83
+ self.sequential = False
84
+
85
+ self.use_sam = False
86
+ self.sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").to(self.device)
87
+
88
+ amg_kwargs = {
89
+ "points_per_side": 32,
90
+ "points_per_batch": None,
91
+ #"pred_iou_thresh": 0.0,
92
+ #"stability_score_thresh": 0.0,
93
+ "stability_score_offset": None,
94
+ "box_nms_thresh": None,
95
+ "crop_n_layers": None,
96
+ "crop_nms_thresh": None,
97
+ "crop_overlap_ratio": None,
98
+ "crop_n_points_downscale_factor": None,
99
+ "min_mask_region_area": None,
100
+ }
101
+ amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
102
+ self.mask = SamAutomaticMaskGenerator(self.sam, output_mode="binary_mask", **amg_kwargs)
103
+ self.overlap_threshold = 0.8
104
+ self.panoptic_on = False
105
+
106
+ @classmethod
107
+ def from_config(cls, cfg):
108
+ backbone = build_backbone(cfg)
109
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
110
+
111
+ return {
112
+ "backbone": backbone,
113
+ "sem_seg_head": sem_seg_head,
114
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
115
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
116
+ "pixel_std": cfg.MODEL.PIXEL_STD,
117
+ "clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN,
118
+ "clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD,
119
+ "train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON,
120
+ "test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON,
121
+ "sliding_window": cfg.TEST.SLIDING_WINDOW,
122
+ "clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE,
123
+ "backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER,
124
+ "clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED,
125
+ }
126
+
127
+ @property
128
+ def device(self):
129
+ return self.pixel_mean.device
130
+
131
+ def forward(self, batched_inputs):
132
+ """
133
+ Args:
134
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
135
+ Each item in the list contains the inputs for one image.
136
+ For now, each item in the list is a dict that contains:
137
+ * "image": Tensor, image in (C, H, W) format.
138
+ * "instances": per-region ground truth
139
+ * Other information that's included in the original dicts, such as:
140
+ "height", "width" (int): the output resolution of the model (may be different
141
+ from input resolution), used in inference.
142
+ Returns:
143
+ list[dict]:
144
+ each dict has the results for one image. The dict contains the following keys:
145
+
146
+ * "sem_seg":
147
+ A Tensor that represents the
148
+ per-pixel segmentation prediced by the head.
149
+ The prediction has shape KxHxW that represents the logits of
150
+ each class for each pixel.
151
+ """
152
+ images = [x["image"].to(self.device) for x in batched_inputs]
153
+ sam_images = images
154
+ if not self.training and self.sliding_window:
155
+ if not self.sequential:
156
+ with _ignore_torch_cuda_oom():
157
+ return self.inference_sliding_window(batched_inputs)
158
+ self.sequential = True
159
+ return self.inference_sliding_window(batched_inputs)
160
+
161
+ clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images]
162
+ clip_images = ImageList.from_tensors(clip_images, self.size_divisibility)
163
+
164
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
165
+ images = ImageList.from_tensors(images, self.size_divisibility)
166
+
167
+ clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, )
168
+ clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
169
+
170
+ images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,)
171
+ features = self.backbone(images_resized)
172
+
173
+ outputs = self.sem_seg_head(clip_features, features)
174
+
175
+ if self.training:
176
+ targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0)
177
+ outputs = F.interpolate(outputs, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False)
178
+
179
+ num_classes = outputs.shape[1]
180
+ mask = targets != self.sem_seg_head.ignore_value
181
+
182
+ outputs = outputs.permute(0,2,3,1)
183
+ _targets = torch.zeros(outputs.shape, device=self.device)
184
+ _onehot = F.one_hot(targets[mask], num_classes=num_classes).float()
185
+ _targets[mask] = _onehot
186
+
187
+ loss = F.binary_cross_entropy_with_logits(outputs, _targets)
188
+ losses = {"loss_sem_seg" : loss}
189
+ return losses
190
+ else:
191
+ #outputs = outputs.sigmoid()
192
+ image_size = images.image_sizes[0]
193
+ if self.use_sam:
194
+ masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy()))
195
+ outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, image_size)
196
+ #outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, image_size)
197
+ #outputs, sam_cls = self.continuous_semantic_inference2(outputs, masks, image_size, img=img, text=text)
198
+ height = batched_inputs[0].get("height", image_size[0])
199
+ width = batched_inputs[0].get("width", image_size[1])
200
+
201
+ output = sem_seg_postprocess(outputs[0], image_size, height, width)
202
+ processed_results = [{'sem_seg': output}]
203
+ return processed_results
204
+
205
+
206
+ @torch.no_grad()
207
+ def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]):
208
+
209
+ images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs]
210
+ stride = int(kernel * (1 - overlap))
211
+ unfold = nn.Unfold(kernel_size=kernel, stride=stride)
212
+ fold = nn.Fold(out_res, kernel_size=kernel, stride=stride)
213
+
214
+ image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze()
215
+ sam_images = [image]
216
+ image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel)
217
+ global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False)
218
+ image = torch.cat((image, global_image), dim=0)
219
+
220
+ images = (image - self.pixel_mean) / self.pixel_std
221
+ clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std
222
+ clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, )
223
+ clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
224
+
225
+ if self.sequential:
226
+ outputs = []
227
+ for clip_feat, image in zip(clip_features, images):
228
+ feature = self.backbone(image.unsqueeze(0))
229
+ output = self.sem_seg_head(clip_feat.unsqueeze(0), feature)
230
+ outputs.append(output[0])
231
+ outputs = torch.stack(outputs, dim=0)
232
+ else:
233
+ features = self.backbone(images)
234
+ outputs = self.sem_seg_head(clip_features, features)
235
+
236
+ outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False)
237
+ outputs = outputs.sigmoid()
238
+
239
+ global_output = outputs[-1:]
240
+ global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,)
241
+ outputs = outputs[:-1]
242
+ outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device)))
243
+ outputs = (outputs + global_output) / 2.
244
+
245
+ height = batched_inputs[0].get("height", out_res[0])
246
+ width = batched_inputs[0].get("width", out_res[1])
247
+ catseg_outputs = sem_seg_postprocess(outputs[0], out_res, height, width)
248
+ #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
249
+
250
+ masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy()))
251
+ if self.use_sam:
252
+ outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, out_res)
253
+ #outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, out_res)
254
+
255
+ output = sem_seg_postprocess(outputs[0], out_res, height, width)
256
+
257
+ ret = [{'sem_seg': output}]
258
+ if self.panoptic_on:
259
+ panoptic_r = self.panoptic_inference(catseg_outputs, masks, sam_cls, size=output.shape[-2:])
260
+ ret[0]['panoptic_seg'] = panoptic_r
261
+
262
+ return ret
263
+
264
+ def discrete_semantic_inference(self, outputs, masks, image_size):
265
+ catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True) #.argmax(dim=1)[0].cpu()
266
+ sam_outputs = torch.zeros_like(catseg_outputs).cpu()
267
+ catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
268
+ sam_classes = torch.zeros(len(masks))
269
+ for i in range(len(masks)):
270
+ m = masks[i]['segmentation']
271
+ s = masks[i]['stability_score']
272
+ idx = catseg_outputs[m].bincount().argmax()
273
+ sam_outputs[0, idx][m] = s
274
+ sam_classes[i] = idx
275
+
276
+ return sam_outputs, sam_classes
277
+
278
+ def continuous_semantic_inference(self, outputs, masks, image_size, scale=100/7.):
279
+ #import pdb; pdb.set_trace()
280
+ catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
281
+ sam_outputs = torch.zeros_like(catseg_outputs)
282
+ #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
283
+ sam_classes = torch.zeros(len(masks))
284
+ #import pdb; pdb.set_trace()
285
+ mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W
286
+ mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N
287
+
288
+ mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs)
289
+ mask_norm = mask_pred.sum(-1).sum(-1)
290
+ mask_cls = mask_cls / mask_norm[:, None]
291
+ mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None]
292
+
293
+ mask_logits = mask_pred * mask_score[:, None, None]
294
+ output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls)
295
+
296
+ return output.unsqueeze(0), mask_cls
297
+
298
+ def continuous_semantic_inference2(self, outputs, masks, image_size, scale=100/7., img=None, text=None):
299
+ assert img is not None and text is not None
300
+ import pdb; pdb.set_trace()
301
+ #catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
302
+ img = F.interpolate(img, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
303
+ img = img.permute(1, 2, 0)
304
+
305
+ #sam_outputs = torch.zeros_like(catseg_outputs)
306
+ #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
307
+ sam_classes = torch.zeros(len(masks))
308
+ #import pdb; pdb.set_trace()
309
+ mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W
310
+ mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N
311
+
312
+ mask_pool = torch.einsum("nhw, hwd -> nd ", mask_pred, img)
313
+ mask_pool = mask_pool / mask_pool.norm(dim=1, keepdim=True)
314
+ mask_cls = torch.einsum("nd, cd -> nc", 100 * mask_pool, text.cpu())
315
+ mask_cls = mask_cls.softmax(dim=1)
316
+
317
+ #mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs)
318
+ mask_norm = mask_pred.sum(-1).sum(-1)
319
+ mask_cls = mask_cls / mask_norm[:, None]
320
+ mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None]
321
+
322
+ mask_logits = mask_pred * mask_score[:, None, None]
323
+ output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls)
324
+
325
+ return output.unsqueeze(0), sam_classes
326
+
327
+ def panoptic_inference(self, outputs, masks, sam_classes, size=None):
328
+ #import pdb; pdb.set_trace()
329
+ scores = np.asarray([x['predicted_iou'] for x in masks])
330
+ mask_pred = np.asarray([x['segmentation'] for x in masks])
331
+
332
+ #keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
333
+ cur_scores = torch.tensor(scores)
334
+ cur_masks = torch.tensor(mask_pred)
335
+ cur_masks = F.interpolate(cur_masks.unsqueeze(0).float(), size=outputs.shape[-2:], mode="nearest")[0]
336
+ cur_classes = sam_classes.argmax(dim=-1)
337
+ #cur_mask_cls = mask_cls#[keep]
338
+ #cur_mask_cls = cur_mask_cls[:, :-1]
339
+
340
+ #import pdb; pdb.set_trace()
341
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
342
+
343
+ h, w = cur_masks.shape[-2:]
344
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
345
+ segments_info = []
346
+
347
+ current_segment_id = 0
348
+ if cur_masks.shape[0] == 0:
349
+ # We didn't detect any mask :(
350
+ return panoptic_seg, segments_info
351
+ else:
352
+ # take argmax
353
+ cur_mask_ids = cur_prob_masks.argmax(0)
354
+ stuff_memory_list = {}
355
+ for k in range(cur_classes.shape[0]):
356
+ pred_class = cur_classes[k].item()
357
+ #isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
358
+ isthing = pred_class in [3, 6] #[i for i in range(10)]#self.metadata.thing_dataset_id_to_contiguous_id.values()
359
+ mask = cur_mask_ids == k
360
+ mask_area = mask.sum().item()
361
+ original_area = (cur_masks[k] >= 0.5).sum().item()
362
+
363
+ if mask_area > 0 and original_area > 0:
364
+ if mask_area / original_area < self.overlap_threshold:
365
+ continue
366
+
367
+ # merge stuff regions
368
+ if not isthing:
369
+ if int(pred_class) in stuff_memory_list.keys():
370
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
371
+ continue
372
+ else:
373
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
374
+
375
+ current_segment_id += 1
376
+ panoptic_seg[mask] = current_segment_id
377
+
378
+ segments_info.append(
379
+ {
380
+ "id": current_segment_id,
381
+ "isthing": bool(isthing),
382
+ "category_id": int(pred_class),
383
+ }
384
+ )
385
+
386
+ return panoptic_seg, segments_info
cat_seg/config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from detectron2.config import CfgNode as CN
4
+
5
+
6
+ def add_cat_seg_config(cfg):
7
+ """
8
+ Add config for MASK_FORMER.
9
+ """
10
+ # data config
11
+ # select the dataset mapper
12
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
13
+
14
+ cfg.DATASETS.VAL_ALL = ("coco_2017_val_all_stuff_sem_seg",)
15
+
16
+ # Color augmentation
17
+ cfg.INPUT.COLOR_AUG_SSD = False
18
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
19
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
20
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
21
+ # Pad image and segmentation GT in dataset mapper.
22
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
23
+
24
+ # solver config
25
+ # weight decay on embedding
26
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
27
+ # optimizer
28
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
29
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
30
+
31
+ # mask_former model config
32
+ cfg.MODEL.MASK_FORMER = CN()
33
+
34
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
35
+ # you can use this config to override
36
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
37
+
38
+ # swin transformer backbone
39
+ cfg.MODEL.SWIN = CN()
40
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
41
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
42
+ cfg.MODEL.SWIN.EMBED_DIM = 96
43
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
44
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
45
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
46
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
47
+ cfg.MODEL.SWIN.QKV_BIAS = True
48
+ cfg.MODEL.SWIN.QK_SCALE = None
49
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
50
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
51
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
52
+ cfg.MODEL.SWIN.APE = False
53
+ cfg.MODEL.SWIN.PATCH_NORM = True
54
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
55
+
56
+ # zero shot config
57
+ cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json"
58
+ cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json"
59
+ cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_INDEXES = "datasets/coco/coco_stuff/split/seen_indexes.json"
60
+ cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_INDEXES = "datasets/coco/coco_stuff/split/unseen_indexes.json"
61
+
62
+ cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED = "ViT-B/16"
63
+
64
+ cfg.MODEL.PROMPT_ENSEMBLE = False
65
+ cfg.MODEL.PROMPT_ENSEMBLE_TYPE = "single"
66
+
67
+ cfg.MODEL.CLIP_PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615]
68
+ cfg.MODEL.CLIP_PIXEL_STD = [68.5005327, 66.6321579, 70.3231630]
69
+ # three styles for clip classification, crop, mask, cropmask
70
+
71
+ cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM = 512
72
+ cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM = 128
73
+ cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM = 512
74
+ cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM = 128
75
+
76
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS = [64, 32]
77
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS = [256, 128]
78
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS = [32, 16]
79
+
80
+ cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS = 4
81
+ cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS = 4
82
+ cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS = 128
83
+ cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES = [6, 6]
84
+ cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION = [24, 24]
85
+ cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES = 12
86
+ cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE = "linear"
87
+
88
+ cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH = 0
89
+ cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH = 0
90
+ cfg.SOLVER.CLIP_MULTIPLIER = 0.01
91
+
92
+ cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE = "attention"
93
+ cfg.TEST.SLIDING_WINDOW = False
cat_seg/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import datasets
cat_seg/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (184 Bytes). View file
 
cat_seg/data/dataset_mappers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (167 Bytes). View file
 
cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc ADDED
Binary file (4.88 kB). View file
 
cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc ADDED
Binary file (4.41 kB). View file
 
cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc ADDED
Binary file (5.05 kB). View file
 
cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
3
+ import copy
4
+ import logging
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import detection_utils as utils
11
+ from detectron2.data import transforms as T
12
+ from detectron2.data.transforms import TransformGen
13
+ from detectron2.structures import BitMasks, Instances
14
+
15
+ __all__ = ["DETRPanopticDatasetMapper"]
16
+
17
+
18
+ def build_transform_gen(cfg, is_train):
19
+ """
20
+ Create a list of :class:`TransformGen` from config.
21
+ Returns:
22
+ list[TransformGen]
23
+ """
24
+ if is_train:
25
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
26
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
27
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
28
+ else:
29
+ min_size = cfg.INPUT.MIN_SIZE_TEST
30
+ max_size = cfg.INPUT.MAX_SIZE_TEST
31
+ sample_style = "choice"
32
+ if sample_style == "range":
33
+ assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(
34
+ len(min_size)
35
+ )
36
+
37
+ logger = logging.getLogger(__name__)
38
+ tfm_gens = []
39
+ if is_train:
40
+ tfm_gens.append(T.RandomFlip())
41
+ tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
42
+ if is_train:
43
+ logger.info("TransformGens used in training: " + str(tfm_gens))
44
+ return tfm_gens
45
+
46
+
47
+ # This is specifically designed for the COCO dataset.
48
+ class DETRPanopticDatasetMapper:
49
+ """
50
+ A callable which takes a dataset dict in Detectron2 Dataset format,
51
+ and map it into a format used by MaskFormer.
52
+
53
+ This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
54
+
55
+ The callable currently does the following:
56
+
57
+ 1. Read the image from "file_name"
58
+ 2. Applies geometric transforms to the image and annotation
59
+ 3. Find and applies suitable cropping to the image and annotation
60
+ 4. Prepare image and annotation to Tensors
61
+ """
62
+
63
+ @configurable
64
+ def __init__(
65
+ self,
66
+ is_train=True,
67
+ *,
68
+ crop_gen,
69
+ tfm_gens,
70
+ image_format,
71
+ ):
72
+ """
73
+ NOTE: this interface is experimental.
74
+ Args:
75
+ is_train: for training or inference
76
+ augmentations: a list of augmentations or deterministic transforms to apply
77
+ crop_gen: crop augmentation
78
+ tfm_gens: data augmentation
79
+ image_format: an image format supported by :func:`detection_utils.read_image`.
80
+ """
81
+ self.crop_gen = crop_gen
82
+ self.tfm_gens = tfm_gens
83
+ logging.getLogger(__name__).info(
84
+ "[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format(
85
+ str(self.tfm_gens), str(self.crop_gen)
86
+ )
87
+ )
88
+
89
+ self.img_format = image_format
90
+ self.is_train = is_train
91
+
92
+ @classmethod
93
+ def from_config(cls, cfg, is_train=True):
94
+ # Build augmentation
95
+ if cfg.INPUT.CROP.ENABLED and is_train:
96
+ crop_gen = [
97
+ T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
98
+ T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
99
+ ]
100
+ else:
101
+ crop_gen = None
102
+
103
+ tfm_gens = build_transform_gen(cfg, is_train)
104
+
105
+ ret = {
106
+ "is_train": is_train,
107
+ "crop_gen": crop_gen,
108
+ "tfm_gens": tfm_gens,
109
+ "image_format": cfg.INPUT.FORMAT,
110
+ }
111
+ return ret
112
+
113
+ def __call__(self, dataset_dict):
114
+ """
115
+ Args:
116
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
117
+
118
+ Returns:
119
+ dict: a format that builtin models in detectron2 accept
120
+ """
121
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
122
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
123
+ utils.check_image_size(dataset_dict, image)
124
+
125
+ if self.crop_gen is None:
126
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
127
+ else:
128
+ if np.random.rand() > 0.5:
129
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
130
+ else:
131
+ image, transforms = T.apply_transform_gens(
132
+ self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
133
+ )
134
+
135
+ image_shape = image.shape[:2] # h, w
136
+
137
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
138
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
139
+ # Therefore it's important to use torch.Tensor.
140
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
141
+
142
+ if not self.is_train:
143
+ # USER: Modify this if you want to keep them for some reason.
144
+ dataset_dict.pop("annotations", None)
145
+ return dataset_dict
146
+
147
+ if "pan_seg_file_name" in dataset_dict:
148
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
149
+ segments_info = dataset_dict["segments_info"]
150
+
151
+ # apply the same transformation to panoptic segmentation
152
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
153
+
154
+ from panopticapi.utils import rgb2id
155
+
156
+ pan_seg_gt = rgb2id(pan_seg_gt)
157
+
158
+ instances = Instances(image_shape)
159
+ classes = []
160
+ masks = []
161
+ for segment_info in segments_info:
162
+ class_id = segment_info["category_id"]
163
+ if not segment_info["iscrowd"]:
164
+ classes.append(class_id)
165
+ masks.append(pan_seg_gt == segment_info["id"])
166
+
167
+ classes = np.array(classes)
168
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
169
+ if len(masks) == 0:
170
+ # Some image does not have annotation (all ignored)
171
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
172
+ else:
173
+ masks = BitMasks(
174
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
175
+ )
176
+ instances.gt_masks = masks.tensor
177
+
178
+ dataset_dict["instances"] = instances
179
+
180
+ return dataset_dict
cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import detection_utils as utils
11
+ from detectron2.data import transforms as T
12
+ from detectron2.structures import BitMasks, Instances
13
+
14
+ from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
15
+
16
+ __all__ = ["MaskFormerPanopticDatasetMapper"]
17
+
18
+
19
+ class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper):
20
+ """
21
+ A callable which takes a dataset dict in Detectron2 Dataset format,
22
+ and map it into a format used by MaskFormer for panoptic segmentation.
23
+
24
+ The callable currently does the following:
25
+
26
+ 1. Read the image from "file_name"
27
+ 2. Applies geometric transforms to the image and annotation
28
+ 3. Find and applies suitable cropping to the image and annotation
29
+ 4. Prepare image and annotation to Tensors
30
+ """
31
+
32
+ @configurable
33
+ def __init__(
34
+ self,
35
+ is_train=True,
36
+ *,
37
+ augmentations,
38
+ image_format,
39
+ ignore_label,
40
+ size_divisibility,
41
+ ):
42
+ """
43
+ NOTE: this interface is experimental.
44
+ Args:
45
+ is_train: for training or inference
46
+ augmentations: a list of augmentations or deterministic transforms to apply
47
+ image_format: an image format supported by :func:`detection_utils.read_image`.
48
+ ignore_label: the label that is ignored to evaluation
49
+ size_divisibility: pad image size to be divisible by this value
50
+ """
51
+ super().__init__(
52
+ is_train,
53
+ augmentations=augmentations,
54
+ image_format=image_format,
55
+ ignore_label=ignore_label,
56
+ size_divisibility=size_divisibility,
57
+ )
58
+
59
+ def __call__(self, dataset_dict):
60
+ """
61
+ Args:
62
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
63
+
64
+ Returns:
65
+ dict: a format that builtin models in detectron2 accept
66
+ """
67
+ assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!"
68
+
69
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
70
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
71
+ utils.check_image_size(dataset_dict, image)
72
+
73
+ # semantic segmentation
74
+ if "sem_seg_file_name" in dataset_dict:
75
+ # PyTorch transformation not implemented for uint16, so converting it to double first
76
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
77
+ else:
78
+ sem_seg_gt = None
79
+
80
+ # panoptic segmentation
81
+ if "pan_seg_file_name" in dataset_dict:
82
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
83
+ segments_info = dataset_dict["segments_info"]
84
+ else:
85
+ pan_seg_gt = None
86
+ segments_info = None
87
+
88
+ if pan_seg_gt is None:
89
+ raise ValueError(
90
+ "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
91
+ dataset_dict["file_name"]
92
+ )
93
+ )
94
+
95
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
96
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
97
+ image = aug_input.image
98
+ if sem_seg_gt is not None:
99
+ sem_seg_gt = aug_input.sem_seg
100
+
101
+ # apply the same transformation to panoptic segmentation
102
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
103
+
104
+ from panopticapi.utils import rgb2id
105
+
106
+ pan_seg_gt = rgb2id(pan_seg_gt)
107
+
108
+ # Pad image and segmentation label here!
109
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
110
+ if sem_seg_gt is not None:
111
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
112
+ pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
113
+
114
+ if self.size_divisibility > 0:
115
+ image_size = (image.shape[-2], image.shape[-1])
116
+ padding_size = [
117
+ 0,
118
+ self.size_divisibility - image_size[1],
119
+ 0,
120
+ self.size_divisibility - image_size[0],
121
+ ]
122
+ image = F.pad(image, padding_size, value=128).contiguous()
123
+ if sem_seg_gt is not None:
124
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
125
+ pan_seg_gt = F.pad(
126
+ pan_seg_gt, padding_size, value=0
127
+ ).contiguous() # 0 is the VOID panoptic label
128
+
129
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
130
+
131
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
132
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
133
+ # Therefore it's important to use torch.Tensor.
134
+ dataset_dict["image"] = image
135
+ if sem_seg_gt is not None:
136
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
137
+
138
+ if "annotations" in dataset_dict:
139
+ raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
140
+
141
+ # Prepare per-category binary masks
142
+ pan_seg_gt = pan_seg_gt.numpy()
143
+ instances = Instances(image_shape)
144
+ classes = []
145
+ masks = []
146
+ for segment_info in segments_info:
147
+ class_id = segment_info["category_id"]
148
+ if not segment_info["iscrowd"]:
149
+ classes.append(class_id)
150
+ masks.append(pan_seg_gt == segment_info["id"])
151
+
152
+ classes = np.array(classes)
153
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
154
+ if len(masks) == 0:
155
+ # Some image does not have annotation (all ignored)
156
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
157
+ else:
158
+ masks = BitMasks(
159
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
160
+ )
161
+ instances.gt_masks = masks.tensor
162
+
163
+ dataset_dict["instances"] = instances
164
+
165
+ return dataset_dict
cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import MetadataCatalog
11
+ from detectron2.data import detection_utils as utils
12
+ from detectron2.data import transforms as T
13
+ from detectron2.projects.point_rend import ColorAugSSDTransform
14
+ from detectron2.structures import BitMasks, Instances
15
+
16
+ __all__ = ["MaskFormerSemanticDatasetMapper"]
17
+
18
+
19
+ class MaskFormerSemanticDatasetMapper:
20
+ """
21
+ A callable which takes a dataset dict in Detectron2 Dataset format,
22
+ and map it into a format used by MaskFormer for semantic segmentation.
23
+
24
+ The callable currently does the following:
25
+
26
+ 1. Read the image from "file_name"
27
+ 2. Applies geometric transforms to the image and annotation
28
+ 3. Find and applies suitable cropping to the image and annotation
29
+ 4. Prepare image and annotation to Tensors
30
+ """
31
+
32
+ @configurable
33
+ def __init__(
34
+ self,
35
+ is_train=True,
36
+ *,
37
+ augmentations,
38
+ image_format,
39
+ ignore_label,
40
+ size_divisibility,
41
+ ):
42
+ """
43
+ NOTE: this interface is experimental.
44
+ Args:
45
+ is_train: for training or inference
46
+ augmentations: a list of augmentations or deterministic transforms to apply
47
+ image_format: an image format supported by :func:`detection_utils.read_image`.
48
+ ignore_label: the label that is ignored to evaluation
49
+ size_divisibility: pad image size to be divisible by this value
50
+ """
51
+ self.is_train = is_train
52
+ self.tfm_gens = augmentations
53
+ self.img_format = image_format
54
+ self.ignore_label = ignore_label
55
+ self.size_divisibility = size_divisibility
56
+
57
+ logger = logging.getLogger(__name__)
58
+ mode = "training" if is_train else "inference"
59
+ logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
60
+
61
+ @classmethod
62
+ def from_config(cls, cfg, is_train=True):
63
+ # Build augmentation
64
+ augs = [
65
+ T.ResizeShortestEdge(
66
+ cfg.INPUT.MIN_SIZE_TRAIN,
67
+ cfg.INPUT.MAX_SIZE_TRAIN,
68
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
69
+ )
70
+ ]
71
+ if cfg.INPUT.CROP.ENABLED:
72
+ augs.append(
73
+ T.RandomCrop_CategoryAreaConstraint(
74
+ cfg.INPUT.CROP.TYPE,
75
+ cfg.INPUT.CROP.SIZE,
76
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
77
+ cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
78
+ )
79
+ )
80
+ if cfg.INPUT.COLOR_AUG_SSD:
81
+ augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
82
+ augs.append(T.RandomFlip())
83
+
84
+ # Assume always applies to the training set.
85
+ dataset_names = cfg.DATASETS.TRAIN
86
+ meta = MetadataCatalog.get(dataset_names[0])
87
+ ignore_label = meta.ignore_label
88
+
89
+ ret = {
90
+ "is_train": is_train,
91
+ "augmentations": augs,
92
+ "image_format": cfg.INPUT.FORMAT,
93
+ "ignore_label": ignore_label,
94
+ "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
95
+ }
96
+ return ret
97
+
98
+ def __call__(self, dataset_dict):
99
+ """
100
+ Args:
101
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
102
+
103
+ Returns:
104
+ dict: a format that builtin models in detectron2 accept
105
+ """
106
+ assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
107
+
108
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
109
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
110
+ utils.check_image_size(dataset_dict, image)
111
+
112
+ if "sem_seg_file_name" in dataset_dict:
113
+ # PyTorch transformation not implemented for uint16, so converting it to double first
114
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
115
+ else:
116
+ sem_seg_gt = None
117
+
118
+ if sem_seg_gt is None:
119
+ raise ValueError(
120
+ "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
121
+ dataset_dict["file_name"]
122
+ )
123
+ )
124
+
125
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
126
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
127
+ image = aug_input.image
128
+ sem_seg_gt = aug_input.sem_seg
129
+
130
+ # Pad image and segmentation label here!
131
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
132
+ if sem_seg_gt is not None:
133
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
134
+ # import ipdb; ipdb.set_trace()
135
+ if self.size_divisibility > 0:
136
+ image_size = (image.shape[-2], image.shape[-1])
137
+ # The ori_size is not the real original size, but size before padding
138
+ dataset_dict['ori_size'] = image_size
139
+ padding_size = [
140
+ 0,
141
+ self.size_divisibility - image_size[1], # w: (left, right)
142
+ 0,
143
+ self.size_divisibility - image_size[0], # h: 0,(top, bottom)
144
+ ]
145
+ image = F.pad(image, padding_size, value=128).contiguous()
146
+ if sem_seg_gt is not None:
147
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
148
+
149
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
150
+
151
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
152
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
153
+ # Therefore it's important to use torch.Tensor.
154
+ dataset_dict["image"] = image
155
+ # print('#########################################################################################')
156
+ if sem_seg_gt is not None:
157
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
158
+
159
+ if "annotations" in dataset_dict:
160
+ raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
161
+
162
+ # Prepare per-category binary masks
163
+ if sem_seg_gt is not None:
164
+ sem_seg_gt = sem_seg_gt.numpy()
165
+ instances = Instances(image_shape)
166
+ classes = np.unique(sem_seg_gt)
167
+ # remove ignored region
168
+ classes = classes[classes != self.ignore_label]
169
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
170
+
171
+ masks = []
172
+ for class_id in classes:
173
+ masks.append(sem_seg_gt == class_id)
174
+
175
+ if len(masks) == 0:
176
+ # Some image does not have annotation (all ignored)
177
+ instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
178
+ else:
179
+ masks = BitMasks(
180
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
181
+ )
182
+ instances.gt_masks = masks.tensor
183
+
184
+ dataset_dict["instances"] = instances
185
+
186
+ return dataset_dict
cat_seg/data/datasets/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import (
3
+ register_coco_stuff,
4
+ register_ade20k_150,
5
+ register_ade20k_847,
6
+ register_pascal_20,
7
+ register_pascal_59,
8
+ )
cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (322 Bytes). View file
 
cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc ADDED
Binary file (2.88 kB). View file
 
cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc ADDED
Binary file (51.8 kB). View file
 
cat_seg/data/datasets/__pycache__/register_ade_panoptic.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
cat_seg/data/datasets/__pycache__/register_coco_panoptic.cpython-38.pyc ADDED
Binary file (4.75 kB). View file
 
cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc ADDED
Binary file (7.85 kB). View file
 
cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc ADDED
Binary file (2.47 kB). View file
 
cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc ADDED
Binary file (9.57 kB). View file
 
cat_seg/data/datasets/__pycache__/register_pascal_context.cpython-38.pyc ADDED
Binary file (9.56 kB). View file
 
cat_seg/data/datasets/register_ade20k_150.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from detectron2.data import DatasetCatalog, MetadataCatalog
4
+ from detectron2.data.datasets import load_sem_seg
5
+ import copy
6
+
7
+ def _get_ade20k_150_meta():
8
+ ade20k_150_classes = ["wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed ", "windowpane", "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion", "base", "box", "column", "signboard", "chest of drawers", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs", "runway", "case", "pool table", "pillow", "screen door", "stairway", "river", "bridge", "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench", "countertop", "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth", "television receiver", "airplane", "dirt track", "apparel", "pole", "land", "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool", "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food", "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass", "clock", "flag"]
9
+
10
+ ret = {
11
+ "stuff_classes" : ade20k_150_classes,
12
+ }
13
+ return ret
14
+
15
+ def register_ade20k_150(root):
16
+ root = os.path.join(root, "ADEChallengeData2016")
17
+ meta = _get_ade20k_150_meta()
18
+ for name, image_dirname, sem_seg_dirname in [
19
+ ("test", "images/validation", "annotations_detectron2/validation"),
20
+ ]:
21
+ image_dir = os.path.join(root, image_dirname)
22
+ gt_dir = os.path.join(root, sem_seg_dirname)
23
+ name = f"ade20k_150_{name}_sem_seg"
24
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
25
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
26
+
27
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
28
+ register_ade20k_150(_root)
cat_seg/data/datasets/register_ade20k_847.py ADDED
The diff for this file is too large to render. See raw diff
 
cat_seg/data/datasets/register_coco_stuff.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from detectron2.data import DatasetCatalog, MetadataCatalog
4
+ from detectron2.data.datasets import load_sem_seg
5
+
6
+ COCO_CATEGORIES = [
7
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
8
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
9
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
10
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
11
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
12
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
13
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
14
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
15
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
16
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
17
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
18
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
19
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
20
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
21
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
22
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
23
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
24
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
25
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
26
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
27
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
28
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
29
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
30
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
31
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
32
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
33
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
34
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
35
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
36
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
37
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
38
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
39
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
40
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
41
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
42
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
43
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
44
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
45
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
46
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
47
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
48
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
49
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
50
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
51
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
52
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
53
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
54
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
55
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
56
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
57
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
58
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
59
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
60
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
61
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
62
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
63
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
64
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
65
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
66
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
67
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
68
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
69
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
70
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
71
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
72
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
73
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
74
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
75
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
76
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
77
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
78
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
79
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
80
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
81
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
82
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
83
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
84
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
85
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
86
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
87
+ {"id": 92, "name": "banner", "supercategory": "textile"},
88
+ {"id": 93, "name": "blanket", "supercategory": "textile"},
89
+ {"id": 94, "name": "branch", "supercategory": "plant"},
90
+ {"id": 95, "name": "bridge", "supercategory": "building"},
91
+ {"id": 96, "name": "building-other", "supercategory": "building"},
92
+ {"id": 97, "name": "bush", "supercategory": "plant"},
93
+ {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
94
+ {"id": 99, "name": "cage", "supercategory": "structural"},
95
+ {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
96
+ {"id": 101, "name": "carpet", "supercategory": "floor"},
97
+ {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
98
+ {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
99
+ {"id": 104, "name": "cloth", "supercategory": "textile"},
100
+ {"id": 105, "name": "clothes", "supercategory": "textile"},
101
+ {"id": 106, "name": "clouds", "supercategory": "sky"},
102
+ {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
103
+ {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
104
+ {"id": 109, "name": "curtain", "supercategory": "textile"},
105
+ {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
106
+ {"id": 111, "name": "dirt", "supercategory": "ground"},
107
+ {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
108
+ {"id": 113, "name": "fence", "supercategory": "structural"},
109
+ {"id": 114, "name": "floor-marble", "supercategory": "floor"},
110
+ {"id": 115, "name": "floor-other", "supercategory": "floor"},
111
+ {"id": 116, "name": "floor-stone", "supercategory": "floor"},
112
+ {"id": 117, "name": "floor-tile", "supercategory": "floor"},
113
+ {"id": 118, "name": "floor-wood", "supercategory": "floor"},
114
+ {"id": 119, "name": "flower", "supercategory": "plant"},
115
+ {"id": 120, "name": "fog", "supercategory": "water"},
116
+ {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
117
+ {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
118
+ {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
119
+ {"id": 124, "name": "grass", "supercategory": "plant"},
120
+ {"id": 125, "name": "gravel", "supercategory": "ground"},
121
+ {"id": 126, "name": "ground-other", "supercategory": "ground"},
122
+ {"id": 127, "name": "hill", "supercategory": "solid"},
123
+ {"id": 128, "name": "house", "supercategory": "building"},
124
+ {"id": 129, "name": "leaves", "supercategory": "plant"},
125
+ {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
126
+ {"id": 131, "name": "mat", "supercategory": "textile"},
127
+ {"id": 132, "name": "metal", "supercategory": "raw-material"},
128
+ {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
129
+ {"id": 134, "name": "moss", "supercategory": "plant"},
130
+ {"id": 135, "name": "mountain", "supercategory": "solid"},
131
+ {"id": 136, "name": "mud", "supercategory": "ground"},
132
+ {"id": 137, "name": "napkin", "supercategory": "textile"},
133
+ {"id": 138, "name": "net", "supercategory": "structural"},
134
+ {"id": 139, "name": "paper", "supercategory": "raw-material"},
135
+ {"id": 140, "name": "pavement", "supercategory": "ground"},
136
+ {"id": 141, "name": "pillow", "supercategory": "textile"},
137
+ {"id": 142, "name": "plant-other", "supercategory": "plant"},
138
+ {"id": 143, "name": "plastic", "supercategory": "raw-material"},
139
+ {"id": 144, "name": "platform", "supercategory": "ground"},
140
+ {"id": 145, "name": "playingfield", "supercategory": "ground"},
141
+ {"id": 146, "name": "railing", "supercategory": "structural"},
142
+ {"id": 147, "name": "railroad", "supercategory": "ground"},
143
+ {"id": 148, "name": "river", "supercategory": "water"},
144
+ {"id": 149, "name": "road", "supercategory": "ground"},
145
+ {"id": 150, "name": "rock", "supercategory": "solid"},
146
+ {"id": 151, "name": "roof", "supercategory": "building"},
147
+ {"id": 152, "name": "rug", "supercategory": "textile"},
148
+ {"id": 153, "name": "salad", "supercategory": "food-stuff"},
149
+ {"id": 154, "name": "sand", "supercategory": "ground"},
150
+ {"id": 155, "name": "sea", "supercategory": "water"},
151
+ {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
152
+ {"id": 157, "name": "sky-other", "supercategory": "sky"},
153
+ {"id": 158, "name": "skyscraper", "supercategory": "building"},
154
+ {"id": 159, "name": "snow", "supercategory": "ground"},
155
+ {"id": 160, "name": "solid-other", "supercategory": "solid"},
156
+ {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
157
+ {"id": 162, "name": "stone", "supercategory": "solid"},
158
+ {"id": 163, "name": "straw", "supercategory": "plant"},
159
+ {"id": 164, "name": "structural-other", "supercategory": "structural"},
160
+ {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
161
+ {"id": 166, "name": "tent", "supercategory": "building"},
162
+ {"id": 167, "name": "textile-other", "supercategory": "textile"},
163
+ {"id": 168, "name": "towel", "supercategory": "textile"},
164
+ {"id": 169, "name": "tree", "supercategory": "plant"},
165
+ {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
166
+ {"id": 171, "name": "wall-brick", "supercategory": "wall"},
167
+ {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
168
+ {"id": 173, "name": "wall-other", "supercategory": "wall"},
169
+ {"id": 174, "name": "wall-panel", "supercategory": "wall"},
170
+ {"id": 175, "name": "wall-stone", "supercategory": "wall"},
171
+ {"id": 176, "name": "wall-tile", "supercategory": "wall"},
172
+ {"id": 177, "name": "wall-wood", "supercategory": "wall"},
173
+ {"id": 178, "name": "water-other", "supercategory": "water"},
174
+ {"id": 179, "name": "waterdrops", "supercategory": "water"},
175
+ {"id": 180, "name": "window-blind", "supercategory": "window"},
176
+ {"id": 181, "name": "window-other", "supercategory": "window"},
177
+ {"id": 182, "name": "wood", "supercategory": "solid"},
178
+ ]
179
+
180
+
181
+ def _get_coco_stuff_meta():
182
+ stuff_ids = [k["id"] for k in COCO_CATEGORIES]
183
+ assert len(stuff_ids) == 171, len(stuff_ids)
184
+
185
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
186
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
187
+
188
+ ret = {
189
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
190
+ "stuff_classes": stuff_classes,
191
+ }
192
+ return ret
193
+
194
+ def register_all_coco_stuff_10k(root):
195
+ root = os.path.join(root, "coco-stuff")
196
+ meta = _get_coco_stuff_meta()
197
+ for name, image_dirname, sem_seg_dirname in [
198
+ ("train", "images/train2017", "annotations_detectron2/train2017"),
199
+ ("test", "images/val2017", "annotations_detectron2/val2017"),
200
+ ]:
201
+ image_dir = os.path.join(root, image_dirname)
202
+ gt_dir = os.path.join(root, sem_seg_dirname)
203
+ name = f"coco_2017_{name}_stuff_all_sem_seg"
204
+ DatasetCatalog.register(
205
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
206
+ )
207
+ MetadataCatalog.get(name).set(
208
+ image_root=image_dir,
209
+ sem_seg_root=gt_dir,
210
+ evaluator_type="sem_seg",
211
+ ignore_label=255,
212
+ **meta,
213
+ )
214
+
215
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
216
+ register_all_coco_stuff_10k(_root)
cat_seg/data/datasets/register_pascal_20.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from detectron2.data import DatasetCatalog, MetadataCatalog
4
+ from detectron2.data.datasets import load_sem_seg
5
+ import copy
6
+
7
+ def _get_pascal_voc_meta():
8
+ voc_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
9
+ voc_colors = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
10
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
11
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
12
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
13
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
14
+ ret = {
15
+ "stuff_classes" : voc_classes,
16
+ "stuff_colors" : voc_colors,
17
+ }
18
+ return ret
19
+
20
+ def register_all_pascal_voc(root):
21
+ root = os.path.join(root, "VOCdevkit/VOC2012")
22
+ meta = _get_pascal_voc_meta()
23
+ for name, image_dirname, sem_seg_dirname in [
24
+ ("test", "JPEGImages", "annotations_detectron2"),
25
+ ("test_background", "JPEGImages", "annotations_detectron2_bg"),
26
+ ]:
27
+ image_dir = os.path.join(root, image_dirname)
28
+ gt_dir = os.path.join(root, sem_seg_dirname, 'val')
29
+ name = f"voc_2012_{name}_sem_seg"
30
+
31
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
32
+ if "background" in name:
33
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg_background", ignore_label=255,
34
+ stuff_classes=meta["stuff_classes"] + ["background"], stuff_colors=meta["stuff_colors"])
35
+ else:
36
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
37
+
38
+ def register_all_pascal_voc_background(root):
39
+ root = os.path.join(root, "VOCdevkit/VOC2012")
40
+ meta = _get_pascal_voc_meta()
41
+ meta["stuff_classes"] = meta["stuff_classes"] + ["background"]
42
+ for name, image_dirname, sem_seg_dirname in [
43
+ ("test_background", "image", "label_openseg_background20"),
44
+ ]:
45
+ image_dir = os.path.join(root, image_dirname, 'validation')
46
+ gt_dir = os.path.join(root, sem_seg_dirname, 'validation')
47
+ name = f"voc_2012_{name}_sem_seg"
48
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
49
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg_background", ignore_label=255, **meta,)
50
+
51
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
52
+ register_all_pascal_voc(_root)
53
+ #register_all_pascal_voc_background(_root)
cat_seg/data/datasets/register_pascal_59.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from detectron2.data import DatasetCatalog, MetadataCatalog
4
+ from detectron2.data.datasets import load_sem_seg
5
+ import copy
6
+
7
+
8
+ stuff_colors = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
9
+ [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
10
+ [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
11
+ [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
12
+ [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
13
+ [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
14
+ [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
15
+ [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
16
+ [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
17
+ [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
18
+ [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
19
+ [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
20
+ [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
21
+ [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
22
+ [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
23
+ [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
24
+ [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
25
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
26
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
27
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
28
+ [0, 0, 230], [119, 11, 32],
29
+ [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
30
+ [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
31
+ [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
32
+ [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
33
+ [64, 192, 96], [64, 160, 64], [64, 64, 0]]
34
+
35
+ def _get_pascal_context_59_meta():
36
+ #context_classes = ["aeroplane", "bag", "bed", "bedclothes", "bench", "bicycle", "bird", "boat", "book", "bottle", "building", "bus", "cabinet", "car", "cat", "ceiling", "chair", "cloth", "computer", "cow", "cup", "curtain", "dog", "door", "fence", "floor", "flower", "food", "grass", "ground", "horse", "keyboard", "light", "motorbike", "mountain", "mouse", "person", "plate", "platform", "pottedplant", "road", "rock", "sheep", "shelves", "sidewalk", "sign", "sky", "snow", "sofa", "diningtable", "track", "train", "tree", "truck", "tvmonitor", "wall", "water", "window", "wood"]#, "background"]
37
+ context_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", "bag", "bed", "bench", "book", "building", "cabinet", "ceiling", "cloth", "computer", "cup", "door", "fence", "floor", "flower", "food", "grass", "ground", "keyboard", "light", "mountain", "mouse", "curtain", "platform", "sign", "plate", "road", "rock", "shelves", "sidewalk", "sky", "snow", "bedclothes", "track", "tree", "truck", "wall", "water", "window", "wood"]
38
+ context_colors = [stuff_colors[i % len(stuff_colors)] for i in range(len(context_classes))]
39
+ ret = {
40
+ "stuff_colors" : context_colors,
41
+ "stuff_classes" : context_classes,
42
+ }
43
+ return ret
44
+
45
+ def register_pascal_context_59(root):
46
+ root = os.path.join(root, "VOCdevkit", "VOC2010")
47
+ meta = _get_pascal_context_59_meta()
48
+ for name, image_dirname, sem_seg_dirname in [
49
+ ("test", "JPEGImages", "annotations_detectron2/pc59_val"),
50
+ ]:
51
+ image_dir = os.path.join(root, image_dirname)
52
+ gt_dir = os.path.join(root, sem_seg_dirname)
53
+ name = f"context_59_{name}_sem_seg"
54
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
55
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
56
+
57
+ def _get_pascal_context_459_meta():
58
+ context_459_classes = ["accordion", "aeroplane", "airconditioner", "antenna", "artillery", "ashtray", "atrium", "babycarriage", "bag", "ball", "balloon", "bambooweaving", "barrel", "baseballbat", "basket", "basketballbackboard", "bathtub", "bed", "bedclothes", "beer", "bell", "bench", "bicycle", "binoculars", "bird", "birdcage", "birdfeeder", "birdnest", "blackboard", "board", "boat", "bone", "book", "bottle", "bottleopener", "bowl", "box", "bracelet", "brick", "bridge", "broom", "brush", "bucket", "building", "bus", "cabinet", "cabinetdoor", "cage", "cake", "calculator", "calendar", "camel", "camera", "cameralens", "can", "candle", "candleholder", "cap", "car", "card", "cart", "case", "casetterecorder", "cashregister", "cat", "cd", "cdplayer", "ceiling", "cellphone", "cello", "chain", "chair", "chessboard", "chicken", "chopstick", "clip", "clippers", "clock", "closet", "cloth", "clothestree", "coffee", "coffeemachine", "comb", "computer", "concrete", "cone", "container", "controlbooth", "controller", "cooker", "copyingmachine", "coral", "cork", "corkscrew", "counter", "court", "cow", "crabstick", "crane", "crate", "cross", "crutch", "cup", "curtain", "cushion", "cuttingboard", "dais", "disc", "disccase", "dishwasher", "dock", "dog", "dolphin", "door", "drainer", "dray", "drinkdispenser", "drinkingmachine", "drop", "drug", "drum", "drumkit", "duck", "dumbbell", "earphone", "earrings", "egg", "electricfan", "electriciron", "electricpot", "electricsaw", "electronickeyboard", "engine", "envelope", "equipment", "escalator", "exhibitionbooth", "extinguisher", "eyeglass", "fan", "faucet", "faxmachine", "fence", "ferriswheel", "fireextinguisher", "firehydrant", "fireplace", "fish", "fishtank", "fishbowl", "fishingnet", "fishingpole", "flag", "flagstaff", "flame", "flashlight", "floor", "flower", "fly", "foam", "food", "footbridge", "forceps", "fork", "forklift", "fountain", "fox", "frame", "fridge", "frog", "fruit", "funnel", "furnace", "gamecontroller", "gamemachine", "gascylinder", "gashood", "gasstove", "giftbox", "glass", "glassmarble", "globe", "glove", "goal", "grandstand", "grass", "gravestone", "ground", "guardrail", "guitar", "gun", "hammer", "handcart", "handle", "handrail", "hanger", "harddiskdrive", "hat", "hay", "headphone", "heater", "helicopter", "helmet", "holder", "hook", "horse", "horse-drawncarriage", "hot-airballoon", "hydrovalve", "ice", "inflatorpump", "ipod", "iron", "ironingboard", "jar", "kart", "kettle", "key", "keyboard", "kitchenrange", "kite", "knife", "knifeblock", "ladder", "laddertruck", "ladle", "laptop", "leaves", "lid", "lifebuoy", "light", "lightbulb", "lighter", "line", "lion", "lobster", "lock", "machine", "mailbox", "mannequin", "map", "mask", "mat", "matchbook", "mattress", "menu", "metal", "meterbox", "microphone", "microwave", "mirror", "missile", "model", "money", "monkey", "mop", "motorbike", "mountain", "mouse", "mousepad", "musicalinstrument", "napkin", "net", "newspaper", "oar", "ornament", "outlet", "oven", "oxygenbottle", "pack", "pan", "paper", "paperbox", "papercutter", "parachute", "parasol", "parterre", "patio", "pelage", "pen", "pencontainer", "pencil", "person", "photo", "piano", "picture", "pig", "pillar", "pillow", "pipe", "pitcher", "plant", "plastic", "plate", "platform", "player", "playground", "pliers", "plume", "poker", "pokerchip", "pole", "pooltable", "postcard", "poster", "pot", "pottedplant", "printer", "projector", "pumpkin", "rabbit", "racket", "radiator", "radio", "rail", "rake", "ramp", "rangehood", "receiver", "recorder", "recreationalmachines", "remotecontrol", "road", "robot", "rock", "rocket", "rockinghorse", "rope", "rug", "ruler", "runway", "saddle", "sand", "saw", "scale", "scanner", "scissors", "scoop", "screen", "screwdriver", "sculpture", "scythe", "sewer", "sewingmachine", "shed", "sheep", "shell", "shelves", "shoe", "shoppingcart", "shovel", "sidecar", "sidewalk", "sign", "signallight", "sink", "skateboard", "ski", "sky", "sled", "slippers", "smoke", "snail", "snake", "snow", "snowmobiles", "sofa", "spanner", "spatula", "speaker", "speedbump", "spicecontainer", "spoon", "sprayer", "squirrel", "stage", "stair", "stapler", "stick", "stickynote", "stone", "stool", "stove", "straw", "stretcher", "sun", "sunglass", "sunshade", "surveillancecamera", "swan", "sweeper", "swimring", "swimmingpool", "swing", "switch", "table", "tableware", "tank", "tap", "tape", "tarp", "telephone", "telephonebooth", "tent", "tire", "toaster", "toilet", "tong", "tool", "toothbrush", "towel", "toy", "toycar", "track", "train", "trampoline", "trashbin", "tray", "tree", "tricycle", "tripod", "trophy", "truck", "tube", "turtle", "tvmonitor", "tweezers", "typewriter", "umbrella", "unknown", "vacuumcleaner", "vendingmachine", "videocamera", "videogameconsole", "videoplayer", "videotape", "violin", "wakeboard", "wall", "wallet", "wardrobe", "washingmachine", "watch", "water", "waterdispenser", "waterpipe", "waterskateboard", "watermelon", "whale", "wharf", "wheel", "wheelchair", "window", "windowblinds", "wineglass", "wire", "wood", "wool"]
59
+ context_colors = [stuff_colors[i % len(stuff_colors)] for i in range(len(context_459_classes))]
60
+ ret = {
61
+ "stuff_colors" : context_colors,
62
+ "stuff_classes" : context_459_classes,
63
+ }
64
+ return ret
65
+
66
+ def register_pascal_context_459(root):
67
+ root = os.path.join(root, "VOCdevkit", "VOC2010")
68
+ meta = _get_pascal_context_459_meta()
69
+ for name, image_dirname, sem_seg_dirname in [
70
+ ("test", "JPEGImages", "annotations_detectron2/pc459_val"),
71
+ ]:
72
+ image_dir = os.path.join(root, image_dirname)
73
+ gt_dir = os.path.join(root, sem_seg_dirname)
74
+ name = f"context_459_{name}_sem_seg"
75
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='tif', image_ext='jpg'))
76
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=459, **meta,)
77
+
78
+
79
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
80
+ register_pascal_context_59(_root)
81
+ register_pascal_context_459(_root)
cat_seg/modeling/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .backbone.swin import D2SwinTransformer
3
+ from .heads.cat_seg_head import CATSegHead
cat_seg/modeling/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (263 Bytes). View file
 
cat_seg/modeling/__pycache__/criterion.cpython-38.pyc ADDED
Binary file (8.26 kB). View file
 
cat_seg/modeling/__pycache__/matcher.cpython-38.pyc ADDED
Binary file (6.94 kB). View file
 
cat_seg/modeling/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (164 Bytes). View file
 
cat_seg/modeling/backbone/__pycache__/image_encoder.cpython-38.pyc ADDED
Binary file (20 kB). View file
 
cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc ADDED
Binary file (21.5 kB). View file
 
cat_seg/modeling/backbone/swin.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
17
+
18
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ """Multilayer perceptron."""
23
+
24
+ def __init__(
25
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
26
+ ):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ def window_partition(x, window_size):
45
+ """
46
+ Args:
47
+ x: (B, H, W, C)
48
+ window_size (int): window size
49
+ Returns:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ """
52
+ B, H, W, C = x.shape
53
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
54
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
55
+ return windows
56
+
57
+
58
+ def window_reverse(windows, window_size, H, W):
59
+ """
60
+ Args:
61
+ windows: (num_windows*B, window_size, window_size, C)
62
+ window_size (int): Window size
63
+ H (int): Height of image
64
+ W (int): Width of image
65
+ Returns:
66
+ x: (B, H, W, C)
67
+ """
68
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
69
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
70
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
71
+ return x
72
+
73
+
74
+ class WindowAttention(nn.Module):
75
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
76
+ It supports both of shifted and non-shifted window.
77
+ Args:
78
+ dim (int): Number of input channels.
79
+ window_size (tuple[int]): The height and width of the window.
80
+ num_heads (int): Number of attention heads.
81
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
82
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
83
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
84
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ window_size,
91
+ num_heads,
92
+ qkv_bias=True,
93
+ qk_scale=None,
94
+ attn_drop=0.0,
95
+ proj_drop=0.0,
96
+ ):
97
+
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.window_size = window_size # Wh, Ww
101
+ self.num_heads = num_heads
102
+ head_dim = dim // num_heads
103
+ self.scale = qk_scale or head_dim ** -0.5
104
+
105
+ # define a parameter table of relative position bias
106
+ self.relative_position_bias_table = nn.Parameter(
107
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
108
+ ) # 2*Wh-1 * 2*Ww-1, nH
109
+
110
+ # get pair-wise relative position index for each token inside the window
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+
123
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ self.proj = nn.Linear(dim, dim)
126
+ self.proj_drop = nn.Dropout(proj_drop)
127
+
128
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
129
+ self.softmax = nn.Softmax(dim=-1)
130
+
131
+ def forward(self, x, mask=None):
132
+ """Forward function.
133
+ Args:
134
+ x: input features with shape of (num_windows*B, N, C)
135
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
136
+ """
137
+ B_, N, C = x.shape
138
+ qkv = (
139
+ self.qkv(x)
140
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
141
+ .permute(2, 0, 3, 1, 4)
142
+ )
143
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
144
+
145
+ q = q * self.scale
146
+ attn = q @ k.transpose(-2, -1)
147
+
148
+ relative_position_bias = self.relative_position_bias_table[
149
+ self.relative_position_index.view(-1)
150
+ ].view(
151
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
152
+ ) # Wh*Ww,Wh*Ww,nH
153
+ relative_position_bias = relative_position_bias.permute(
154
+ 2, 0, 1
155
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
156
+ attn = attn + relative_position_bias.unsqueeze(0)
157
+
158
+ if mask is not None:
159
+ nW = mask.shape[0]
160
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
161
+ attn = attn.view(-1, self.num_heads, N, N)
162
+ attn = self.softmax(attn)
163
+ else:
164
+ attn = self.softmax(attn)
165
+
166
+ attn = self.attn_drop(attn)
167
+
168
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
169
+ x = self.proj(x)
170
+ x = self.proj_drop(x)
171
+ return x
172
+
173
+
174
+ class SwinTransformerBlock(nn.Module):
175
+ """Swin Transformer Block.
176
+ Args:
177
+ dim (int): Number of input channels.
178
+ num_heads (int): Number of attention heads.
179
+ window_size (int): Window size.
180
+ shift_size (int): Shift size for SW-MSA.
181
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
182
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
183
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
184
+ drop (float, optional): Dropout rate. Default: 0.0
185
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
186
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
187
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
188
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ num_heads,
195
+ window_size=7,
196
+ shift_size=0,
197
+ mlp_ratio=4.0,
198
+ qkv_bias=True,
199
+ qk_scale=None,
200
+ drop=0.0,
201
+ attn_drop=0.0,
202
+ drop_path=0.0,
203
+ act_layer=nn.GELU,
204
+ norm_layer=nn.LayerNorm,
205
+ ):
206
+ super().__init__()
207
+ self.dim = dim
208
+ self.num_heads = num_heads
209
+ self.window_size = window_size
210
+ self.shift_size = shift_size
211
+ self.mlp_ratio = mlp_ratio
212
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
213
+
214
+ self.norm1 = norm_layer(dim)
215
+ self.attn = WindowAttention(
216
+ dim,
217
+ window_size=to_2tuple(self.window_size),
218
+ num_heads=num_heads,
219
+ qkv_bias=qkv_bias,
220
+ qk_scale=qk_scale,
221
+ attn_drop=attn_drop,
222
+ proj_drop=drop,
223
+ )
224
+
225
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
226
+ self.norm2 = norm_layer(dim)
227
+ mlp_hidden_dim = int(dim * mlp_ratio)
228
+ self.mlp = Mlp(
229
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
230
+ )
231
+
232
+ self.H = None
233
+ self.W = None
234
+
235
+ def forward(self, x, mask_matrix):
236
+ """Forward function.
237
+ Args:
238
+ x: Input feature, tensor size (B, H*W, C).
239
+ H, W: Spatial resolution of the input feature.
240
+ mask_matrix: Attention mask for cyclic shift.
241
+ """
242
+ B, L, C = x.shape
243
+ H, W = self.H, self.W
244
+ assert L == H * W, "input feature has wrong size"
245
+
246
+ shortcut = x
247
+ x = self.norm1(x)
248
+ x = x.view(B, H, W, C)
249
+
250
+ # pad feature maps to multiples of window size
251
+ pad_l = pad_t = 0
252
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
253
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
254
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
255
+ _, Hp, Wp, _ = x.shape
256
+
257
+ # cyclic shift
258
+ if self.shift_size > 0:
259
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
260
+ attn_mask = mask_matrix
261
+ else:
262
+ shifted_x = x
263
+ attn_mask = None
264
+
265
+ # partition windows
266
+ x_windows = window_partition(
267
+ shifted_x, self.window_size
268
+ ) # nW*B, window_size, window_size, C
269
+ x_windows = x_windows.view(
270
+ -1, self.window_size * self.window_size, C
271
+ ) # nW*B, window_size*window_size, C
272
+
273
+ # W-MSA/SW-MSA
274
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
275
+
276
+ # merge windows
277
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
278
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
279
+
280
+ # reverse cyclic shift
281
+ if self.shift_size > 0:
282
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
283
+ else:
284
+ x = shifted_x
285
+
286
+ if pad_r > 0 or pad_b > 0:
287
+ x = x[:, :H, :W, :].contiguous()
288
+
289
+ x = x.view(B, H * W, C)
290
+
291
+ # FFN
292
+ x = shortcut + self.drop_path(x)
293
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
294
+
295
+ return x
296
+
297
+
298
+ class PatchMerging(nn.Module):
299
+ """Patch Merging Layer
300
+ Args:
301
+ dim (int): Number of input channels.
302
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
303
+ """
304
+
305
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
306
+ super().__init__()
307
+ self.dim = dim
308
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
309
+ self.norm = norm_layer(4 * dim)
310
+
311
+ def forward(self, x, H, W):
312
+ """Forward function.
313
+ Args:
314
+ x: Input feature, tensor size (B, H*W, C).
315
+ H, W: Spatial resolution of the input feature.
316
+ """
317
+ B, L, C = x.shape
318
+ assert L == H * W, "input feature has wrong size"
319
+
320
+ x = x.view(B, H, W, C)
321
+
322
+ # padding
323
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
324
+ if pad_input:
325
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+
340
+ class BasicLayer(nn.Module):
341
+ """A basic Swin Transformer layer for one stage.
342
+ Args:
343
+ dim (int): Number of feature channels
344
+ depth (int): Depths of this stage.
345
+ num_heads (int): Number of attention head.
346
+ window_size (int): Local window size. Default: 7.
347
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
348
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
349
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
350
+ drop (float, optional): Dropout rate. Default: 0.0
351
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
352
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
353
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
354
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
355
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ dim,
361
+ depth,
362
+ num_heads,
363
+ window_size=7,
364
+ mlp_ratio=4.0,
365
+ qkv_bias=True,
366
+ qk_scale=None,
367
+ drop=0.0,
368
+ attn_drop=0.0,
369
+ drop_path=0.0,
370
+ norm_layer=nn.LayerNorm,
371
+ downsample=None,
372
+ use_checkpoint=False,
373
+ ):
374
+ super().__init__()
375
+ self.window_size = window_size
376
+ self.shift_size = window_size // 2
377
+ self.depth = depth
378
+ self.use_checkpoint = use_checkpoint
379
+
380
+ # build blocks
381
+ self.blocks = nn.ModuleList(
382
+ [
383
+ SwinTransformerBlock(
384
+ dim=dim,
385
+ num_heads=num_heads,
386
+ window_size=window_size,
387
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
388
+ mlp_ratio=mlp_ratio,
389
+ qkv_bias=qkv_bias,
390
+ qk_scale=qk_scale,
391
+ drop=drop,
392
+ attn_drop=attn_drop,
393
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
394
+ norm_layer=norm_layer,
395
+ )
396
+ for i in range(depth)
397
+ ]
398
+ )
399
+
400
+ # patch merging layer
401
+ if downsample is not None:
402
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
403
+ else:
404
+ self.downsample = None
405
+
406
+ def forward(self, x, H, W):
407
+ """Forward function.
408
+ Args:
409
+ x: Input feature, tensor size (B, H*W, C).
410
+ H, W: Spatial resolution of the input feature.
411
+ """
412
+
413
+ # calculate attention mask for SW-MSA
414
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
415
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
416
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
417
+ h_slices = (
418
+ slice(0, -self.window_size),
419
+ slice(-self.window_size, -self.shift_size),
420
+ slice(-self.shift_size, None),
421
+ )
422
+ w_slices = (
423
+ slice(0, -self.window_size),
424
+ slice(-self.window_size, -self.shift_size),
425
+ slice(-self.shift_size, None),
426
+ )
427
+ cnt = 0
428
+ for h in h_slices:
429
+ for w in w_slices:
430
+ img_mask[:, h, w, :] = cnt
431
+ cnt += 1
432
+
433
+ mask_windows = window_partition(
434
+ img_mask, self.window_size
435
+ ) # nW, window_size, window_size, 1
436
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
437
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
438
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
439
+ attn_mask == 0, float(0.0)
440
+ )
441
+
442
+ for blk in self.blocks:
443
+ blk.H, blk.W = H, W
444
+ if self.use_checkpoint:
445
+ x = checkpoint.checkpoint(blk, x, attn_mask)
446
+ else:
447
+ x = blk(x, attn_mask)
448
+ if self.downsample is not None:
449
+ x_down = self.downsample(x, H, W)
450
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
451
+ return x, H, W, x_down, Wh, Ww
452
+ else:
453
+ return x, H, W, x, H, W
454
+
455
+
456
+ class PatchEmbed(nn.Module):
457
+ """Image to Patch Embedding
458
+ Args:
459
+ patch_size (int): Patch token size. Default: 4.
460
+ in_chans (int): Number of input image channels. Default: 3.
461
+ embed_dim (int): Number of linear projection output channels. Default: 96.
462
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
463
+ """
464
+
465
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
466
+ super().__init__()
467
+ patch_size = to_2tuple(patch_size)
468
+ self.patch_size = patch_size
469
+
470
+ self.in_chans = in_chans
471
+ self.embed_dim = embed_dim
472
+
473
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
474
+ if norm_layer is not None:
475
+ self.norm = norm_layer(embed_dim)
476
+ else:
477
+ self.norm = None
478
+
479
+ def forward(self, x):
480
+ """Forward function."""
481
+ # padding
482
+ _, _, H, W = x.size()
483
+ if W % self.patch_size[1] != 0:
484
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
485
+ if H % self.patch_size[0] != 0:
486
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
487
+
488
+ x = self.proj(x) # B C Wh Ww
489
+ if self.norm is not None:
490
+ Wh, Ww = x.size(2), x.size(3)
491
+ x = x.flatten(2).transpose(1, 2)
492
+ x = self.norm(x)
493
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
494
+
495
+ return x
496
+
497
+
498
+ class SwinTransformer(nn.Module):
499
+ """Swin Transformer backbone.
500
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
501
+ https://arxiv.org/pdf/2103.14030
502
+ Args:
503
+ pretrain_img_size (int): Input image size for training the pretrained model,
504
+ used in absolute postion embedding. Default 224.
505
+ patch_size (int | tuple(int)): Patch size. Default: 4.
506
+ in_chans (int): Number of input image channels. Default: 3.
507
+ embed_dim (int): Number of linear projection output channels. Default: 96.
508
+ depths (tuple[int]): Depths of each Swin Transformer stage.
509
+ num_heads (tuple[int]): Number of attention head of each stage.
510
+ window_size (int): Window size. Default: 7.
511
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
512
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
513
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
514
+ drop_rate (float): Dropout rate.
515
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
516
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
517
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
518
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
519
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
520
+ out_indices (Sequence[int]): Output from which stages.
521
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
522
+ -1 means not freezing any parameters.
523
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ pretrain_img_size=224,
529
+ patch_size=4,
530
+ in_chans=3,
531
+ embed_dim=96,
532
+ depths=[2, 2, 6, 2],
533
+ num_heads=[3, 6, 12, 24],
534
+ window_size=7,
535
+ mlp_ratio=4.0,
536
+ qkv_bias=True,
537
+ qk_scale=None,
538
+ drop_rate=0.0,
539
+ attn_drop_rate=0.0,
540
+ drop_path_rate=0.2,
541
+ norm_layer=nn.LayerNorm,
542
+ ape=False,
543
+ patch_norm=True,
544
+ out_indices=(0, 1, 2), #3),
545
+ frozen_stages=-1,
546
+ use_checkpoint=False,
547
+ ):
548
+ super().__init__()
549
+
550
+ self.pretrain_img_size = pretrain_img_size
551
+ self.num_layers = len(depths)
552
+ self.embed_dim = embed_dim
553
+ self.ape = ape
554
+ self.patch_norm = patch_norm
555
+ self.out_indices = out_indices
556
+ self.frozen_stages = frozen_stages
557
+
558
+ # split image into non-overlapping patches
559
+ self.patch_embed = PatchEmbed(
560
+ patch_size=patch_size,
561
+ in_chans=in_chans,
562
+ embed_dim=embed_dim,
563
+ norm_layer=norm_layer if self.patch_norm else None,
564
+ )
565
+
566
+ # absolute position embedding
567
+ if self.ape:
568
+ pretrain_img_size = to_2tuple(pretrain_img_size)
569
+ patch_size = to_2tuple(patch_size)
570
+ patches_resolution = [
571
+ pretrain_img_size[0] // patch_size[0],
572
+ pretrain_img_size[1] // patch_size[1],
573
+ ]
574
+
575
+ self.absolute_pos_embed = nn.Parameter(
576
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
577
+ )
578
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
579
+
580
+ self.pos_drop = nn.Dropout(p=drop_rate)
581
+
582
+ # stochastic depth
583
+ dpr = [
584
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
585
+ ] # stochastic depth decay rule
586
+
587
+ # build layers
588
+ self.layers = nn.ModuleList()
589
+ for i_layer in range(self.num_layers):
590
+ layer = BasicLayer(
591
+ dim=int(embed_dim * 2 ** i_layer),
592
+ depth=depths[i_layer],
593
+ num_heads=num_heads[i_layer],
594
+ window_size=window_size,
595
+ mlp_ratio=mlp_ratio,
596
+ qkv_bias=qkv_bias,
597
+ qk_scale=qk_scale,
598
+ drop=drop_rate,
599
+ attn_drop=attn_drop_rate,
600
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
601
+ norm_layer=norm_layer,
602
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
603
+ use_checkpoint=use_checkpoint,
604
+ )
605
+ self.layers.append(layer)
606
+
607
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
608
+ self.num_features = num_features
609
+
610
+ # add a norm layer for each output
611
+ for i_layer in out_indices:
612
+ layer = norm_layer(num_features[i_layer])
613
+ layer_name = f"norm{i_layer}"
614
+ self.add_module(layer_name, layer)
615
+
616
+ self._freeze_stages()
617
+
618
+ def _freeze_stages(self):
619
+ if self.frozen_stages >= 0:
620
+ self.patch_embed.eval()
621
+ for param in self.patch_embed.parameters():
622
+ param.requires_grad = False
623
+
624
+ if self.frozen_stages >= 1 and self.ape:
625
+ self.absolute_pos_embed.requires_grad = False
626
+
627
+ if self.frozen_stages >= 2:
628
+ self.pos_drop.eval()
629
+ for i in range(0, self.frozen_stages - 1):
630
+ m = self.layers[i]
631
+ m.eval()
632
+ for param in m.parameters():
633
+ param.requires_grad = False
634
+
635
+ def init_weights(self, pretrained=None):
636
+ """Initialize the weights in backbone.
637
+ Args:
638
+ pretrained (str, optional): Path to pre-trained weights.
639
+ Defaults to None.
640
+ """
641
+
642
+ def _init_weights(m):
643
+ if isinstance(m, nn.Linear):
644
+ trunc_normal_(m.weight, std=0.02)
645
+ if isinstance(m, nn.Linear) and m.bias is not None:
646
+ nn.init.constant_(m.bias, 0)
647
+ elif isinstance(m, nn.LayerNorm):
648
+ nn.init.constant_(m.bias, 0)
649
+ nn.init.constant_(m.weight, 1.0)
650
+
651
+ def forward(self, x):
652
+ """Forward function."""
653
+ x = self.patch_embed(x)
654
+
655
+ Wh, Ww = x.size(2), x.size(3)
656
+ if self.ape:
657
+ # interpolate the position embedding to the corresponding size
658
+ absolute_pos_embed = F.interpolate(
659
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
660
+ )
661
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
662
+ else:
663
+ x = x.flatten(2).transpose(1, 2)
664
+ x = self.pos_drop(x)
665
+
666
+ outs = {}
667
+ for i in range(self.num_layers):
668
+ layer = self.layers[i]
669
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
670
+
671
+ if i in self.out_indices:
672
+ norm_layer = getattr(self, f"norm{i}")
673
+ x_out = norm_layer(x_out)
674
+
675
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
676
+ outs["res{}".format(i + 2)] = out
677
+
678
+ return outs
679
+
680
+ def train(self, mode=True):
681
+ """Convert the model into training mode while keep layers freezed."""
682
+ super(SwinTransformer, self).train(mode)
683
+ self._freeze_stages()
684
+
685
+
686
+ @BACKBONE_REGISTRY.register()
687
+ class D2SwinTransformer(SwinTransformer, Backbone):
688
+ def __init__(self, cfg, input_shape):
689
+
690
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
691
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
692
+ in_chans = 3
693
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
694
+ depths = cfg.MODEL.SWIN.DEPTHS
695
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
696
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
697
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
698
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
699
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
700
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
701
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
702
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
703
+ norm_layer = nn.LayerNorm
704
+ ape = cfg.MODEL.SWIN.APE
705
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
706
+
707
+ super().__init__(
708
+ pretrain_img_size,
709
+ patch_size,
710
+ in_chans,
711
+ embed_dim,
712
+ depths,
713
+ num_heads,
714
+ window_size,
715
+ mlp_ratio,
716
+ qkv_bias,
717
+ qk_scale,
718
+ drop_rate,
719
+ attn_drop_rate,
720
+ drop_path_rate,
721
+ norm_layer,
722
+ ape,
723
+ patch_norm,
724
+ )
725
+
726
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
727
+
728
+ self._out_feature_strides = {
729
+ "res2": 4,
730
+ "res3": 8,
731
+ "res4": 16,
732
+ #"res5": 32,
733
+ }
734
+ self._out_feature_channels = {
735
+ "res2": self.num_features[0],
736
+ "res3": self.num_features[1],
737
+ "res4": self.num_features[2],
738
+ #"res5": self.num_features[3],
739
+ }
740
+
741
+ def forward(self, x):
742
+ """
743
+ Args:
744
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
745
+ Returns:
746
+ dict[str->Tensor]: names and the corresponding features
747
+ """
748
+ assert (
749
+ x.dim() == 4
750
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
751
+ outputs = {}
752
+ y = super().forward(x)
753
+ for k in y.keys():
754
+ if k in self._out_features:
755
+ outputs[k] = y[k]
756
+ return outputs
757
+
758
+ def output_shape(self):
759
+ return {
760
+ name: ShapeSpec(
761
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
762
+ )
763
+ for name in self._out_features
764
+ }
765
+
766
+ @property
767
+ def size_divisibility(self):
768
+ return 32
cat_seg/modeling/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.