hsshin98 commited on
Commit
d617811
1 Parent(s): aeae875

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. INSTALL.md +20 -0
  3. README.md +48 -12
  4. app.py +100 -0
  5. assets/fig1.png +0 -0
  6. cat_seg/__init__.py +19 -0
  7. cat_seg/__pycache__/__init__.cpython-38.pyc +0 -0
  8. cat_seg/__pycache__/cat_seg_model.cpython-38.pyc +0 -0
  9. cat_seg/__pycache__/config.cpython-38.pyc +0 -0
  10. cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc +0 -0
  11. cat_seg/cat_seg_model.py +216 -0
  12. cat_seg/config.py +93 -0
  13. cat_seg/data/__init__.py +2 -0
  14. cat_seg/data/__pycache__/__init__.cpython-38.pyc +0 -0
  15. cat_seg/data/dataset_mappers/__init__.py +1 -0
  16. cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc +0 -0
  17. cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc +0 -0
  18. cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc +0 -0
  19. cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc +0 -0
  20. cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py +180 -0
  21. cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py +165 -0
  22. cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py +186 -0
  23. cat_seg/data/datasets/__init__.py +8 -0
  24. cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  25. cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc +0 -0
  26. cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc +0 -0
  27. cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc +0 -0
  28. cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc +0 -0
  29. cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc +0 -0
  30. cat_seg/data/datasets/register_ade20k_150.py +28 -0
  31. cat_seg/data/datasets/register_ade20k_847.py +0 -0
  32. cat_seg/data/datasets/register_coco_stuff.py +216 -0
  33. cat_seg/data/datasets/register_pascal_20.py +53 -0
  34. cat_seg/data/datasets/register_pascal_59.py +81 -0
  35. cat_seg/modeling/__init__.py +3 -0
  36. cat_seg/modeling/__pycache__/__init__.cpython-38.pyc +0 -0
  37. cat_seg/modeling/backbone/__init__.py +1 -0
  38. cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc +0 -0
  39. cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc +0 -0
  40. cat_seg/modeling/backbone/swin.py +768 -0
  41. cat_seg/modeling/heads/__init__.py +1 -0
  42. cat_seg/modeling/heads/__pycache__/__init__.cpython-38.pyc +0 -0
  43. cat_seg/modeling/heads/__pycache__/cat_seg_head.cpython-38.pyc +0 -0
  44. cat_seg/modeling/heads/cat_seg_head.py +72 -0
  45. cat_seg/modeling/transformer/__init__.py +1 -0
  46. cat_seg/modeling/transformer/__pycache__/__init__.cpython-38.pyc +0 -0
  47. cat_seg/modeling/transformer/__pycache__/cat_seg_predictor.cpython-38.pyc +0 -0
  48. cat_seg/modeling/transformer/__pycache__/model.cpython-38.pyc +0 -0
  49. cat_seg/modeling/transformer/cat_seg_predictor.py +175 -0
  50. cat_seg/modeling/transformer/model.py +650 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ model_final.pth filter=lfs diff=lfs merge=lfs -text
36
+ cat_seg/third_party/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
INSTALL.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+
3
+ ### Requirements
4
+ - Linux or macOS with Python ≥ 3.6
5
+ - PyTorch ≥ 1.7 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
6
+ Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check
7
+ PyTorch version matches that is required by Detectron2.
8
+ - Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html).
9
+ - OpenCV is optional but needed by demo and visualization
10
+ - `pip install -r requirements.txt`
11
+
12
+ An example of installation is shown below:
13
+
14
+ ```
15
+ git clone https://github.com/~~~/CAT-Seg.git
16
+ cd CAT-Seg
17
+ conda create -n catseg python=3.8
18
+ conda activate catseg
19
+ pip install -r requirements.txt
20
+ ```
README.md CHANGED
@@ -1,12 +1,48 @@
1
- ---
2
- title: CAT Seg
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.21.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CAT-Seg🐱: Cost Aggregation for Open-Vocabulary Semantic Segmentation
2
+
3
+ This is our official implementation of CAT-Seg🐱!
4
+
5
+ [[arXiv](#)] [[Project](#)]<br>
6
+ by [Seokju Cho](https://seokju-cho.github.io/)\*, [Heeseong Shin](https://github.com/hsshin98)\*, [Sunghwan Hong](https://sunghwanhong.github.io), Seungjun An, Seungjun Lee, [Anurag Arnab](https://anuragarnab.github.io), [Paul Hongsuck Seo](https://phseo.github.io), [Seungryong Kim](https://cvlab.korea.ac.kr)
7
+
8
+
9
+ ## Introduction
10
+ ![](assets/fig1.png)
11
+ We introduce cost aggregation to open-vocabulary semantic segmentation, which jointly aggregates both image and text modalities within the matching cost.
12
+
13
+ ## Installation
14
+ Install required packages.
15
+
16
+ ```bash
17
+ conda create --name catseg python=3.8
18
+ conda activate catseg
19
+ conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ## Data Preparation
24
+
25
+
26
+ ## Training
27
+ ### Preparation
28
+ you have to blah
29
+ ### Training script
30
+ ```bash
31
+ python train.py --config configs/eval/{a847 | pc459 | a150 | pc59 | pas20 | pas20b}.yaml
32
+ ```
33
+
34
+ ## Evaluation
35
+ ```bash
36
+ python eval.py --config configs/eval/{a847 | pc459 | a150 | pc59 | pas20 | pas20b}.yaml
37
+ ```
38
+
39
+ ## Citing CAT-Seg🐱 :pray:
40
+
41
+ ```BibTeX
42
+ @article{liang2022open,
43
+ title={Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP},
44
+ author={Liang, Feng and Wu, Bichen and Dai, Xiaoliang and Li, Kunpeng and Zhao, Yinan and Zhang, Hang and Zhang, Peizhao and Vajda, Peter and Marculescu, Diana},
45
+ journal={arXiv preprint arXiv:2210.04150},
46
+ year={2022}
47
+ }
48
+ ```
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
3
+ import argparse
4
+ import glob
5
+ import multiprocessing as mp
6
+ import os
7
+
8
+ # fmt: off
9
+ import sys
10
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
11
+ # fmt: on
12
+
13
+ import tempfile
14
+ import time
15
+ import warnings
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import tqdm
20
+
21
+ from detectron2.config import get_cfg
22
+ from detectron2.data.detection_utils import read_image
23
+ from detectron2.projects.deeplab import add_deeplab_config
24
+ from detectron2.utils.logger import setup_logger
25
+
26
+ from cat_seg import add_cat_seg_config
27
+ from demo.predictor import VisualizationDemo
28
+ import gradio as gr
29
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
30
+
31
+ # constants
32
+ WINDOW_NAME = "MaskFormer demo"
33
+
34
+
35
+ def setup_cfg(args):
36
+ # load config from file and command-line arguments
37
+ cfg = get_cfg()
38
+ add_deeplab_config(cfg)
39
+ add_cat_seg_config(cfg)
40
+ cfg.merge_from_file(args.config_file)
41
+ cfg.merge_from_list(args.opts)
42
+ cfg.freeze()
43
+ return cfg
44
+
45
+
46
+ def get_parser():
47
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
48
+ parser.add_argument(
49
+ "--config-file",
50
+ default="configs/vitl_swinb_384.yaml",
51
+ metavar="FILE",
52
+ help="path to config file",
53
+ )
54
+ parser.add_argument(
55
+ "--input",
56
+ nargs="+",
57
+ help="A list of space separated input images; "
58
+ "or a single glob pattern such as 'directory/*.jpg'",
59
+ )
60
+ parser.add_argument(
61
+ "--opts",
62
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
63
+ default=["MODEL.WEIGHTS", "model_final.pth",
64
+ "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
65
+ "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
66
+ "TEST.SLIDING_WINDOW", "True",
67
+ "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]"],
68
+ nargs=argparse.REMAINDER,
69
+ )
70
+ return parser
71
+
72
+ def save_masks(preds, text):
73
+ preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
74
+ for i, t in enumerate(text):
75
+ dir = f"masks/mask_{t}.png"
76
+ mask = preds == i
77
+ cv2.imwrite(dir, mask * 255)
78
+
79
+ def predict(image, text):
80
+ args = get_parser().parse_args()
81
+ cfg = setup_cfg(args)
82
+ demo = VisualizationDemo(cfg, text=text)
83
+ predictions, visualized_output = demo.run_on_image(image)
84
+ save_masks(predictions, text.split(','))
85
+ canvas = fc(visualized_output.fig)
86
+ canvas.draw()
87
+ out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
88
+
89
+ return out[..., ::-1]
90
+
91
+ if __name__ == "__main__":
92
+ args = get_parser().parse_args()
93
+ cfg = setup_cfg(args)
94
+
95
+ iface = gr.Interface(
96
+ fn=predict,
97
+ inputs=[gr.Image(), gr.Textbox(placeholder="Classes to segment")],
98
+ outputs="image",
99
+ )
100
+ iface.launch()
assets/fig1.png ADDED
cat_seg/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import data # register all new datasets
3
+ from . import modeling
4
+
5
+ # config
6
+ from .config import add_cat_seg_config
7
+
8
+ # dataset loading
9
+ from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper
10
+ from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import (
11
+ MaskFormerPanopticDatasetMapper,
12
+ )
13
+ from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
14
+ MaskFormerSemanticDatasetMapper,
15
+ )
16
+
17
+ # models
18
+ from .cat_seg_model import CATSeg
19
+ from .test_time_augmentation import SemanticSegmentorWithTTA
cat_seg/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (705 Bytes). View file
 
cat_seg/__pycache__/cat_seg_model.cpython-38.pyc ADDED
Binary file (7.49 kB). View file
 
cat_seg/__pycache__/config.cpython-38.pyc ADDED
Binary file (2.4 kB). View file
 
cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc ADDED
Binary file (4.42 kB). View file
 
cat_seg/cat_seg_model.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.data import MetadataCatalog
10
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
11
+ from detectron2.modeling.backbone import Backbone
12
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
13
+ from detectron2.structures import ImageList
14
+ from detectron2.utils.memory import _ignore_torch_cuda_oom
15
+
16
+ from einops import rearrange
17
+
18
+ @META_ARCH_REGISTRY.register()
19
+ class CATSeg(nn.Module):
20
+ @configurable
21
+ def __init__(
22
+ self,
23
+ *,
24
+ backbone: Backbone,
25
+ sem_seg_head: nn.Module,
26
+ size_divisibility: int,
27
+ pixel_mean: Tuple[float],
28
+ pixel_std: Tuple[float],
29
+ clip_pixel_mean: Tuple[float],
30
+ clip_pixel_std: Tuple[float],
31
+ train_class_json: str,
32
+ test_class_json: str,
33
+ sliding_window: bool,
34
+ clip_finetune: str,
35
+ backbone_multiplier: float,
36
+ clip_pretrained: str,
37
+ ):
38
+ """
39
+ Args:
40
+ backbone: a backbone module, must follow detectron2's backbone interface
41
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
42
+ """
43
+ super().__init__()
44
+ self.backbone = backbone
45
+ self.sem_seg_head = sem_seg_head
46
+ if size_divisibility < 0:
47
+ size_divisibility = self.backbone.size_divisibility
48
+ self.size_divisibility = size_divisibility
49
+
50
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
51
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
52
+ self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False)
53
+ self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False)
54
+
55
+ self.train_class_json = train_class_json
56
+ self.test_class_json = test_class_json
57
+
58
+ self.clip_finetune = clip_finetune
59
+ for name, params in self.sem_seg_head.predictor.clip_model.named_parameters():
60
+ if "visual" in name:
61
+ if clip_finetune == "prompt":
62
+ params.requires_grad = True if "prompt" in name else False
63
+ elif clip_finetune == "attention":
64
+ params.requires_grad = True if "attn" in name or "position" in name else False
65
+ elif clip_finetune == "full":
66
+ params.requires_grad = True
67
+ else:
68
+ params.requires_grad = False
69
+ else:
70
+ params.requires_grad = False
71
+
72
+ finetune_backbone = backbone_multiplier > 0.
73
+ for name, params in self.backbone.named_parameters():
74
+ if "norm0" in name:
75
+ params.requires_grad = False
76
+ else:
77
+ params.requires_grad = finetune_backbone
78
+
79
+ self.sliding_window = sliding_window
80
+ self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336)
81
+ self.sequential = False
82
+
83
+ @classmethod
84
+ def from_config(cls, cfg):
85
+ backbone = build_backbone(cfg)
86
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
87
+
88
+ return {
89
+ "backbone": backbone,
90
+ "sem_seg_head": sem_seg_head,
91
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
92
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
93
+ "pixel_std": cfg.MODEL.PIXEL_STD,
94
+ "clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN,
95
+ "clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD,
96
+ "train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON,
97
+ "test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON,
98
+ "sliding_window": cfg.TEST.SLIDING_WINDOW,
99
+ "clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE,
100
+ "backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER,
101
+ "clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED,
102
+ }
103
+
104
+ @property
105
+ def device(self):
106
+ return self.pixel_mean.device
107
+
108
+ def forward(self, batched_inputs):
109
+ """
110
+ Args:
111
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
112
+ Each item in the list contains the inputs for one image.
113
+ For now, each item in the list is a dict that contains:
114
+ * "image": Tensor, image in (C, H, W) format.
115
+ * "instances": per-region ground truth
116
+ * Other information that's included in the original dicts, such as:
117
+ "height", "width" (int): the output resolution of the model (may be different
118
+ from input resolution), used in inference.
119
+ Returns:
120
+ list[dict]:
121
+ each dict has the results for one image. The dict contains the following keys:
122
+
123
+ * "sem_seg":
124
+ A Tensor that represents the
125
+ per-pixel segmentation prediced by the head.
126
+ The prediction has shape KxHxW that represents the logits of
127
+ each class for each pixel.
128
+ """
129
+ images = [x["image"].to(self.device) for x in batched_inputs]
130
+ if not self.training and self.sliding_window:
131
+ if not self.sequential:
132
+ with _ignore_torch_cuda_oom():
133
+ return self.inference_sliding_window(batched_inputs)
134
+ self.sequential = True
135
+ return self.inference_sliding_window(batched_inputs)
136
+
137
+ clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images]
138
+ clip_images = ImageList.from_tensors(clip_images, self.size_divisibility)
139
+
140
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
141
+ images = ImageList.from_tensors(images, self.size_divisibility)
142
+
143
+ clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, )
144
+ clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
145
+
146
+ images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,)
147
+ features = self.backbone(images_resized)
148
+
149
+ outputs = self.sem_seg_head(clip_features, features)
150
+ if self.training:
151
+ targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0)
152
+ outputs = F.interpolate(outputs, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False)
153
+
154
+ num_classes = outputs.shape[1]
155
+ mask = targets != self.sem_seg_head.ignore_value
156
+
157
+ outputs = outputs.permute(0,2,3,1)
158
+ _targets = torch.zeros(outputs.shape, device=self.device)
159
+ _onehot = F.one_hot(targets[mask], num_classes=num_classes).float()
160
+ _targets[mask] = _onehot
161
+
162
+ loss = F.binary_cross_entropy_with_logits(outputs, _targets)
163
+ losses = {"loss_sem_seg" : loss}
164
+ return losses
165
+ else:
166
+ outputs = outputs.sigmoid()
167
+ image_size = images.image_sizes[0]
168
+ height = batched_inputs[0].get("height", image_size[0])
169
+ width = batched_inputs[0].get("width", image_size[1])
170
+
171
+ output = sem_seg_postprocess(outputs[0], image_size, height, width)
172
+ processed_results = [{'sem_seg': output}]
173
+ return processed_results
174
+
175
+
176
+ @torch.no_grad()
177
+ def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]):
178
+ images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs]
179
+ stride = int(kernel * (1 - overlap))
180
+ unfold = nn.Unfold(kernel_size=kernel, stride=stride)
181
+ fold = nn.Fold(out_res, kernel_size=kernel, stride=stride)
182
+
183
+ image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze()
184
+ image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel)
185
+ global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False)
186
+ image = torch.cat((image, global_image), dim=0)
187
+
188
+ images = (image - self.pixel_mean) / self.pixel_std
189
+ clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std
190
+ clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, )
191
+ clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
192
+
193
+ if self.sequential:
194
+ outputs = []
195
+ for clip_feat, image in zip(clip_features, images):
196
+ feature = self.backbone(image.unsqueeze(0))
197
+ output = self.sem_seg_head(clip_feat.unsqueeze(0), feature)
198
+ outputs.append(output[0])
199
+ outputs = torch.stack(outputs, dim=0)
200
+ else:
201
+ features = self.backbone(images)
202
+ outputs = self.sem_seg_head(clip_features, features)
203
+
204
+ outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False)
205
+ outputs = outputs.sigmoid()
206
+
207
+ global_output = outputs[-1:]
208
+ global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,)
209
+ outputs = outputs[:-1]
210
+ outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device)))
211
+ outputs = (outputs + global_output) / 2.
212
+
213
+ height = batched_inputs[0].get("height", out_res[0])
214
+ width = batched_inputs[0].get("width", out_res[1])
215
+ output = sem_seg_postprocess(outputs, out_res, height, width)
216
+ return [{'sem_seg': output}]
cat_seg/config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from detectron2.config import CfgNode as CN
4
+
5
+
6
+ def add_cat_seg_config(cfg):
7
+ """
8
+ Add config for MASK_FORMER.
9
+ """
10
+ # data config
11
+ # select the dataset mapper
12
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
13
+
14
+ cfg.DATASETS.VAL_ALL = ("coco_2017_val_all_stuff_sem_seg",)
15
+
16
+ # Color augmentation
17
+ cfg.INPUT.COLOR_AUG_SSD = False
18
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
19
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
20
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
21
+ # Pad image and segmentation GT in dataset mapper.
22
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
23
+
24
+ # solver config
25
+ # weight decay on embedding
26
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
27
+ # optimizer
28
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
29
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
30
+
31
+ # mask_former model config
32
+ cfg.MODEL.MASK_FORMER = CN()
33
+
34
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
35
+ # you can use this config to override
36
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
37
+
38
+ # swin transformer backbone
39
+ cfg.MODEL.SWIN = CN()
40
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
41
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
42
+ cfg.MODEL.SWIN.EMBED_DIM = 96
43
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
44
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
45
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
46
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
47
+ cfg.MODEL.SWIN.QKV_BIAS = True
48
+ cfg.MODEL.SWIN.QK_SCALE = None
49
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
50
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
51
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
52
+ cfg.MODEL.SWIN.APE = False
53
+ cfg.MODEL.SWIN.PATCH_NORM = True
54
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
55
+
56
+ # zero shot config
57
+ cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json"
58
+ cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json"
59
+ cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_INDEXES = "datasets/coco/coco_stuff/split/seen_indexes.json"
60
+ cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_INDEXES = "datasets/coco/coco_stuff/split/unseen_indexes.json"
61
+
62
+ cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED = "ViT-B/16"
63
+
64
+ cfg.MODEL.PROMPT_ENSEMBLE = False
65
+ cfg.MODEL.PROMPT_ENSEMBLE_TYPE = "single"
66
+
67
+ cfg.MODEL.CLIP_PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615]
68
+ cfg.MODEL.CLIP_PIXEL_STD = [68.5005327, 66.6321579, 70.3231630]
69
+ # three styles for clip classification, crop, mask, cropmask
70
+
71
+ cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM = 512
72
+ cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM = 128
73
+ cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM = 512
74
+ cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM = 128
75
+
76
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS = [64, 32]
77
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS = [256, 128]
78
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS = [32, 16]
79
+
80
+ cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS = 4
81
+ cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS = 4
82
+ cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS = 128
83
+ cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES = [6, 6]
84
+ cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION = [24, 24]
85
+ cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES = 12
86
+ cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE = "linear"
87
+
88
+ cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH = 0
89
+ cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH = 0
90
+ cfg.SOLVER.CLIP_MULTIPLIER = 0.01
91
+
92
+ cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE = "attention"
93
+ cfg.TEST.SLIDING_WINDOW = False
cat_seg/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import datasets
cat_seg/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (196 Bytes). View file
 
cat_seg/data/dataset_mappers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (179 Bytes). View file
 
cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc ADDED
Binary file (4.9 kB). View file
 
cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc ADDED
Binary file (4.43 kB). View file
 
cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc ADDED
Binary file (5.07 kB). View file
 
cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
3
+ import copy
4
+ import logging
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import detection_utils as utils
11
+ from detectron2.data import transforms as T
12
+ from detectron2.data.transforms import TransformGen
13
+ from detectron2.structures import BitMasks, Instances
14
+
15
+ __all__ = ["DETRPanopticDatasetMapper"]
16
+
17
+
18
+ def build_transform_gen(cfg, is_train):
19
+ """
20
+ Create a list of :class:`TransformGen` from config.
21
+ Returns:
22
+ list[TransformGen]
23
+ """
24
+ if is_train:
25
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
26
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
27
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
28
+ else:
29
+ min_size = cfg.INPUT.MIN_SIZE_TEST
30
+ max_size = cfg.INPUT.MAX_SIZE_TEST
31
+ sample_style = "choice"
32
+ if sample_style == "range":
33
+ assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(
34
+ len(min_size)
35
+ )
36
+
37
+ logger = logging.getLogger(__name__)
38
+ tfm_gens = []
39
+ if is_train:
40
+ tfm_gens.append(T.RandomFlip())
41
+ tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
42
+ if is_train:
43
+ logger.info("TransformGens used in training: " + str(tfm_gens))
44
+ return tfm_gens
45
+
46
+
47
+ # This is specifically designed for the COCO dataset.
48
+ class DETRPanopticDatasetMapper:
49
+ """
50
+ A callable which takes a dataset dict in Detectron2 Dataset format,
51
+ and map it into a format used by MaskFormer.
52
+
53
+ This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
54
+
55
+ The callable currently does the following:
56
+
57
+ 1. Read the image from "file_name"
58
+ 2. Applies geometric transforms to the image and annotation
59
+ 3. Find and applies suitable cropping to the image and annotation
60
+ 4. Prepare image and annotation to Tensors
61
+ """
62
+
63
+ @configurable
64
+ def __init__(
65
+ self,
66
+ is_train=True,
67
+ *,
68
+ crop_gen,
69
+ tfm_gens,
70
+ image_format,
71
+ ):
72
+ """
73
+ NOTE: this interface is experimental.
74
+ Args:
75
+ is_train: for training or inference
76
+ augmentations: a list of augmentations or deterministic transforms to apply
77
+ crop_gen: crop augmentation
78
+ tfm_gens: data augmentation
79
+ image_format: an image format supported by :func:`detection_utils.read_image`.
80
+ """
81
+ self.crop_gen = crop_gen
82
+ self.tfm_gens = tfm_gens
83
+ logging.getLogger(__name__).info(
84
+ "[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format(
85
+ str(self.tfm_gens), str(self.crop_gen)
86
+ )
87
+ )
88
+
89
+ self.img_format = image_format
90
+ self.is_train = is_train
91
+
92
+ @classmethod
93
+ def from_config(cls, cfg, is_train=True):
94
+ # Build augmentation
95
+ if cfg.INPUT.CROP.ENABLED and is_train:
96
+ crop_gen = [
97
+ T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
98
+ T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
99
+ ]
100
+ else:
101
+ crop_gen = None
102
+
103
+ tfm_gens = build_transform_gen(cfg, is_train)
104
+
105
+ ret = {
106
+ "is_train": is_train,
107
+ "crop_gen": crop_gen,
108
+ "tfm_gens": tfm_gens,
109
+ "image_format": cfg.INPUT.FORMAT,
110
+ }
111
+ return ret
112
+
113
+ def __call__(self, dataset_dict):
114
+ """
115
+ Args:
116
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
117
+
118
+ Returns:
119
+ dict: a format that builtin models in detectron2 accept
120
+ """
121
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
122
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
123
+ utils.check_image_size(dataset_dict, image)
124
+
125
+ if self.crop_gen is None:
126
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
127
+ else:
128
+ if np.random.rand() > 0.5:
129
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
130
+ else:
131
+ image, transforms = T.apply_transform_gens(
132
+ self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
133
+ )
134
+
135
+ image_shape = image.shape[:2] # h, w
136
+
137
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
138
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
139
+ # Therefore it's important to use torch.Tensor.
140
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
141
+
142
+ if not self.is_train:
143
+ # USER: Modify this if you want to keep them for some reason.
144
+ dataset_dict.pop("annotations", None)
145
+ return dataset_dict
146
+
147
+ if "pan_seg_file_name" in dataset_dict:
148
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
149
+ segments_info = dataset_dict["segments_info"]
150
+
151
+ # apply the same transformation to panoptic segmentation
152
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
153
+
154
+ from panopticapi.utils import rgb2id
155
+
156
+ pan_seg_gt = rgb2id(pan_seg_gt)
157
+
158
+ instances = Instances(image_shape)
159
+ classes = []
160
+ masks = []
161
+ for segment_info in segments_info:
162
+ class_id = segment_info["category_id"]
163
+ if not segment_info["iscrowd"]:
164
+ classes.append(class_id)
165
+ masks.append(pan_seg_gt == segment_info["id"])
166
+
167
+ classes = np.array(classes)
168
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
169
+ if len(masks) == 0:
170
+ # Some image does not have annotation (all ignored)
171
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
172
+ else:
173
+ masks = BitMasks(
174
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
175
+ )
176
+ instances.gt_masks = masks.tensor
177
+
178
+ dataset_dict["instances"] = instances
179
+
180
+ return dataset_dict
cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import detection_utils as utils
11
+ from detectron2.data import transforms as T
12
+ from detectron2.structures import BitMasks, Instances
13
+
14
+ from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
15
+
16
+ __all__ = ["MaskFormerPanopticDatasetMapper"]
17
+
18
+
19
+ class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper):
20
+ """
21
+ A callable which takes a dataset dict in Detectron2 Dataset format,
22
+ and map it into a format used by MaskFormer for panoptic segmentation.
23
+
24
+ The callable currently does the following:
25
+
26
+ 1. Read the image from "file_name"
27
+ 2. Applies geometric transforms to the image and annotation
28
+ 3. Find and applies suitable cropping to the image and annotation
29
+ 4. Prepare image and annotation to Tensors
30
+ """
31
+
32
+ @configurable
33
+ def __init__(
34
+ self,
35
+ is_train=True,
36
+ *,
37
+ augmentations,
38
+ image_format,
39
+ ignore_label,
40
+ size_divisibility,
41
+ ):
42
+ """
43
+ NOTE: this interface is experimental.
44
+ Args:
45
+ is_train: for training or inference
46
+ augmentations: a list of augmentations or deterministic transforms to apply
47
+ image_format: an image format supported by :func:`detection_utils.read_image`.
48
+ ignore_label: the label that is ignored to evaluation
49
+ size_divisibility: pad image size to be divisible by this value
50
+ """
51
+ super().__init__(
52
+ is_train,
53
+ augmentations=augmentations,
54
+ image_format=image_format,
55
+ ignore_label=ignore_label,
56
+ size_divisibility=size_divisibility,
57
+ )
58
+
59
+ def __call__(self, dataset_dict):
60
+ """
61
+ Args:
62
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
63
+
64
+ Returns:
65
+ dict: a format that builtin models in detectron2 accept
66
+ """
67
+ assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!"
68
+
69
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
70
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
71
+ utils.check_image_size(dataset_dict, image)
72
+
73
+ # semantic segmentation
74
+ if "sem_seg_file_name" in dataset_dict:
75
+ # PyTorch transformation not implemented for uint16, so converting it to double first
76
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
77
+ else:
78
+ sem_seg_gt = None
79
+
80
+ # panoptic segmentation
81
+ if "pan_seg_file_name" in dataset_dict:
82
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
83
+ segments_info = dataset_dict["segments_info"]
84
+ else:
85
+ pan_seg_gt = None
86
+ segments_info = None
87
+
88
+ if pan_seg_gt is None:
89
+ raise ValueError(
90
+ "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
91
+ dataset_dict["file_name"]
92
+ )
93
+ )
94
+
95
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
96
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
97
+ image = aug_input.image
98
+ if sem_seg_gt is not None:
99
+ sem_seg_gt = aug_input.sem_seg
100
+
101
+ # apply the same transformation to panoptic segmentation
102
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
103
+
104
+ from panopticapi.utils import rgb2id
105
+
106
+ pan_seg_gt = rgb2id(pan_seg_gt)
107
+
108
+ # Pad image and segmentation label here!
109
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
110
+ if sem_seg_gt is not None:
111
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
112
+ pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
113
+
114
+ if self.size_divisibility > 0:
115
+ image_size = (image.shape[-2], image.shape[-1])
116
+ padding_size = [
117
+ 0,
118
+ self.size_divisibility - image_size[1],
119
+ 0,
120
+ self.size_divisibility - image_size[0],
121
+ ]
122
+ image = F.pad(image, padding_size, value=128).contiguous()
123
+ if sem_seg_gt is not None:
124
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
125
+ pan_seg_gt = F.pad(
126
+ pan_seg_gt, padding_size, value=0
127
+ ).contiguous() # 0 is the VOID panoptic label
128
+
129
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
130
+
131
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
132
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
133
+ # Therefore it's important to use torch.Tensor.
134
+ dataset_dict["image"] = image
135
+ if sem_seg_gt is not None:
136
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
137
+
138
+ if "annotations" in dataset_dict:
139
+ raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
140
+
141
+ # Prepare per-category binary masks
142
+ pan_seg_gt = pan_seg_gt.numpy()
143
+ instances = Instances(image_shape)
144
+ classes = []
145
+ masks = []
146
+ for segment_info in segments_info:
147
+ class_id = segment_info["category_id"]
148
+ if not segment_info["iscrowd"]:
149
+ classes.append(class_id)
150
+ masks.append(pan_seg_gt == segment_info["id"])
151
+
152
+ classes = np.array(classes)
153
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
154
+ if len(masks) == 0:
155
+ # Some image does not have annotation (all ignored)
156
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
157
+ else:
158
+ masks = BitMasks(
159
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
160
+ )
161
+ instances.gt_masks = masks.tensor
162
+
163
+ dataset_dict["instances"] = instances
164
+
165
+ return dataset_dict
cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.data import MetadataCatalog
11
+ from detectron2.data import detection_utils as utils
12
+ from detectron2.data import transforms as T
13
+ from detectron2.projects.point_rend import ColorAugSSDTransform
14
+ from detectron2.structures import BitMasks, Instances
15
+
16
+ __all__ = ["MaskFormerSemanticDatasetMapper"]
17
+
18
+
19
+ class MaskFormerSemanticDatasetMapper:
20
+ """
21
+ A callable which takes a dataset dict in Detectron2 Dataset format,
22
+ and map it into a format used by MaskFormer for semantic segmentation.
23
+
24
+ The callable currently does the following:
25
+
26
+ 1. Read the image from "file_name"
27
+ 2. Applies geometric transforms to the image and annotation
28
+ 3. Find and applies suitable cropping to the image and annotation
29
+ 4. Prepare image and annotation to Tensors
30
+ """
31
+
32
+ @configurable
33
+ def __init__(
34
+ self,
35
+ is_train=True,
36
+ *,
37
+ augmentations,
38
+ image_format,
39
+ ignore_label,
40
+ size_divisibility,
41
+ ):
42
+ """
43
+ NOTE: this interface is experimental.
44
+ Args:
45
+ is_train: for training or inference
46
+ augmentations: a list of augmentations or deterministic transforms to apply
47
+ image_format: an image format supported by :func:`detection_utils.read_image`.
48
+ ignore_label: the label that is ignored to evaluation
49
+ size_divisibility: pad image size to be divisible by this value
50
+ """
51
+ self.is_train = is_train
52
+ self.tfm_gens = augmentations
53
+ self.img_format = image_format
54
+ self.ignore_label = ignore_label
55
+ self.size_divisibility = size_divisibility
56
+
57
+ logger = logging.getLogger(__name__)
58
+ mode = "training" if is_train else "inference"
59
+ logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
60
+
61
+ @classmethod
62
+ def from_config(cls, cfg, is_train=True):
63
+ # Build augmentation
64
+ augs = [
65
+ T.ResizeShortestEdge(
66
+ cfg.INPUT.MIN_SIZE_TRAIN,
67
+ cfg.INPUT.MAX_SIZE_TRAIN,
68
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
69
+ )
70
+ ]
71
+ if cfg.INPUT.CROP.ENABLED:
72
+ augs.append(
73
+ T.RandomCrop_CategoryAreaConstraint(
74
+ cfg.INPUT.CROP.TYPE,
75
+ cfg.INPUT.CROP.SIZE,
76
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
77
+ cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
78
+ )
79
+ )
80
+ if cfg.INPUT.COLOR_AUG_SSD:
81
+ augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
82
+ augs.append(T.RandomFlip())
83
+
84
+ # Assume always applies to the training set.
85
+ dataset_names = cfg.DATASETS.TRAIN
86
+ meta = MetadataCatalog.get(dataset_names[0])
87
+ ignore_label = meta.ignore_label
88
+
89
+ ret = {
90
+ "is_train": is_train,
91
+ "augmentations": augs,
92
+ "image_format": cfg.INPUT.FORMAT,
93
+ "ignore_label": ignore_label,
94
+ "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
95
+ }
96
+ return ret
97
+
98
+ def __call__(self, dataset_dict):
99
+ """
100
+ Args:
101
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
102
+
103
+ Returns:
104
+ dict: a format that builtin models in detectron2 accept
105
+ """
106
+ assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
107
+
108
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
109
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
110
+ utils.check_image_size(dataset_dict, image)
111
+
112
+ if "sem_seg_file_name" in dataset_dict:
113
+ # PyTorch transformation not implemented for uint16, so converting it to double first
114
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
115
+ else:
116
+ sem_seg_gt = None
117
+
118
+ if sem_seg_gt is None:
119
+ raise ValueError(
120
+ "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
121
+ dataset_dict["file_name"]
122
+ )
123
+ )
124
+
125
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
126
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
127
+ image = aug_input.image
128
+ sem_seg_gt = aug_input.sem_seg
129
+
130
+ # Pad image and segmentation label here!
131
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
132
+ if sem_seg_gt is not None:
133
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
134
+ # import ipdb; ipdb.set_trace()
135
+ if self.size_divisibility > 0:
136
+ image_size = (image.shape[-2], image.shape[-1])
137
+ # The ori_size is not the real original size, but size before padding
138
+ dataset_dict['ori_size'] = image_size
139
+ padding_size = [
140
+ 0,
141
+ self.size_divisibility - image_size[1], # w: (left, right)
142
+ 0,
143
+ self.size_divisibility - image_size[0], # h: 0,(top, bottom)
144
+ ]
145
+ image = F.pad(image, padding_size, value=128).contiguous()
146
+ if sem_seg_gt is not None:
147
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
148
+
149
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
150
+
151
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
152
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
153
+ # Therefore it's important to use torch.Tensor.
154
+ dataset_dict["image"] = image
155
+ # print('#########################################################################################')
156
+ if sem_seg_gt is not None:
157
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
158
+
159
+ if "annotations" in dataset_dict:
160
+ raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
161
+
162
+ # Prepare per-category binary masks
163
+ if sem_seg_gt is not None:
164
+ sem_seg_gt = sem_seg_gt.numpy()
165
+ instances = Instances(image_shape)
166
+ classes = np.unique(sem_seg_gt)
167
+ # remove ignored region
168
+ classes = classes[classes != self.ignore_label]
169
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
170
+
171
+ masks = []
172
+ for class_id in classes:
173
+ masks.append(sem_seg_gt == class_id)
174
+
175
+ if len(masks) == 0:
176
+ # Some image does not have annotation (all ignored)
177
+ instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
178
+ else:
179
+ masks = BitMasks(
180
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
181
+ )
182
+ instances.gt_masks = masks.tensor
183
+
184
+ dataset_dict["instances"] = instances
185
+
186
+ return dataset_dict
cat_seg/data/datasets/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import (
3
+ register_coco_stuff,
4
+ register_ade20k_150,
5
+ register_ade20k_847,
6
+ register_pascal_20,
7
+ register_pascal_59,
8
+ )
cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (334 Bytes). View file
 
cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc ADDED
Binary file (2.89 kB). View file
 
cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc ADDED
Binary file (51.8 kB). View file
 
cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc ADDED
Binary file (7.87 kB). View file
 
cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc ADDED
Binary file (2.48 kB). View file
 
cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc ADDED
Binary file (9.58 kB). View file
 
cat_seg/data/datasets/register_ade20k_150.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from detectron2.data import DatasetCatalog, MetadataCatalog
4
+ from detectron2.data.datasets import load_sem_seg
5
+ import copy
6
+
7
+ def _get_ade20k_150_meta():
8
+ ade20k_150_classes = ["wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed ", "windowpane", "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion", "base", "box", "column", "signboard", "chest of drawers", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs", "runway", "case", "pool table", "pillow", "screen door", "stairway", "river", "bridge", "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench", "countertop", "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth", "television receiver", "airplane", "dirt track", "apparel", "pole", "land", "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool", "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food", "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass", "clock", "flag"]
9
+
10
+ ret = {
11
+ "stuff_classes" : ade20k_150_classes,
12
+ }
13
+ return ret
14
+
15
+ def register_ade20k_150(root):
16
+ root = os.path.join(root, "ADEChallengeData2016")
17
+ meta = _get_ade20k_150_meta()
18
+ for name, image_dirname, sem_seg_dirname in [
19
+ ("test", "images/validation", "annotations_detectron2/validation"),
20
+ ]:
21
+ image_dir = os.path.join(root, image_dirname)
22
+ gt_dir = os.path.join(root, sem_seg_dirname)
23
+ name = f"ade20k_150_{name}_sem_seg"
24
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
25
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
26
+
27
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
28
+ register_ade20k_150(_root)
cat_seg/data/datasets/register_ade20k_847.py ADDED
The diff for this file is too large to render. See raw diff
 
cat_seg/data/datasets/register_coco_stuff.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from detectron2.data import DatasetCatalog, MetadataCatalog
4
+ from detectron2.data.datasets import load_sem_seg
5
+
6
+ COCO_CATEGORIES = [
7
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
8
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
9
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
10
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
11
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
12
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
13
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
14
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
15
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
16
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
17
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
18
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
19
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
20
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
21
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
22
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
23
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
24
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
25
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
26
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
27
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
28
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
29
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
30
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
31
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
32
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
33
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
34
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
35
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
36
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
37
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
38
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
39
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
40
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
41
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
42
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
43
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
44
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
45
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
46
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
47
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
48
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
49
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
50
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
51
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
52
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
53
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
54
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
55
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
56
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
57
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
58
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
59
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
60
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
61
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
62
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
63
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
64
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
65
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
66
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
67
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
68
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
69
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
70
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
71
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
72
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
73
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
74
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
75
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
76
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
77
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
78
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
79
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
80
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
81
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
82
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
83
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
84
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
85
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
86
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
87
+ {"id": 92, "name": "banner", "supercategory": "textile"},
88
+ {"id": 93, "name": "blanket", "supercategory": "textile"},
89
+ {"id": 94, "name": "branch", "supercategory": "plant"},
90
+ {"id": 95, "name": "bridge", "supercategory": "building"},
91
+ {"id": 96, "name": "building-other", "supercategory": "building"},
92
+ {"id": 97, "name": "bush", "supercategory": "plant"},
93
+ {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
94
+ {"id": 99, "name": "cage", "supercategory": "structural"},
95
+ {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
96
+ {"id": 101, "name": "carpet", "supercategory": "floor"},
97
+ {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
98
+ {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
99
+ {"id": 104, "name": "cloth", "supercategory": "textile"},
100
+ {"id": 105, "name": "clothes", "supercategory": "textile"},
101
+ {"id": 106, "name": "clouds", "supercategory": "sky"},
102
+ {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
103
+ {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
104
+ {"id": 109, "name": "curtain", "supercategory": "textile"},
105
+ {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
106
+ {"id": 111, "name": "dirt", "supercategory": "ground"},
107
+ {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
108
+ {"id": 113, "name": "fence", "supercategory": "structural"},
109
+ {"id": 114, "name": "floor-marble", "supercategory": "floor"},
110
+ {"id": 115, "name": "floor-other", "supercategory": "floor"},
111
+ {"id": 116, "name": "floor-stone", "supercategory": "floor"},
112
+ {"id": 117, "name": "floor-tile", "supercategory": "floor"},
113
+ {"id": 118, "name": "floor-wood", "supercategory": "floor"},
114
+ {"id": 119, "name": "flower", "supercategory": "plant"},
115
+ {"id": 120, "name": "fog", "supercategory": "water"},
116
+ {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
117
+ {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
118
+ {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
119
+ {"id": 124, "name": "grass", "supercategory": "plant"},
120
+ {"id": 125, "name": "gravel", "supercategory": "ground"},
121
+ {"id": 126, "name": "ground-other", "supercategory": "ground"},
122
+ {"id": 127, "name": "hill", "supercategory": "solid"},
123
+ {"id": 128, "name": "house", "supercategory": "building"},
124
+ {"id": 129, "name": "leaves", "supercategory": "plant"},
125
+ {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
126
+ {"id": 131, "name": "mat", "supercategory": "textile"},
127
+ {"id": 132, "name": "metal", "supercategory": "raw-material"},
128
+ {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
129
+ {"id": 134, "name": "moss", "supercategory": "plant"},
130
+ {"id": 135, "name": "mountain", "supercategory": "solid"},
131
+ {"id": 136, "name": "mud", "supercategory": "ground"},
132
+ {"id": 137, "name": "napkin", "supercategory": "textile"},
133
+ {"id": 138, "name": "net", "supercategory": "structural"},
134
+ {"id": 139, "name": "paper", "supercategory": "raw-material"},
135
+ {"id": 140, "name": "pavement", "supercategory": "ground"},
136
+ {"id": 141, "name": "pillow", "supercategory": "textile"},
137
+ {"id": 142, "name": "plant-other", "supercategory": "plant"},
138
+ {"id": 143, "name": "plastic", "supercategory": "raw-material"},
139
+ {"id": 144, "name": "platform", "supercategory": "ground"},
140
+ {"id": 145, "name": "playingfield", "supercategory": "ground"},
141
+ {"id": 146, "name": "railing", "supercategory": "structural"},
142
+ {"id": 147, "name": "railroad", "supercategory": "ground"},
143
+ {"id": 148, "name": "river", "supercategory": "water"},
144
+ {"id": 149, "name": "road", "supercategory": "ground"},
145
+ {"id": 150, "name": "rock", "supercategory": "solid"},
146
+ {"id": 151, "name": "roof", "supercategory": "building"},
147
+ {"id": 152, "name": "rug", "supercategory": "textile"},
148
+ {"id": 153, "name": "salad", "supercategory": "food-stuff"},
149
+ {"id": 154, "name": "sand", "supercategory": "ground"},
150
+ {"id": 155, "name": "sea", "supercategory": "water"},
151
+ {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
152
+ {"id": 157, "name": "sky-other", "supercategory": "sky"},
153
+ {"id": 158, "name": "skyscraper", "supercategory": "building"},
154
+ {"id": 159, "name": "snow", "supercategory": "ground"},
155
+ {"id": 160, "name": "solid-other", "supercategory": "solid"},
156
+ {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
157
+ {"id": 162, "name": "stone", "supercategory": "solid"},
158
+ {"id": 163, "name": "straw", "supercategory": "plant"},
159
+ {"id": 164, "name": "structural-other", "supercategory": "structural"},
160
+ {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
161
+ {"id": 166, "name": "tent", "supercategory": "building"},
162
+ {"id": 167, "name": "textile-other", "supercategory": "textile"},
163
+ {"id": 168, "name": "towel", "supercategory": "textile"},
164
+ {"id": 169, "name": "tree", "supercategory": "plant"},
165
+ {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
166
+ {"id": 171, "name": "wall-brick", "supercategory": "wall"},
167
+ {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
168
+ {"id": 173, "name": "wall-other", "supercategory": "wall"},
169
+ {"id": 174, "name": "wall-panel", "supercategory": "wall"},
170
+ {"id": 175, "name": "wall-stone", "supercategory": "wall"},
171
+ {"id": 176, "name": "wall-tile", "supercategory": "wall"},
172
+ {"id": 177, "name": "wall-wood", "supercategory": "wall"},
173
+ {"id": 178, "name": "water-other", "supercategory": "water"},
174
+ {"id": 179, "name": "waterdrops", "supercategory": "water"},
175
+ {"id": 180, "name": "window-blind", "supercategory": "window"},
176
+ {"id": 181, "name": "window-other", "supercategory": "window"},
177
+ {"id": 182, "name": "wood", "supercategory": "solid"},
178
+ ]
179
+
180
+
181
+ def _get_coco_stuff_meta():
182
+ stuff_ids = [k["id"] for k in COCO_CATEGORIES]
183
+ assert len(stuff_ids) == 171, len(stuff_ids)
184
+
185
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
186
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
187
+
188
+ ret = {
189
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
190
+ "stuff_classes": stuff_classes,
191
+ }
192
+ return ret
193
+
194
+ def register_all_coco_stuff_10k(root):
195
+ root = os.path.join(root, "coco-stuff")
196
+ meta = _get_coco_stuff_meta()
197
+ for name, image_dirname, sem_seg_dirname in [
198
+ ("train", "images/train2017", "annotations_detectron2/train2017"),
199
+ ("test", "images/val2017", "annotations_detectron2/val2017"),
200
+ ]:
201
+ image_dir = os.path.join(root, image_dirname)
202
+ gt_dir = os.path.join(root, sem_seg_dirname)
203
+ name = f"coco_2017_{name}_stuff_all_sem_seg"
204
+ DatasetCatalog.register(
205
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
206
+ )
207
+ MetadataCatalog.get(name).set(
208
+ image_root=image_dir,
209
+ sem_seg_root=gt_dir,
210
+ evaluator_type="sem_seg",
211
+ ignore_label=255,
212
+ **meta,
213
+ )
214
+
215
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
216
+ register_all_coco_stuff_10k(_root)
cat_seg/data/datasets/register_pascal_20.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from detectron2.data import DatasetCatalog, MetadataCatalog
4
+ from detectron2.data.datasets import load_sem_seg
5
+ import copy
6
+
7
+ def _get_pascal_voc_meta():
8
+ voc_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
9
+ voc_colors = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
10
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
11
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
12
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
13
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
14
+ ret = {
15
+ "stuff_classes" : voc_classes,
16
+ "stuff_colors" : voc_colors,
17
+ }
18
+ return ret
19
+
20
+ def register_all_pascal_voc(root):
21
+ root = os.path.join(root, "VOCdevkit/VOC2012")
22
+ meta = _get_pascal_voc_meta()
23
+ for name, image_dirname, sem_seg_dirname in [
24
+ ("test", "JPEGImages", "annotations_detectron2"),
25
+ ("test_background", "JPEGImages", "annotations_detectron2_bg"),
26
+ ]:
27
+ image_dir = os.path.join(root, image_dirname)
28
+ gt_dir = os.path.join(root, sem_seg_dirname, 'val')
29
+ name = f"voc_2012_{name}_sem_seg"
30
+
31
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
32
+ if "background" in name:
33
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg_background", ignore_label=255,
34
+ stuff_classes=meta["stuff_classes"] + ["background"], stuff_colors=meta["stuff_colors"])
35
+ else:
36
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
37
+
38
+ def register_all_pascal_voc_background(root):
39
+ root = os.path.join(root, "VOCdevkit/VOC2012")
40
+ meta = _get_pascal_voc_meta()
41
+ meta["stuff_classes"] = meta["stuff_classes"] + ["background"]
42
+ for name, image_dirname, sem_seg_dirname in [
43
+ ("test_background", "image", "label_openseg_background20"),
44
+ ]:
45
+ image_dir = os.path.join(root, image_dirname, 'validation')
46
+ gt_dir = os.path.join(root, sem_seg_dirname, 'validation')
47
+ name = f"voc_2012_{name}_sem_seg"
48
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
49
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg_background", ignore_label=255, **meta,)
50
+
51
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
52
+ register_all_pascal_voc(_root)
53
+ #register_all_pascal_voc_background(_root)
cat_seg/data/datasets/register_pascal_59.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from detectron2.data import DatasetCatalog, MetadataCatalog
4
+ from detectron2.data.datasets import load_sem_seg
5
+ import copy
6
+
7
+
8
+ stuff_colors = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
9
+ [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
10
+ [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
11
+ [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
12
+ [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
13
+ [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
14
+ [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
15
+ [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
16
+ [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
17
+ [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
18
+ [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
19
+ [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
20
+ [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
21
+ [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
22
+ [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
23
+ [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
24
+ [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
25
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
26
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
27
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
28
+ [0, 0, 230], [119, 11, 32],
29
+ [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
30
+ [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
31
+ [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
32
+ [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
33
+ [64, 192, 96], [64, 160, 64], [64, 64, 0]]
34
+
35
+ def _get_pascal_context_59_meta():
36
+ #context_classes = ["aeroplane", "bag", "bed", "bedclothes", "bench", "bicycle", "bird", "boat", "book", "bottle", "building", "bus", "cabinet", "car", "cat", "ceiling", "chair", "cloth", "computer", "cow", "cup", "curtain", "dog", "door", "fence", "floor", "flower", "food", "grass", "ground", "horse", "keyboard", "light", "motorbike", "mountain", "mouse", "person", "plate", "platform", "pottedplant", "road", "rock", "sheep", "shelves", "sidewalk", "sign", "sky", "snow", "sofa", "diningtable", "track", "train", "tree", "truck", "tvmonitor", "wall", "water", "window", "wood"]#, "background"]
37
+ context_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", "bag", "bed", "bench", "book", "building", "cabinet", "ceiling", "cloth", "computer", "cup", "door", "fence", "floor", "flower", "food", "grass", "ground", "keyboard", "light", "mountain", "mouse", "curtain", "platform", "sign", "plate", "road", "rock", "shelves", "sidewalk", "sky", "snow", "bedclothes", "track", "tree", "truck", "wall", "water", "window", "wood"]
38
+ context_colors = [stuff_colors[i % len(stuff_colors)] for i in range(len(context_classes))]
39
+ ret = {
40
+ "stuff_colors" : context_colors,
41
+ "stuff_classes" : context_classes,
42
+ }
43
+ return ret
44
+
45
+ def register_pascal_context_59(root):
46
+ root = os.path.join(root, "VOCdevkit", "VOC2010")
47
+ meta = _get_pascal_context_59_meta()
48
+ for name, image_dirname, sem_seg_dirname in [
49
+ ("test", "JPEGImages", "annotations_detectron2/pc59_val"),
50
+ ]:
51
+ image_dir = os.path.join(root, image_dirname)
52
+ gt_dir = os.path.join(root, sem_seg_dirname)
53
+ name = f"context_59_{name}_sem_seg"
54
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
55
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
56
+
57
+ def _get_pascal_context_459_meta():
58
+ context_459_classes = ["accordion", "aeroplane", "airconditioner", "antenna", "artillery", "ashtray", "atrium", "babycarriage", "bag", "ball", "balloon", "bambooweaving", "barrel", "baseballbat", "basket", "basketballbackboard", "bathtub", "bed", "bedclothes", "beer", "bell", "bench", "bicycle", "binoculars", "bird", "birdcage", "birdfeeder", "birdnest", "blackboard", "board", "boat", "bone", "book", "bottle", "bottleopener", "bowl", "box", "bracelet", "brick", "bridge", "broom", "brush", "bucket", "building", "bus", "cabinet", "cabinetdoor", "cage", "cake", "calculator", "calendar", "camel", "camera", "cameralens", "can", "candle", "candleholder", "cap", "car", "card", "cart", "case", "casetterecorder", "cashregister", "cat", "cd", "cdplayer", "ceiling", "cellphone", "cello", "chain", "chair", "chessboard", "chicken", "chopstick", "clip", "clippers", "clock", "closet", "cloth", "clothestree", "coffee", "coffeemachine", "comb", "computer", "concrete", "cone", "container", "controlbooth", "controller", "cooker", "copyingmachine", "coral", "cork", "corkscrew", "counter", "court", "cow", "crabstick", "crane", "crate", "cross", "crutch", "cup", "curtain", "cushion", "cuttingboard", "dais", "disc", "disccase", "dishwasher", "dock", "dog", "dolphin", "door", "drainer", "dray", "drinkdispenser", "drinkingmachine", "drop", "drug", "drum", "drumkit", "duck", "dumbbell", "earphone", "earrings", "egg", "electricfan", "electriciron", "electricpot", "electricsaw", "electronickeyboard", "engine", "envelope", "equipment", "escalator", "exhibitionbooth", "extinguisher", "eyeglass", "fan", "faucet", "faxmachine", "fence", "ferriswheel", "fireextinguisher", "firehydrant", "fireplace", "fish", "fishtank", "fishbowl", "fishingnet", "fishingpole", "flag", "flagstaff", "flame", "flashlight", "floor", "flower", "fly", "foam", "food", "footbridge", "forceps", "fork", "forklift", "fountain", "fox", "frame", "fridge", "frog", "fruit", "funnel", "furnace", "gamecontroller", "gamemachine", "gascylinder", "gashood", "gasstove", "giftbox", "glass", "glassmarble", "globe", "glove", "goal", "grandstand", "grass", "gravestone", "ground", "guardrail", "guitar", "gun", "hammer", "handcart", "handle", "handrail", "hanger", "harddiskdrive", "hat", "hay", "headphone", "heater", "helicopter", "helmet", "holder", "hook", "horse", "horse-drawncarriage", "hot-airballoon", "hydrovalve", "ice", "inflatorpump", "ipod", "iron", "ironingboard", "jar", "kart", "kettle", "key", "keyboard", "kitchenrange", "kite", "knife", "knifeblock", "ladder", "laddertruck", "ladle", "laptop", "leaves", "lid", "lifebuoy", "light", "lightbulb", "lighter", "line", "lion", "lobster", "lock", "machine", "mailbox", "mannequin", "map", "mask", "mat", "matchbook", "mattress", "menu", "metal", "meterbox", "microphone", "microwave", "mirror", "missile", "model", "money", "monkey", "mop", "motorbike", "mountain", "mouse", "mousepad", "musicalinstrument", "napkin", "net", "newspaper", "oar", "ornament", "outlet", "oven", "oxygenbottle", "pack", "pan", "paper", "paperbox", "papercutter", "parachute", "parasol", "parterre", "patio", "pelage", "pen", "pencontainer", "pencil", "person", "photo", "piano", "picture", "pig", "pillar", "pillow", "pipe", "pitcher", "plant", "plastic", "plate", "platform", "player", "playground", "pliers", "plume", "poker", "pokerchip", "pole", "pooltable", "postcard", "poster", "pot", "pottedplant", "printer", "projector", "pumpkin", "rabbit", "racket", "radiator", "radio", "rail", "rake", "ramp", "rangehood", "receiver", "recorder", "recreationalmachines", "remotecontrol", "road", "robot", "rock", "rocket", "rockinghorse", "rope", "rug", "ruler", "runway", "saddle", "sand", "saw", "scale", "scanner", "scissors", "scoop", "screen", "screwdriver", "sculpture", "scythe", "sewer", "sewingmachine", "shed", "sheep", "shell", "shelves", "shoe", "shoppingcart", "shovel", "sidecar", "sidewalk", "sign", "signallight", "sink", "skateboard", "ski", "sky", "sled", "slippers", "smoke", "snail", "snake", "snow", "snowmobiles", "sofa", "spanner", "spatula", "speaker", "speedbump", "spicecontainer", "spoon", "sprayer", "squirrel", "stage", "stair", "stapler", "stick", "stickynote", "stone", "stool", "stove", "straw", "stretcher", "sun", "sunglass", "sunshade", "surveillancecamera", "swan", "sweeper", "swimring", "swimmingpool", "swing", "switch", "table", "tableware", "tank", "tap", "tape", "tarp", "telephone", "telephonebooth", "tent", "tire", "toaster", "toilet", "tong", "tool", "toothbrush", "towel", "toy", "toycar", "track", "train", "trampoline", "trashbin", "tray", "tree", "tricycle", "tripod", "trophy", "truck", "tube", "turtle", "tvmonitor", "tweezers", "typewriter", "umbrella", "unknown", "vacuumcleaner", "vendingmachine", "videocamera", "videogameconsole", "videoplayer", "videotape", "violin", "wakeboard", "wall", "wallet", "wardrobe", "washingmachine", "watch", "water", "waterdispenser", "waterpipe", "waterskateboard", "watermelon", "whale", "wharf", "wheel", "wheelchair", "window", "windowblinds", "wineglass", "wire", "wood", "wool"]
59
+ context_colors = [stuff_colors[i % len(stuff_colors)] for i in range(len(context_459_classes))]
60
+ ret = {
61
+ "stuff_colors" : context_colors,
62
+ "stuff_classes" : context_459_classes,
63
+ }
64
+ return ret
65
+
66
+ def register_pascal_context_459(root):
67
+ root = os.path.join(root, "VOCdevkit", "VOC2010")
68
+ meta = _get_pascal_context_459_meta()
69
+ for name, image_dirname, sem_seg_dirname in [
70
+ ("test", "JPEGImages", "annotations_detectron2/pc459_val"),
71
+ ]:
72
+ image_dir = os.path.join(root, image_dirname)
73
+ gt_dir = os.path.join(root, sem_seg_dirname)
74
+ name = f"context_459_{name}_sem_seg"
75
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='tif', image_ext='jpg'))
76
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=459, **meta,)
77
+
78
+
79
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
80
+ register_pascal_context_59(_root)
81
+ register_pascal_context_459(_root)
cat_seg/modeling/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .backbone.swin import D2SwinTransformer
3
+ from .heads.cat_seg_head import CATSegHead
cat_seg/modeling/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (275 Bytes). View file
 
cat_seg/modeling/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (176 Bytes). View file
 
cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc ADDED
Binary file (21.5 kB). View file
 
cat_seg/modeling/backbone/swin.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
17
+
18
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ """Multilayer perceptron."""
23
+
24
+ def __init__(
25
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
26
+ ):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ def window_partition(x, window_size):
45
+ """
46
+ Args:
47
+ x: (B, H, W, C)
48
+ window_size (int): window size
49
+ Returns:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ """
52
+ B, H, W, C = x.shape
53
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
54
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
55
+ return windows
56
+
57
+
58
+ def window_reverse(windows, window_size, H, W):
59
+ """
60
+ Args:
61
+ windows: (num_windows*B, window_size, window_size, C)
62
+ window_size (int): Window size
63
+ H (int): Height of image
64
+ W (int): Width of image
65
+ Returns:
66
+ x: (B, H, W, C)
67
+ """
68
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
69
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
70
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
71
+ return x
72
+
73
+
74
+ class WindowAttention(nn.Module):
75
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
76
+ It supports both of shifted and non-shifted window.
77
+ Args:
78
+ dim (int): Number of input channels.
79
+ window_size (tuple[int]): The height and width of the window.
80
+ num_heads (int): Number of attention heads.
81
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
82
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
83
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
84
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ window_size,
91
+ num_heads,
92
+ qkv_bias=True,
93
+ qk_scale=None,
94
+ attn_drop=0.0,
95
+ proj_drop=0.0,
96
+ ):
97
+
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.window_size = window_size # Wh, Ww
101
+ self.num_heads = num_heads
102
+ head_dim = dim // num_heads
103
+ self.scale = qk_scale or head_dim ** -0.5
104
+
105
+ # define a parameter table of relative position bias
106
+ self.relative_position_bias_table = nn.Parameter(
107
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
108
+ ) # 2*Wh-1 * 2*Ww-1, nH
109
+
110
+ # get pair-wise relative position index for each token inside the window
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+
123
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ self.proj = nn.Linear(dim, dim)
126
+ self.proj_drop = nn.Dropout(proj_drop)
127
+
128
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
129
+ self.softmax = nn.Softmax(dim=-1)
130
+
131
+ def forward(self, x, mask=None):
132
+ """Forward function.
133
+ Args:
134
+ x: input features with shape of (num_windows*B, N, C)
135
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
136
+ """
137
+ B_, N, C = x.shape
138
+ qkv = (
139
+ self.qkv(x)
140
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
141
+ .permute(2, 0, 3, 1, 4)
142
+ )
143
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
144
+
145
+ q = q * self.scale
146
+ attn = q @ k.transpose(-2, -1)
147
+
148
+ relative_position_bias = self.relative_position_bias_table[
149
+ self.relative_position_index.view(-1)
150
+ ].view(
151
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
152
+ ) # Wh*Ww,Wh*Ww,nH
153
+ relative_position_bias = relative_position_bias.permute(
154
+ 2, 0, 1
155
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
156
+ attn = attn + relative_position_bias.unsqueeze(0)
157
+
158
+ if mask is not None:
159
+ nW = mask.shape[0]
160
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
161
+ attn = attn.view(-1, self.num_heads, N, N)
162
+ attn = self.softmax(attn)
163
+ else:
164
+ attn = self.softmax(attn)
165
+
166
+ attn = self.attn_drop(attn)
167
+
168
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
169
+ x = self.proj(x)
170
+ x = self.proj_drop(x)
171
+ return x
172
+
173
+
174
+ class SwinTransformerBlock(nn.Module):
175
+ """Swin Transformer Block.
176
+ Args:
177
+ dim (int): Number of input channels.
178
+ num_heads (int): Number of attention heads.
179
+ window_size (int): Window size.
180
+ shift_size (int): Shift size for SW-MSA.
181
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
182
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
183
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
184
+ drop (float, optional): Dropout rate. Default: 0.0
185
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
186
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
187
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
188
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ num_heads,
195
+ window_size=7,
196
+ shift_size=0,
197
+ mlp_ratio=4.0,
198
+ qkv_bias=True,
199
+ qk_scale=None,
200
+ drop=0.0,
201
+ attn_drop=0.0,
202
+ drop_path=0.0,
203
+ act_layer=nn.GELU,
204
+ norm_layer=nn.LayerNorm,
205
+ ):
206
+ super().__init__()
207
+ self.dim = dim
208
+ self.num_heads = num_heads
209
+ self.window_size = window_size
210
+ self.shift_size = shift_size
211
+ self.mlp_ratio = mlp_ratio
212
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
213
+
214
+ self.norm1 = norm_layer(dim)
215
+ self.attn = WindowAttention(
216
+ dim,
217
+ window_size=to_2tuple(self.window_size),
218
+ num_heads=num_heads,
219
+ qkv_bias=qkv_bias,
220
+ qk_scale=qk_scale,
221
+ attn_drop=attn_drop,
222
+ proj_drop=drop,
223
+ )
224
+
225
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
226
+ self.norm2 = norm_layer(dim)
227
+ mlp_hidden_dim = int(dim * mlp_ratio)
228
+ self.mlp = Mlp(
229
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
230
+ )
231
+
232
+ self.H = None
233
+ self.W = None
234
+
235
+ def forward(self, x, mask_matrix):
236
+ """Forward function.
237
+ Args:
238
+ x: Input feature, tensor size (B, H*W, C).
239
+ H, W: Spatial resolution of the input feature.
240
+ mask_matrix: Attention mask for cyclic shift.
241
+ """
242
+ B, L, C = x.shape
243
+ H, W = self.H, self.W
244
+ assert L == H * W, "input feature has wrong size"
245
+
246
+ shortcut = x
247
+ x = self.norm1(x)
248
+ x = x.view(B, H, W, C)
249
+
250
+ # pad feature maps to multiples of window size
251
+ pad_l = pad_t = 0
252
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
253
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
254
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
255
+ _, Hp, Wp, _ = x.shape
256
+
257
+ # cyclic shift
258
+ if self.shift_size > 0:
259
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
260
+ attn_mask = mask_matrix
261
+ else:
262
+ shifted_x = x
263
+ attn_mask = None
264
+
265
+ # partition windows
266
+ x_windows = window_partition(
267
+ shifted_x, self.window_size
268
+ ) # nW*B, window_size, window_size, C
269
+ x_windows = x_windows.view(
270
+ -1, self.window_size * self.window_size, C
271
+ ) # nW*B, window_size*window_size, C
272
+
273
+ # W-MSA/SW-MSA
274
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
275
+
276
+ # merge windows
277
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
278
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
279
+
280
+ # reverse cyclic shift
281
+ if self.shift_size > 0:
282
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
283
+ else:
284
+ x = shifted_x
285
+
286
+ if pad_r > 0 or pad_b > 0:
287
+ x = x[:, :H, :W, :].contiguous()
288
+
289
+ x = x.view(B, H * W, C)
290
+
291
+ # FFN
292
+ x = shortcut + self.drop_path(x)
293
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
294
+
295
+ return x
296
+
297
+
298
+ class PatchMerging(nn.Module):
299
+ """Patch Merging Layer
300
+ Args:
301
+ dim (int): Number of input channels.
302
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
303
+ """
304
+
305
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
306
+ super().__init__()
307
+ self.dim = dim
308
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
309
+ self.norm = norm_layer(4 * dim)
310
+
311
+ def forward(self, x, H, W):
312
+ """Forward function.
313
+ Args:
314
+ x: Input feature, tensor size (B, H*W, C).
315
+ H, W: Spatial resolution of the input feature.
316
+ """
317
+ B, L, C = x.shape
318
+ assert L == H * W, "input feature has wrong size"
319
+
320
+ x = x.view(B, H, W, C)
321
+
322
+ # padding
323
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
324
+ if pad_input:
325
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+
340
+ class BasicLayer(nn.Module):
341
+ """A basic Swin Transformer layer for one stage.
342
+ Args:
343
+ dim (int): Number of feature channels
344
+ depth (int): Depths of this stage.
345
+ num_heads (int): Number of attention head.
346
+ window_size (int): Local window size. Default: 7.
347
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
348
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
349
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
350
+ drop (float, optional): Dropout rate. Default: 0.0
351
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
352
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
353
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
354
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
355
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ dim,
361
+ depth,
362
+ num_heads,
363
+ window_size=7,
364
+ mlp_ratio=4.0,
365
+ qkv_bias=True,
366
+ qk_scale=None,
367
+ drop=0.0,
368
+ attn_drop=0.0,
369
+ drop_path=0.0,
370
+ norm_layer=nn.LayerNorm,
371
+ downsample=None,
372
+ use_checkpoint=False,
373
+ ):
374
+ super().__init__()
375
+ self.window_size = window_size
376
+ self.shift_size = window_size // 2
377
+ self.depth = depth
378
+ self.use_checkpoint = use_checkpoint
379
+
380
+ # build blocks
381
+ self.blocks = nn.ModuleList(
382
+ [
383
+ SwinTransformerBlock(
384
+ dim=dim,
385
+ num_heads=num_heads,
386
+ window_size=window_size,
387
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
388
+ mlp_ratio=mlp_ratio,
389
+ qkv_bias=qkv_bias,
390
+ qk_scale=qk_scale,
391
+ drop=drop,
392
+ attn_drop=attn_drop,
393
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
394
+ norm_layer=norm_layer,
395
+ )
396
+ for i in range(depth)
397
+ ]
398
+ )
399
+
400
+ # patch merging layer
401
+ if downsample is not None:
402
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
403
+ else:
404
+ self.downsample = None
405
+
406
+ def forward(self, x, H, W):
407
+ """Forward function.
408
+ Args:
409
+ x: Input feature, tensor size (B, H*W, C).
410
+ H, W: Spatial resolution of the input feature.
411
+ """
412
+
413
+ # calculate attention mask for SW-MSA
414
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
415
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
416
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
417
+ h_slices = (
418
+ slice(0, -self.window_size),
419
+ slice(-self.window_size, -self.shift_size),
420
+ slice(-self.shift_size, None),
421
+ )
422
+ w_slices = (
423
+ slice(0, -self.window_size),
424
+ slice(-self.window_size, -self.shift_size),
425
+ slice(-self.shift_size, None),
426
+ )
427
+ cnt = 0
428
+ for h in h_slices:
429
+ for w in w_slices:
430
+ img_mask[:, h, w, :] = cnt
431
+ cnt += 1
432
+
433
+ mask_windows = window_partition(
434
+ img_mask, self.window_size
435
+ ) # nW, window_size, window_size, 1
436
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
437
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
438
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
439
+ attn_mask == 0, float(0.0)
440
+ )
441
+
442
+ for blk in self.blocks:
443
+ blk.H, blk.W = H, W
444
+ if self.use_checkpoint:
445
+ x = checkpoint.checkpoint(blk, x, attn_mask)
446
+ else:
447
+ x = blk(x, attn_mask)
448
+ if self.downsample is not None:
449
+ x_down = self.downsample(x, H, W)
450
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
451
+ return x, H, W, x_down, Wh, Ww
452
+ else:
453
+ return x, H, W, x, H, W
454
+
455
+
456
+ class PatchEmbed(nn.Module):
457
+ """Image to Patch Embedding
458
+ Args:
459
+ patch_size (int): Patch token size. Default: 4.
460
+ in_chans (int): Number of input image channels. Default: 3.
461
+ embed_dim (int): Number of linear projection output channels. Default: 96.
462
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
463
+ """
464
+
465
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
466
+ super().__init__()
467
+ patch_size = to_2tuple(patch_size)
468
+ self.patch_size = patch_size
469
+
470
+ self.in_chans = in_chans
471
+ self.embed_dim = embed_dim
472
+
473
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
474
+ if norm_layer is not None:
475
+ self.norm = norm_layer(embed_dim)
476
+ else:
477
+ self.norm = None
478
+
479
+ def forward(self, x):
480
+ """Forward function."""
481
+ # padding
482
+ _, _, H, W = x.size()
483
+ if W % self.patch_size[1] != 0:
484
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
485
+ if H % self.patch_size[0] != 0:
486
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
487
+
488
+ x = self.proj(x) # B C Wh Ww
489
+ if self.norm is not None:
490
+ Wh, Ww = x.size(2), x.size(3)
491
+ x = x.flatten(2).transpose(1, 2)
492
+ x = self.norm(x)
493
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
494
+
495
+ return x
496
+
497
+
498
+ class SwinTransformer(nn.Module):
499
+ """Swin Transformer backbone.
500
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
501
+ https://arxiv.org/pdf/2103.14030
502
+ Args:
503
+ pretrain_img_size (int): Input image size for training the pretrained model,
504
+ used in absolute postion embedding. Default 224.
505
+ patch_size (int | tuple(int)): Patch size. Default: 4.
506
+ in_chans (int): Number of input image channels. Default: 3.
507
+ embed_dim (int): Number of linear projection output channels. Default: 96.
508
+ depths (tuple[int]): Depths of each Swin Transformer stage.
509
+ num_heads (tuple[int]): Number of attention head of each stage.
510
+ window_size (int): Window size. Default: 7.
511
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
512
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
513
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
514
+ drop_rate (float): Dropout rate.
515
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
516
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
517
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
518
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
519
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
520
+ out_indices (Sequence[int]): Output from which stages.
521
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
522
+ -1 means not freezing any parameters.
523
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ pretrain_img_size=224,
529
+ patch_size=4,
530
+ in_chans=3,
531
+ embed_dim=96,
532
+ depths=[2, 2, 6, 2],
533
+ num_heads=[3, 6, 12, 24],
534
+ window_size=7,
535
+ mlp_ratio=4.0,
536
+ qkv_bias=True,
537
+ qk_scale=None,
538
+ drop_rate=0.0,
539
+ attn_drop_rate=0.0,
540
+ drop_path_rate=0.2,
541
+ norm_layer=nn.LayerNorm,
542
+ ape=False,
543
+ patch_norm=True,
544
+ out_indices=(0, 1, 2), #3),
545
+ frozen_stages=-1,
546
+ use_checkpoint=False,
547
+ ):
548
+ super().__init__()
549
+
550
+ self.pretrain_img_size = pretrain_img_size
551
+ self.num_layers = len(depths)
552
+ self.embed_dim = embed_dim
553
+ self.ape = ape
554
+ self.patch_norm = patch_norm
555
+ self.out_indices = out_indices
556
+ self.frozen_stages = frozen_stages
557
+
558
+ # split image into non-overlapping patches
559
+ self.patch_embed = PatchEmbed(
560
+ patch_size=patch_size,
561
+ in_chans=in_chans,
562
+ embed_dim=embed_dim,
563
+ norm_layer=norm_layer if self.patch_norm else None,
564
+ )
565
+
566
+ # absolute position embedding
567
+ if self.ape:
568
+ pretrain_img_size = to_2tuple(pretrain_img_size)
569
+ patch_size = to_2tuple(patch_size)
570
+ patches_resolution = [
571
+ pretrain_img_size[0] // patch_size[0],
572
+ pretrain_img_size[1] // patch_size[1],
573
+ ]
574
+
575
+ self.absolute_pos_embed = nn.Parameter(
576
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
577
+ )
578
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
579
+
580
+ self.pos_drop = nn.Dropout(p=drop_rate)
581
+
582
+ # stochastic depth
583
+ dpr = [
584
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
585
+ ] # stochastic depth decay rule
586
+
587
+ # build layers
588
+ self.layers = nn.ModuleList()
589
+ for i_layer in range(self.num_layers):
590
+ layer = BasicLayer(
591
+ dim=int(embed_dim * 2 ** i_layer),
592
+ depth=depths[i_layer],
593
+ num_heads=num_heads[i_layer],
594
+ window_size=window_size,
595
+ mlp_ratio=mlp_ratio,
596
+ qkv_bias=qkv_bias,
597
+ qk_scale=qk_scale,
598
+ drop=drop_rate,
599
+ attn_drop=attn_drop_rate,
600
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
601
+ norm_layer=norm_layer,
602
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
603
+ use_checkpoint=use_checkpoint,
604
+ )
605
+ self.layers.append(layer)
606
+
607
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
608
+ self.num_features = num_features
609
+
610
+ # add a norm layer for each output
611
+ for i_layer in out_indices:
612
+ layer = norm_layer(num_features[i_layer])
613
+ layer_name = f"norm{i_layer}"
614
+ self.add_module(layer_name, layer)
615
+
616
+ self._freeze_stages()
617
+
618
+ def _freeze_stages(self):
619
+ if self.frozen_stages >= 0:
620
+ self.patch_embed.eval()
621
+ for param in self.patch_embed.parameters():
622
+ param.requires_grad = False
623
+
624
+ if self.frozen_stages >= 1 and self.ape:
625
+ self.absolute_pos_embed.requires_grad = False
626
+
627
+ if self.frozen_stages >= 2:
628
+ self.pos_drop.eval()
629
+ for i in range(0, self.frozen_stages - 1):
630
+ m = self.layers[i]
631
+ m.eval()
632
+ for param in m.parameters():
633
+ param.requires_grad = False
634
+
635
+ def init_weights(self, pretrained=None):
636
+ """Initialize the weights in backbone.
637
+ Args:
638
+ pretrained (str, optional): Path to pre-trained weights.
639
+ Defaults to None.
640
+ """
641
+
642
+ def _init_weights(m):
643
+ if isinstance(m, nn.Linear):
644
+ trunc_normal_(m.weight, std=0.02)
645
+ if isinstance(m, nn.Linear) and m.bias is not None:
646
+ nn.init.constant_(m.bias, 0)
647
+ elif isinstance(m, nn.LayerNorm):
648
+ nn.init.constant_(m.bias, 0)
649
+ nn.init.constant_(m.weight, 1.0)
650
+
651
+ def forward(self, x):
652
+ """Forward function."""
653
+ x = self.patch_embed(x)
654
+
655
+ Wh, Ww = x.size(2), x.size(3)
656
+ if self.ape:
657
+ # interpolate the position embedding to the corresponding size
658
+ absolute_pos_embed = F.interpolate(
659
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
660
+ )
661
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
662
+ else:
663
+ x = x.flatten(2).transpose(1, 2)
664
+ x = self.pos_drop(x)
665
+
666
+ outs = {}
667
+ for i in range(self.num_layers):
668
+ layer = self.layers[i]
669
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
670
+
671
+ if i in self.out_indices:
672
+ norm_layer = getattr(self, f"norm{i}")
673
+ x_out = norm_layer(x_out)
674
+
675
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
676
+ outs["res{}".format(i + 2)] = out
677
+
678
+ return outs
679
+
680
+ def train(self, mode=True):
681
+ """Convert the model into training mode while keep layers freezed."""
682
+ super(SwinTransformer, self).train(mode)
683
+ self._freeze_stages()
684
+
685
+
686
+ @BACKBONE_REGISTRY.register()
687
+ class D2SwinTransformer(SwinTransformer, Backbone):
688
+ def __init__(self, cfg, input_shape):
689
+
690
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
691
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
692
+ in_chans = 3
693
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
694
+ depths = cfg.MODEL.SWIN.DEPTHS
695
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
696
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
697
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
698
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
699
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
700
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
701
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
702
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
703
+ norm_layer = nn.LayerNorm
704
+ ape = cfg.MODEL.SWIN.APE
705
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
706
+
707
+ super().__init__(
708
+ pretrain_img_size,
709
+ patch_size,
710
+ in_chans,
711
+ embed_dim,
712
+ depths,
713
+ num_heads,
714
+ window_size,
715
+ mlp_ratio,
716
+ qkv_bias,
717
+ qk_scale,
718
+ drop_rate,
719
+ attn_drop_rate,
720
+ drop_path_rate,
721
+ norm_layer,
722
+ ape,
723
+ patch_norm,
724
+ )
725
+
726
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
727
+
728
+ self._out_feature_strides = {
729
+ "res2": 4,
730
+ "res3": 8,
731
+ "res4": 16,
732
+ #"res5": 32,
733
+ }
734
+ self._out_feature_channels = {
735
+ "res2": self.num_features[0],
736
+ "res3": self.num_features[1],
737
+ "res4": self.num_features[2],
738
+ #"res5": self.num_features[3],
739
+ }
740
+
741
+ def forward(self, x):
742
+ """
743
+ Args:
744
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
745
+ Returns:
746
+ dict[str->Tensor]: names and the corresponding features
747
+ """
748
+ assert (
749
+ x.dim() == 4
750
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
751
+ outputs = {}
752
+ y = super().forward(x)
753
+ for k in y.keys():
754
+ if k in self._out_features:
755
+ outputs[k] = y[k]
756
+ return outputs
757
+
758
+ def output_shape(self):
759
+ return {
760
+ name: ShapeSpec(
761
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
762
+ )
763
+ for name in self._out_features
764
+ }
765
+
766
+ @property
767
+ def size_divisibility(self):
768
+ return 32
cat_seg/modeling/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
cat_seg/modeling/heads/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (173 Bytes). View file
 
cat_seg/modeling/heads/__pycache__/cat_seg_head.cpython-38.pyc ADDED
Binary file (3.24 kB). View file
 
cat_seg/modeling/heads/cat_seg_head.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+ from einops import rearrange
6
+
7
+ import fvcore.nn.weight_init as weight_init
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from detectron2.config import configurable
12
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
13
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
14
+
15
+ from ..transformer.cat_seg_predictor import CATSegPredictor
16
+
17
+
18
+ @SEM_SEG_HEADS_REGISTRY.register()
19
+ class CATSegHead(nn.Module):
20
+
21
+ @configurable
22
+ def __init__(
23
+ self,
24
+ input_shape: Dict[str, ShapeSpec],
25
+ *,
26
+ num_classes: int,
27
+ ignore_value: int = -1,
28
+ # extra parameters
29
+ feature_resolution: list,
30
+ transformer_predictor: nn.Module,
31
+ ):
32
+ """
33
+ NOTE: this interface is experimental.
34
+ Args:
35
+ input_shape: shapes (channels and stride) of the input features
36
+ num_classes: number of classes to predict
37
+ pixel_decoder: the pixel decoder module
38
+ loss_weight: loss weight
39
+ ignore_value: category id to be ignored during training.
40
+ transformer_predictor: the transformer decoder that makes prediction
41
+ transformer_in_feature: input feature name to the transformer_predictor
42
+ """
43
+ super().__init__()
44
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
45
+ self.in_features = [k for k, v in input_shape]
46
+ self.ignore_value = ignore_value
47
+ self.predictor = transformer_predictor
48
+ self.num_classes = num_classes
49
+ self.feature_resolution = feature_resolution
50
+
51
+ @classmethod
52
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
53
+ return {
54
+ "input_shape": {
55
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
56
+ },
57
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
58
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
59
+ "feature_resolution": cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION,
60
+ "transformer_predictor": CATSegPredictor(
61
+ cfg,
62
+ ),
63
+ }
64
+
65
+ def forward(self, features, guidance_features):
66
+ """
67
+ Arguments:
68
+ img_feats: (B, C, HW)
69
+ affinity_features: (B, C, )
70
+ """
71
+ img_feat = rearrange(features[:, 1:, :], "b (h w) c->b c h w", h=self.feature_resolution[0], w=self.feature_resolution[1])
72
+ return self.predictor(img_feat, guidance_features)
cat_seg/modeling/transformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
cat_seg/modeling/transformer/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (179 Bytes). View file
 
cat_seg/modeling/transformer/__pycache__/cat_seg_predictor.cpython-38.pyc ADDED
Binary file (5.18 kB). View file
 
cat_seg/modeling/transformer/__pycache__/model.cpython-38.pyc ADDED
Binary file (21.8 kB). View file
 
cat_seg/modeling/transformer/cat_seg_predictor.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ # Modified by Jian Ding from: https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
4
+ import fvcore.nn.weight_init as weight_init
5
+ import torch
6
+
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d
12
+
13
+ from .model import Aggregator
14
+ from cat_seg.third_party import clip
15
+ from cat_seg.third_party import imagenet_templates
16
+
17
+ import numpy as np
18
+ import open_clip
19
+ class CATSegPredictor(nn.Module):
20
+ @configurable
21
+ def __init__(
22
+ self,
23
+ *,
24
+ train_class_json: str,
25
+ test_class_json: str,
26
+ clip_pretrained: str,
27
+ prompt_ensemble_type: str,
28
+ text_guidance_dim: int,
29
+ text_guidance_proj_dim: int,
30
+ appearance_guidance_dim: int,
31
+ appearance_guidance_proj_dim: int,
32
+ prompt_depth: int,
33
+ prompt_length: int,
34
+ decoder_dims: list,
35
+ decoder_guidance_dims: list,
36
+ decoder_guidance_proj_dims: list,
37
+ num_heads: int,
38
+ num_layers: tuple,
39
+ hidden_dims: tuple,
40
+ pooling_sizes: tuple,
41
+ feature_resolution: tuple,
42
+ window_sizes: tuple,
43
+ attention_type: str,
44
+ ):
45
+ """
46
+ Args:
47
+
48
+ """
49
+ super().__init__()
50
+
51
+ import json
52
+ # use class_texts in train_forward, and test_class_texts in test_forward
53
+ with open(train_class_json, 'r') as f_in:
54
+ self.class_texts = json.load(f_in)
55
+ with open(test_class_json, 'r') as f_in:
56
+ self.test_class_texts = json.load(f_in)
57
+ assert self.class_texts != None
58
+ if self.test_class_texts == None:
59
+ self.test_class_texts = self.class_texts
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+
62
+ self.tokenizer = None
63
+ if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
64
+ # for OpenCLIP models
65
+ name, pretrain = ('ViT-H-14', 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else ('ViT-bigG-14', 'laion2b_s39b_b160k')
66
+ clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
67
+ name,
68
+ pretrained=pretrain,
69
+ device=device,
70
+ force_image_size=336,)
71
+
72
+ self.tokenizer = open_clip.get_tokenizer(name)
73
+ else:
74
+ # for OpenAI models
75
+ clip_model, clip_preprocess = clip.load(clip_pretrained, device=device, jit=False, prompt_depth=prompt_depth, prompt_length=prompt_length)
76
+
77
+ self.prompt_ensemble_type = prompt_ensemble_type
78
+
79
+ if self.prompt_ensemble_type == "imagenet_select":
80
+ prompt_templates = imagenet_templates.IMAGENET_TEMPLATES_SELECT
81
+ elif self.prompt_ensemble_type == "imagenet":
82
+ prompt_templates = imagenet_templates.IMAGENET_TEMPLATES
83
+ elif self.prompt_ensemble_type == "single":
84
+ prompt_templates = ['A photo of a {} in the scene',]
85
+ else:
86
+ raise NotImplementedError
87
+
88
+ self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
89
+ self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
90
+
91
+ self.clip_model = clip_model.float()
92
+ self.clip_preprocess = clip_preprocess
93
+
94
+ transformer = Aggregator(
95
+ text_guidance_dim=text_guidance_dim,
96
+ text_guidance_proj_dim=text_guidance_proj_dim,
97
+ appearance_guidance_dim=appearance_guidance_dim,
98
+ appearance_guidance_proj_dim=appearance_guidance_proj_dim,
99
+ decoder_dims=decoder_dims,
100
+ decoder_guidance_dims=decoder_guidance_dims,
101
+ decoder_guidance_proj_dims=decoder_guidance_proj_dims,
102
+ num_layers=num_layers,
103
+ nheads=num_heads,
104
+ hidden_dim=hidden_dims,
105
+ pooling_size=pooling_sizes,
106
+ feature_resolution=feature_resolution,
107
+ window_size=window_sizes,
108
+ attention_type=attention_type
109
+ )
110
+ self.transformer = transformer
111
+
112
+ @classmethod
113
+ def from_config(cls, cfg):#, in_channels, mask_classification):
114
+ ret = {}
115
+
116
+ ret["train_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON
117
+ ret["test_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON
118
+ ret["clip_pretrained"] = cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED
119
+ ret["prompt_ensemble_type"] = cfg.MODEL.PROMPT_ENSEMBLE_TYPE
120
+
121
+ # Aggregator parameters:
122
+ ret["text_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM
123
+ ret["text_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM
124
+ ret["appearance_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM
125
+ ret["appearance_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM
126
+
127
+ ret["decoder_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS
128
+ ret["decoder_guidance_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS
129
+ ret["decoder_guidance_proj_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS
130
+
131
+ ret["prompt_depth"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH
132
+ ret["prompt_length"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH
133
+
134
+ ret["num_layers"] = cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS
135
+ ret["num_heads"] = cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS
136
+ ret["hidden_dims"] = cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS
137
+ ret["pooling_sizes"] = cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES
138
+ ret["feature_resolution"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION
139
+ ret["window_sizes"] = cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES
140
+ ret["attention_type"] = cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE
141
+
142
+ return ret
143
+
144
+ def forward(self, x, vis_affinity):
145
+ vis = [vis_affinity[k] for k in vis_affinity.keys()][::-1]
146
+ text = self.text_features if self.training else self.text_features_test
147
+ text = text.repeat(x.shape[0], 1, 1, 1)
148
+ out = self.transformer(x, text, vis)
149
+ return out
150
+
151
+ @torch.no_grad()
152
+ def class_embeddings(self, classnames, templates, clip_model):
153
+ zeroshot_weights = []
154
+ for classname in classnames:
155
+ if ', ' in classname:
156
+ classname_splits = classname.split(', ')
157
+ texts = []
158
+ for template in templates:
159
+ for cls_split in classname_splits:
160
+ texts.append(template.format(cls_split))
161
+ else:
162
+ texts = [template.format(classname) for template in templates] # format with class
163
+ if self.tokenizer is not None:
164
+ texts = self.tokenizer(texts).cuda()
165
+ else:
166
+ texts = clip.tokenize(texts).cuda()
167
+ class_embeddings = clip_model.encode_text(texts)
168
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
169
+ if len(templates) != class_embeddings.shape[0]:
170
+ class_embeddings = class_embeddings.reshape(len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
171
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
172
+ class_embedding = class_embeddings
173
+ zeroshot_weights.append(class_embedding)
174
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
175
+ return zeroshot_weights
cat_seg/modeling/transformer/model.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+
8
+ from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
9
+
10
+ def window_partition(x, window_size: int):
11
+ """
12
+ Args:
13
+ x: (B, H, W, C)
14
+ window_size (int): window size
15
+
16
+ Returns:
17
+ windows: (num_windows*B, window_size, window_size, C)
18
+ """
19
+ B, H, W, C = x.shape
20
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
21
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
22
+ return windows
23
+
24
+
25
+ def window_reverse(windows, window_size: int, H: int, W: int):
26
+ """
27
+ Args:
28
+ windows: (num_windows*B, window_size, window_size, C)
29
+ window_size (int): Window size
30
+ H (int): Height of image
31
+ W (int): Width of image
32
+
33
+ Returns:
34
+ x: (B, H, W, C)
35
+ """
36
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
37
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
38
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
39
+ return x
40
+
41
+
42
+
43
+ class WindowAttention(nn.Module):
44
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
45
+ It supports both of shifted and non-shifted window.
46
+
47
+ Args:
48
+ dim (int): Number of input channels.
49
+ num_heads (int): Number of attention heads.
50
+ head_dim (int): Number of channels per head (dim // num_heads if not set)
51
+ window_size (tuple[int]): The height and width of the window.
52
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
53
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
54
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
55
+ """
56
+
57
+ def __init__(self, dim, appearance_guidance_dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
58
+
59
+ super().__init__()
60
+ self.dim = dim
61
+ self.window_size = to_2tuple(window_size) # Wh, Ww
62
+ win_h, win_w = self.window_size
63
+ self.window_area = win_h * win_w
64
+ self.num_heads = num_heads
65
+ head_dim = head_dim or dim // num_heads
66
+ attn_dim = head_dim * num_heads
67
+ self.scale = head_dim ** -0.5
68
+
69
+ self.q = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias)
70
+ self.k = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias)
71
+ self.v = nn.Linear(dim, attn_dim, bias=qkv_bias)
72
+ self.attn_drop = nn.Dropout(attn_drop)
73
+ self.proj = nn.Linear(attn_dim, dim)
74
+ self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+ self.softmax = nn.Softmax(dim=-1)
77
+
78
+ def forward(self, x, mask=None):
79
+ """
80
+ Args:
81
+ x: input features with shape of (num_windows*B, N, C)
82
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
83
+ """
84
+ B_, N, C = x.shape
85
+
86
+ q = self.q(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
87
+ k = self.k(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
88
+ v = self.v(x[:, :, :self.dim]).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
89
+
90
+ q = q * self.scale
91
+ attn = (q @ k.transpose(-2, -1))
92
+
93
+ if mask is not None:
94
+ num_win = mask.shape[0]
95
+ attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
96
+ attn = attn.view(-1, self.num_heads, N, N)
97
+ attn = self.softmax(attn)
98
+ else:
99
+ attn = self.softmax(attn)
100
+
101
+ attn = self.attn_drop(attn)
102
+
103
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
104
+ x = self.proj(x)
105
+ x = self.proj_drop(x)
106
+ return x
107
+
108
+
109
+ class SwinTransformerBlock(nn.Module):
110
+ r""" Swin Transformer Block.
111
+
112
+ Args:
113
+ dim (int): Number of input channels.
114
+ input_resolution (tuple[int]): Input resulotion.
115
+ window_size (int): Window size.
116
+ num_heads (int): Number of attention heads.
117
+ head_dim (int): Enforce the number of channels per head
118
+ shift_size (int): Shift size for SW-MSA.
119
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
120
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
121
+ drop (float, optional): Dropout rate. Default: 0.0
122
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
123
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
124
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
125
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
126
+ """
127
+
128
+ def __init__(
129
+ self, dim, appearance_guidance_dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
130
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
131
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
132
+ super().__init__()
133
+ self.dim = dim
134
+ self.input_resolution = input_resolution
135
+ self.window_size = window_size
136
+ self.shift_size = shift_size
137
+ self.mlp_ratio = mlp_ratio
138
+ if min(self.input_resolution) <= self.window_size:
139
+ # if window size is larger than input resolution, we don't partition windows
140
+ self.shift_size = 0
141
+ self.window_size = min(self.input_resolution)
142
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
143
+
144
+ self.norm1 = norm_layer(dim)
145
+ self.attn = WindowAttention(
146
+ dim, appearance_guidance_dim=appearance_guidance_dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
147
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
148
+
149
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
150
+ self.norm2 = norm_layer(dim)
151
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
152
+
153
+ if self.shift_size > 0:
154
+ # calculate attention mask for SW-MSA
155
+ H, W = self.input_resolution
156
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
157
+ cnt = 0
158
+ for h in (
159
+ slice(0, -self.window_size),
160
+ slice(-self.window_size, -self.shift_size),
161
+ slice(-self.shift_size, None)):
162
+ for w in (
163
+ slice(0, -self.window_size),
164
+ slice(-self.window_size, -self.shift_size),
165
+ slice(-self.shift_size, None)):
166
+ img_mask[:, h, w, :] = cnt
167
+ cnt += 1
168
+ mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1
169
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
170
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
171
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
172
+ else:
173
+ attn_mask = None
174
+
175
+ self.register_buffer("attn_mask", attn_mask)
176
+
177
+ def forward(self, x, appearance_guidance):
178
+ H, W = self.input_resolution
179
+ B, L, C = x.shape
180
+ assert L == H * W, "input feature has wrong size"
181
+
182
+ shortcut = x
183
+ x = self.norm1(x)
184
+ x = x.view(B, H, W, C)
185
+ if appearance_guidance is not None:
186
+ appearance_guidance = appearance_guidance.view(B, H, W, -1)
187
+ x = torch.cat([x, appearance_guidance], dim=-1)
188
+
189
+ # cyclic shift
190
+ if self.shift_size > 0:
191
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
192
+ else:
193
+ shifted_x = x
194
+
195
+ # partition windows
196
+ x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
197
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, x_windows.shape[-1]) # num_win*B, window_size*window_size, C
198
+
199
+ # W-MSA/SW-MSA
200
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
201
+
202
+ # merge windows
203
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
204
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
205
+
206
+ # reverse cyclic shift
207
+ if self.shift_size > 0:
208
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
209
+ else:
210
+ x = shifted_x
211
+ x = x.view(B, H * W, C)
212
+
213
+ # FFN
214
+ x = shortcut + self.drop_path(x)
215
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
216
+
217
+ return x
218
+
219
+
220
+ class SwinTransformerBlockWrapper(nn.Module):
221
+ def __init__(self, dim, appearance_guidance_dim, input_resolution, nheads=4, window_size=5):
222
+ super().__init__()
223
+ self.block_1 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=0)
224
+ self.block_2 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=window_size // 2)
225
+ self.guidance_norm = nn.LayerNorm(appearance_guidance_dim) if appearance_guidance_dim > 0 else None
226
+
227
+ def forward(self, x, appearance_guidance):
228
+ """
229
+ Arguments:
230
+ x: B C T H W
231
+ appearance_guidance: B C H W
232
+ """
233
+ B, C, T, H, W = x.shape
234
+ x = rearrange(x, 'B C T H W -> (B T) (H W) C')
235
+ if appearance_guidance is not None:
236
+ appearance_guidance = self.guidance_norm(repeat(appearance_guidance, 'B C H W -> (B T) (H W) C', T=T))
237
+ x = self.block_1(x, appearance_guidance)
238
+ x = self.block_2(x, appearance_guidance)
239
+ x = rearrange(x, '(B T) (H W) C -> B C T H W', B=B, T=T, H=H, W=W)
240
+ return x
241
+
242
+
243
+ def elu_feature_map(x):
244
+ return torch.nn.functional.elu(x) + 1
245
+
246
+
247
+ class LinearAttention(nn.Module):
248
+ def __init__(self, eps=1e-6):
249
+ super().__init__()
250
+ self.feature_map = elu_feature_map
251
+ self.eps = eps
252
+
253
+ def forward(self, queries, keys, values):
254
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
255
+ Args:
256
+ queries: [N, L, H, D]
257
+ keys: [N, S, H, D]
258
+ values: [N, S, H, D]
259
+ q_mask: [N, L]
260
+ kv_mask: [N, S]
261
+ Returns:
262
+ queried_values: (N, L, H, D)
263
+ """
264
+ Q = self.feature_map(queries)
265
+ K = self.feature_map(keys)
266
+
267
+ v_length = values.size(1)
268
+ values = values / v_length # prevent fp16 overflow
269
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
270
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
271
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
272
+
273
+ return queried_values.contiguous()
274
+
275
+
276
+ class FullAttention(nn.Module):
277
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
278
+ super().__init__()
279
+ self.use_dropout = use_dropout
280
+ self.dropout = nn.Dropout(attention_dropout)
281
+
282
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
283
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
284
+ Args:
285
+ queries: [N, L, H, D]
286
+ keys: [N, S, H, D]
287
+ values: [N, S, H, D]
288
+ q_mask: [N, L]
289
+ kv_mask: [N, S]
290
+ Returns:
291
+ queried_values: (N, L, H, D)
292
+ """
293
+
294
+ # Compute the unnormalized attention and apply the masks
295
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
296
+ if kv_mask is not None:
297
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
298
+
299
+ # Compute the attention and the weighted average
300
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
301
+ A = torch.softmax(softmax_temp * QK, dim=2)
302
+ if self.use_dropout:
303
+ A = self.dropout(A)
304
+
305
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
306
+
307
+ return queried_values.contiguous()
308
+
309
+
310
+ class AttentionLayer(nn.Module):
311
+ def __init__(self, hidden_dim, guidance_dim, nheads=8, attention_type='linear'):
312
+ super().__init__()
313
+ self.nheads = nheads
314
+ self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
315
+ self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
316
+ self.v = nn.Linear(hidden_dim, hidden_dim)
317
+
318
+ if attention_type == 'linear':
319
+ self.attention = LinearAttention()
320
+ elif attention_type == 'full':
321
+ self.attention = FullAttention()
322
+ else:
323
+ raise NotImplementedError
324
+
325
+ def forward(self, x, guidance):
326
+ """
327
+ Arguments:
328
+ x: B, L, C
329
+ guidance: B, L, C
330
+ """
331
+ q = self.q(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.q(x)
332
+ k = self.k(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.k(x)
333
+ v = self.v(x)
334
+
335
+ q = rearrange(q, 'B L (H D) -> B L H D', H=self.nheads)
336
+ k = rearrange(k, 'B S (H D) -> B S H D', H=self.nheads)
337
+ v = rearrange(v, 'B S (H D) -> B S H D', H=self.nheads)
338
+
339
+ out = self.attention(q, k, v)
340
+ out = rearrange(out, 'B L H D -> B L (H D)')
341
+ return out
342
+
343
+
344
+ class ClassTransformerLayer(nn.Module):
345
+ def __init__(self, hidden_dim=64, guidance_dim=64, nheads=8, attention_type='linear', pooling_size=(4, 4)) -> None:
346
+ super().__init__()
347
+ self.pool = nn.AvgPool2d(pooling_size)
348
+ self.attention = AttentionLayer(hidden_dim, guidance_dim, nheads=nheads, attention_type=attention_type)
349
+ self.MLP = nn.Sequential(
350
+ nn.Linear(hidden_dim, hidden_dim * 4),
351
+ nn.ReLU(),
352
+ nn.Linear(hidden_dim * 4, hidden_dim)
353
+ )
354
+
355
+ self.norm1 = nn.LayerNorm(hidden_dim)
356
+ self.norm2 = nn.LayerNorm(hidden_dim)
357
+
358
+ def pool_features(self, x):
359
+ """
360
+ Intermediate pooling layer for computational efficiency.
361
+ Arguments:
362
+ x: B, C, T, H, W
363
+ """
364
+ B = x.size(0)
365
+ x = rearrange(x, 'B C T H W -> (B T) C H W')
366
+ x = self.pool(x)
367
+ x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
368
+ return x
369
+
370
+ def forward(self, x, guidance):
371
+ """
372
+ Arguments:
373
+ x: B, C, T, H, W
374
+ guidance: B, T, C
375
+ """
376
+ B, _, _, H, W = x.size()
377
+ x_pool = self.pool_features(x)
378
+ *_, H_pool, W_pool = x_pool.size()
379
+
380
+ x_pool = rearrange(x_pool, 'B C T H W -> (B H W) T C')
381
+ if guidance is not None:
382
+ guidance = repeat(guidance, 'B T C -> (B H W) T C', H=H_pool, W=W_pool)
383
+
384
+ x_pool = x_pool + self.attention(self.norm1(x_pool), guidance) # Attention
385
+ x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP
386
+
387
+ x_pool = rearrange(x_pool, '(B H W) T C -> (B T) C H W', H=H_pool, W=W_pool)
388
+ x_pool = F.interpolate(x_pool, size=(H, W), mode='bilinear', align_corners=True)
389
+ x_pool = rearrange(x_pool, '(B T) C H W -> B C T H W', B=B)
390
+
391
+ x = x + x_pool # Residual
392
+ return x
393
+
394
+
395
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
396
+ """3x3 convolution with padding"""
397
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
398
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
399
+
400
+
401
+ def conv1x1(in_planes, out_planes, stride=1):
402
+ """1x1 convolution"""
403
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
404
+
405
+
406
+ class Bottleneck(nn.Module):
407
+ expansion = 4
408
+ __constants__ = ['downsample']
409
+
410
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
411
+ base_width=64, dilation=1, norm_layer=None):
412
+ super(Bottleneck, self).__init__()
413
+ if norm_layer is None:
414
+ norm_layer = nn.BatchNorm2d
415
+ width = int(planes * (base_width / 64.)) * groups
416
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
417
+ self.conv1 = conv1x1(inplanes, width)
418
+ self.bn1 = norm_layer(width)
419
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
420
+ self.bn2 = norm_layer(width)
421
+ self.conv3 = conv1x1(width, planes * self.expansion)
422
+ self.bn3 = norm_layer(planes * self.expansion)
423
+ self.relu = nn.ReLU(inplace=True)
424
+ self.downsample = downsample
425
+ self.stride = stride
426
+
427
+ def forward(self, x):
428
+ identity = x
429
+
430
+ out = self.conv1(x)
431
+ out = self.bn1(out)
432
+ out = self.relu(out)
433
+
434
+ out = self.conv2(out)
435
+ out = self.bn2(out)
436
+ out = self.relu(out)
437
+
438
+ out = self.conv3(out)
439
+ out = self.bn3(out)
440
+
441
+ if self.downsample is not None:
442
+ identity = self.downsample(x)
443
+
444
+ out += identity
445
+ out = self.relu(out)
446
+
447
+ return out
448
+
449
+
450
+ class AggregatorLayer(nn.Module):
451
+ def __init__(self, hidden_dim=64, text_guidance_dim=512, appearance_guidance=512, nheads=4, input_resolution=(20, 20), pooling_size=(5, 5), window_size=(10, 10), attention_type='linear') -> None:
452
+ super().__init__()
453
+ self.swin_block = SwinTransformerBlockWrapper(hidden_dim, appearance_guidance, input_resolution, nheads, window_size)
454
+ self.attention = ClassTransformerLayer(hidden_dim, text_guidance_dim, nheads=nheads, attention_type=attention_type, pooling_size=pooling_size)
455
+
456
+
457
+ def forward(self, x, appearance_guidance, text_guidance):
458
+ """
459
+ Arguments:
460
+ x: B C T H W
461
+ """
462
+ x = self.swin_block(x, appearance_guidance)
463
+ x = self.attention(x, text_guidance)
464
+ return x
465
+
466
+
467
+ class AggregatorResNetLayer(nn.Module):
468
+ def __init__(self, hidden_dim=64, appearance_guidance=512) -> None:
469
+ super().__init__()
470
+ self.conv_linear = nn.Conv2d(hidden_dim + appearance_guidance, hidden_dim, kernel_size=1, stride=1)
471
+ self.conv_layer = Bottleneck(hidden_dim, hidden_dim // 4)
472
+
473
+
474
+ def forward(self, x, appearance_guidance):
475
+ """
476
+ Arguments:
477
+ x: B C T H W
478
+ """
479
+ B, T = x.size(0), x.size(2)
480
+ x = rearrange(x, 'B C T H W -> (B T) C H W')
481
+ appearance_guidance = repeat(appearance_guidance, 'B C H W -> (B T) C H W', T=T)
482
+
483
+ x = self.conv_linear(torch.cat([x, appearance_guidance], dim=1))
484
+ x = self.conv_layer(x)
485
+ x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
486
+ return x
487
+
488
+
489
+ class DoubleConv(nn.Module):
490
+ """(convolution => [GN] => ReLU) * 2"""
491
+
492
+ def __init__(self, in_channels, out_channels, mid_channels=None):
493
+ super().__init__()
494
+ if not mid_channels:
495
+ mid_channels = out_channels
496
+ self.double_conv = nn.Sequential(
497
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
498
+ nn.GroupNorm(mid_channels // 16, mid_channels),
499
+ nn.ReLU(inplace=True),
500
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
501
+ nn.GroupNorm(mid_channels // 16, mid_channels),
502
+ nn.ReLU(inplace=True)
503
+ )
504
+
505
+ def forward(self, x):
506
+ return self.double_conv(x)
507
+
508
+
509
+ class Up(nn.Module):
510
+ """Upscaling then double conv"""
511
+
512
+ def __init__(self, in_channels, out_channels, guidance_channels):
513
+ super().__init__()
514
+
515
+ self.up = nn.ConvTranspose2d(in_channels, in_channels - guidance_channels, kernel_size=2, stride=2)
516
+ self.conv = DoubleConv(in_channels, out_channels)
517
+
518
+ def forward(self, x, guidance=None):
519
+ x = self.up(x)
520
+ if guidance is not None:
521
+ T = x.size(0) // guidance.size(0)
522
+ guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T)
523
+ x = torch.cat([x, guidance], dim=1)
524
+ return self.conv(x)
525
+
526
+
527
+ class Aggregator(nn.Module):
528
+ def __init__(self,
529
+ text_guidance_dim=512,
530
+ text_guidance_proj_dim=128,
531
+ appearance_guidance_dim=512,
532
+ appearance_guidance_proj_dim=128,
533
+ decoder_dims = (64, 32),
534
+ decoder_guidance_dims=(256, 128),
535
+ decoder_guidance_proj_dims=(32, 16),
536
+ num_layers=4,
537
+ nheads=4,
538
+ hidden_dim=128,
539
+ pooling_size=(6, 6),
540
+ feature_resolution=(24, 24),
541
+ window_size=12,
542
+ attention_type='linear',
543
+ prompt_channel=80,
544
+ ) -> None:
545
+ super().__init__()
546
+ self.num_layers = num_layers
547
+ self.hidden_dim = hidden_dim
548
+
549
+ self.layers = nn.ModuleList([
550
+ AggregatorLayer(
551
+ hidden_dim=hidden_dim, text_guidance_dim=text_guidance_proj_dim, appearance_guidance=appearance_guidance_proj_dim,
552
+ nheads=nheads, input_resolution=feature_resolution, pooling_size=pooling_size, window_size=window_size, attention_type=attention_type
553
+ ) for _ in range(num_layers)
554
+ ])
555
+
556
+ self.conv1 = nn.Conv2d(prompt_channel, hidden_dim, kernel_size=7, stride=1, padding=3)
557
+
558
+ self.guidance_projection = nn.Sequential(
559
+ nn.Conv2d(appearance_guidance_dim, appearance_guidance_proj_dim, kernel_size=3, stride=1, padding=1),
560
+ nn.ReLU(),
561
+ ) if appearance_guidance_dim > 0 else None
562
+
563
+ self.text_guidance_projection = nn.Sequential(
564
+ nn.Linear(text_guidance_dim, text_guidance_proj_dim),
565
+ nn.ReLU(),
566
+ ) if text_guidance_dim > 0 else None
567
+
568
+ self.decoder_guidance_projection = nn.ModuleList([
569
+ nn.Sequential(
570
+ nn.Conv2d(d, dp, kernel_size=3, stride=1, padding=1),
571
+ nn.ReLU(),
572
+ ) for d, dp in zip(decoder_guidance_dims, decoder_guidance_proj_dims)
573
+ ]) if decoder_guidance_dims[0] > 0 else None
574
+
575
+ self.decoder1 = Up(hidden_dim, decoder_dims[0], decoder_guidance_proj_dims[0])
576
+ self.decoder2 = Up(decoder_dims[0], decoder_dims[1], decoder_guidance_proj_dims[1])
577
+ self.head = nn.Conv2d(decoder_dims[1], 1, kernel_size=3, stride=1, padding=1)
578
+
579
+ def feature_map(self, img_feats, text_feats):
580
+ img_feats = F.normalize(img_feats, dim=1) # B C H W
581
+ img_feats = repeat(img_feats, "B C H W -> B C T H W", T=text_feats.shape[1])
582
+ text_feats = F.normalize(text_feats, dim=-1) # B T P C
583
+ text_feats = text_feats.mean(dim=-2)
584
+ text_feats = F.normalize(text_feats, dim=-1) # B T C
585
+ text_feats = repeat(text_feats, "B T C -> B C T H W", H=img_feats.shape[-2], W=img_feats.shape[-1])
586
+ return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W
587
+
588
+ def correlation(self, img_feats, text_feats):
589
+ img_feats = F.normalize(img_feats, dim=1) # B C H W
590
+ text_feats = F.normalize(text_feats, dim=-1) # B T P C
591
+ corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats)
592
+ return corr
593
+
594
+ def corr_embed(self, x):
595
+ B = x.shape[0]
596
+ corr_embed = rearrange(x, 'B P T H W -> (B T) P H W')
597
+ corr_embed = self.conv1(corr_embed)
598
+ corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B)
599
+ return corr_embed
600
+
601
+ def corr_projection(self, x, proj):
602
+ corr_embed = rearrange(x, 'B C T H W -> B T H W C')
603
+ corr_embed = proj(corr_embed)
604
+ corr_embed = rearrange(corr_embed, 'B T H W C -> B C T H W')
605
+ return corr_embed
606
+
607
+ def upsample(self, x):
608
+ B = x.shape[0]
609
+ corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
610
+ corr_embed = F.interpolate(corr_embed, scale_factor=2, mode='bilinear', align_corners=True)
611
+ corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B)
612
+ return corr_embed
613
+
614
+ def conv_decoder(self, x, guidance):
615
+ B = x.shape[0]
616
+ corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
617
+ corr_embed = self.decoder1(corr_embed, guidance[0])
618
+ corr_embed = self.decoder2(corr_embed, guidance[1])
619
+ corr_embed = self.head(corr_embed)
620
+ corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B)
621
+ return corr_embed
622
+
623
+ def forward(self, img_feats, text_feats, appearance_guidance):
624
+ """
625
+ Arguments:
626
+ img_feats: (B, C, H, W)
627
+ text_feats: (B, T, P, C)
628
+ apperance_guidance: tuple of (B, C, H, W)
629
+ """
630
+ corr = self.correlation(img_feats, text_feats)
631
+ #corr = self.feature_map(img_feats, text_feats)
632
+ corr_embed = self.corr_embed(corr)
633
+
634
+ projected_guidance, projected_text_guidance, projected_decoder_guidance = None, None, [None, None]
635
+ if self.guidance_projection is not None:
636
+ projected_guidance = self.guidance_projection(appearance_guidance[0])
637
+ if self.decoder_guidance_projection is not None:
638
+ projected_decoder_guidance = [proj(g) for proj, g in zip(self.decoder_guidance_projection, appearance_guidance[1:])]
639
+
640
+ if self.text_guidance_projection is not None:
641
+ text_feats = text_feats.mean(dim=-2)
642
+ text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
643
+ projected_text_guidance = self.text_guidance_projection(text_feats)
644
+
645
+ for layer in self.layers:
646
+ corr_embed = layer(corr_embed, projected_guidance, projected_text_guidance)
647
+
648
+ logit = self.conv_decoder(corr_embed, projected_decoder_guidance)
649
+
650
+ return logit