toto10 commited on
Commit
dfc0d4a
1 Parent(s): 9f522c4

cc83b6fc637fff12860d62fb420bfbdd9ec6115c2cff879b0a0e4f6eaddc4cd5

Browse files
Files changed (50) hide show
  1. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/__init__.py +2 -0
  2. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/bpe_simple_vocab_16e6.txt.gz +3 -0
  3. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/build.py +117 -0
  4. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/dataset_mappers/__init__.py +1 -0
  5. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py +341 -0
  6. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/dataset_mappers/dataset_mapper.py +203 -0
  7. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/dataset_mappers/oneformer_unified_dataset_mapper.py +375 -0
  8. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/__init__.py +7 -0
  9. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_ade20k_instance.py +56 -0
  10. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_ade20k_panoptic.py +394 -0
  11. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_cityscapes_panoptic.py +199 -0
  12. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_coco_panoptic2instance.py +44 -0
  13. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_coco_panoptic_annos_semseg.py +367 -0
  14. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/tokenizer.py +192 -0
  15. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/demo/colormap.py +170 -0
  16. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/demo/defaults.py +77 -0
  17. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/demo/predictor.py +190 -0
  18. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/demo/visualizer.py +1350 -0
  19. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/__init__.py +3 -0
  20. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/cityscapes_evaluation.py +201 -0
  21. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/coco_evaluator.py +563 -0
  22. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/detection_coco_evaluator.py +723 -0
  23. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/evaluator.py +228 -0
  24. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/instance_evaluation.py +110 -0
  25. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/__init__.py +5 -0
  26. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/backbone/__init__.py +1 -0
  27. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/backbone/dinat.py +324 -0
  28. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/backbone/swin.py +771 -0
  29. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/matcher.py +212 -0
  30. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/meta_arch/__init__.py +1 -0
  31. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/meta_arch/oneformer_head.py +135 -0
  32. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/__init__.py +1 -0
  33. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/fpn.py +312 -0
  34. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/msdeformattn.py +358 -0
  35. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/functions/__init__.py +13 -0
  36. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py +77 -0
  37. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/make.sh +13 -0
  38. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/modules/__init__.py +12 -0
  39. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/modules/ms_deform_attn.py +120 -0
  40. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/setup.py +78 -0
  41. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp +46 -0
  42. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h +38 -0
  43. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu +158 -0
  44. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h +35 -0
  45. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh +1332 -0
  46. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/ms_deform_attn.h +67 -0
  47. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/vision.cpp +21 -0
  48. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/test.py +92 -0
  49. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/transformer_decoder/__init__.py +2 -0
  50. extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/transformer_decoder/oneformer_transformer_decoder.py +528 -0
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import datasets
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/build.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+ import torch.utils.data as torchdata
4
+
5
+ from annotator.oneformer.detectron2.config import configurable
6
+
7
+
8
+ from annotator.oneformer.detectron2.data.common import DatasetFromList, MapDataset
9
+ from annotator.oneformer.detectron2.data.dataset_mapper import DatasetMapper
10
+ from annotator.oneformer.detectron2.data.samplers import (
11
+ InferenceSampler,
12
+ )
13
+ from annotator.oneformer.detectron2.data.build import (
14
+ get_detection_dataset_dicts,
15
+ trivial_batch_collator
16
+ )
17
+ """
18
+ This file contains the default logic to build a dataloader for training or testing.
19
+ """
20
+
21
+ __all__ = [
22
+ "build_detection_test_loader",
23
+ ]
24
+
25
+
26
+ def _test_loader_from_config(cfg, dataset_name, mapper=None):
27
+ """
28
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
29
+ standard practice is to evaluate each test set individually (not combining them).
30
+ """
31
+ if isinstance(dataset_name, str):
32
+ dataset_name = [dataset_name]
33
+
34
+ dataset = get_detection_dataset_dicts(
35
+ dataset_name,
36
+ filter_empty=False,
37
+ proposal_files=[
38
+ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
39
+ ]
40
+ if cfg.MODEL.LOAD_PROPOSALS
41
+ else None,
42
+ )
43
+ if mapper is None:
44
+ mapper = DatasetMapper(cfg, False)
45
+ return {
46
+ "dataset": dataset,
47
+ "mapper": mapper,
48
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
49
+ "sampler": InferenceSampler(len(dataset))
50
+ if not isinstance(dataset, torchdata.IterableDataset)
51
+ else None,
52
+ }
53
+
54
+
55
+ @configurable(from_config=_test_loader_from_config)
56
+ def build_detection_test_loader(
57
+ dataset: Union[List[Any], torchdata.Dataset],
58
+ *,
59
+ mapper: Callable[[Dict[str, Any]], Any],
60
+ sampler: Optional[torchdata.Sampler] = None,
61
+ batch_size: int = 1,
62
+ num_workers: int = 0,
63
+ collate_fn: Optional[Callable[[List[Any]], Any]] = None,
64
+ ) -> torchdata.DataLoader:
65
+ """
66
+ Similar to `build_detection_train_loader`, with default batch size = 1,
67
+ and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
68
+ to produce the exact set of all samples.
69
+
70
+ Args:
71
+ dataset: a list of dataset dicts,
72
+ or a pytorch dataset (either map-style or iterable). They can be obtained
73
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
74
+ mapper: a callable which takes a sample (dict) from dataset
75
+ and returns the format to be consumed by the model.
76
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
77
+ sampler: a sampler that produces
78
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
79
+ which splits the dataset across all workers. Sampler must be None
80
+ if `dataset` is iterable.
81
+ batch_size: the batch size of the data loader to be created.
82
+ Default to 1 image per worker since this is the standard when reporting
83
+ inference time in papers.
84
+ num_workers: number of parallel data loading workers
85
+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
86
+ Defaults to do no collation and return a list of data.
87
+
88
+ Returns:
89
+ DataLoader: a torch DataLoader, that loads the given detection
90
+ dataset, with test-time transformation and batching.
91
+
92
+ Examples:
93
+ ::
94
+ data_loader = build_detection_test_loader(
95
+ DatasetRegistry.get("my_test"),
96
+ mapper=DatasetMapper(...))
97
+
98
+ # or, instantiate with a CfgNode:
99
+ data_loader = build_detection_test_loader(cfg, "my_test")
100
+ """
101
+ if isinstance(dataset, list):
102
+ dataset = DatasetFromList(dataset, copy=False)
103
+ if mapper is not None:
104
+ dataset = MapDataset(dataset, mapper)
105
+ if isinstance(dataset, torchdata.IterableDataset):
106
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
107
+ else:
108
+ if sampler is None:
109
+ sampler = InferenceSampler(len(dataset))
110
+ return torchdata.DataLoader(
111
+ dataset,
112
+ batch_size=batch_size,
113
+ sampler=sampler,
114
+ drop_last=False,
115
+ num_workers=num_workers,
116
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
117
+ )
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/dataset_mappers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import copy
7
+ import logging
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from annotator.oneformer.detectron2.data import MetadataCatalog
13
+ from annotator.oneformer.detectron2.config import configurable
14
+ from annotator.oneformer.detectron2.data import detection_utils as utils
15
+ from annotator.oneformer.detectron2.data import transforms as T
16
+ from annotator.oneformer.detectron2.structures import BitMasks, Instances
17
+ from annotator.oneformer.oneformer.utils.box_ops import masks_to_boxes
18
+ from annotator.oneformer.oneformer.data.tokenizer import SimpleTokenizer, Tokenize
19
+
20
+ __all__ = ["COCOUnifiedNewBaselineDatasetMapper"]
21
+
22
+
23
+ def build_transform_gen(cfg, is_train):
24
+ """
25
+ Create a list of default :class:`Augmentation` from config.
26
+ Now it includes resizing and flipping.
27
+ Returns:
28
+ list[Augmentation]
29
+ """
30
+ assert is_train, "Only support training augmentation"
31
+ image_size = cfg.INPUT.IMAGE_SIZE
32
+ min_scale = cfg.INPUT.MIN_SCALE
33
+ max_scale = cfg.INPUT.MAX_SCALE
34
+
35
+ augmentation = []
36
+
37
+ if cfg.INPUT.RANDOM_FLIP != "none":
38
+ augmentation.append(
39
+ T.RandomFlip(
40
+ horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
41
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
42
+ )
43
+ )
44
+
45
+ augmentation.extend([
46
+ T.ResizeScale(
47
+ min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
48
+ ),
49
+ T.FixedSizeCrop(crop_size=(image_size, image_size)),
50
+ ])
51
+
52
+ return augmentation
53
+
54
+
55
+ # This is specifically designed for the COCO dataset.
56
+ class COCOUnifiedNewBaselineDatasetMapper:
57
+ """
58
+ A callable which takes a dataset dict in Detectron2 Dataset format,
59
+ and map it into a format used by OneFormer.
60
+
61
+ This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
62
+
63
+ The callable currently does the following:
64
+
65
+ 1. Read the image from "file_name"
66
+ 2. Applies geometric transforms to the image and annotation
67
+ 3. Find and applies suitable cropping to the image and annotation
68
+ 4. Prepare image and annotation to Tensors
69
+ """
70
+
71
+ @configurable
72
+ def __init__(
73
+ self,
74
+ is_train=True,
75
+ *,
76
+ num_queries,
77
+ tfm_gens,
78
+ meta,
79
+ image_format,
80
+ max_seq_len,
81
+ task_seq_len,
82
+ semantic_prob,
83
+ instance_prob,
84
+ ):
85
+ """
86
+ NOTE: this interface is experimental.
87
+ Args:
88
+ is_train: for training or inference
89
+ augmentations: a list of augmentations or deterministic transforms to apply
90
+ crop_gen: crop augmentation
91
+ tfm_gens: data augmentation
92
+ image_format: an image format supported by :func:`detection_utils.read_image`.
93
+ """
94
+ self.tfm_gens = tfm_gens
95
+ logging.getLogger(__name__).info(
96
+ "[COCOUnifiedNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
97
+ str(self.tfm_gens)
98
+ )
99
+ )
100
+
101
+ self.img_format = image_format
102
+ self.is_train = is_train
103
+ self.meta = meta
104
+ self.ignore_label = self.meta.ignore_label
105
+ self.num_queries = num_queries
106
+
107
+ self.things = []
108
+ for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
109
+ self.things.append(v)
110
+ self.class_names = self.meta.stuff_classes
111
+ self.text_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=max_seq_len)
112
+ self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
113
+ self.semantic_prob = semantic_prob
114
+ self.instance_prob = instance_prob
115
+
116
+ @classmethod
117
+ def from_config(cls, cfg, is_train=True):
118
+ # Build augmentation
119
+ tfm_gens = build_transform_gen(cfg, is_train)
120
+ dataset_names = cfg.DATASETS.TRAIN
121
+ meta = MetadataCatalog.get(dataset_names[0])
122
+
123
+ ret = {
124
+ "is_train": is_train,
125
+ "meta": meta,
126
+ "tfm_gens": tfm_gens,
127
+ "image_format": cfg.INPUT.FORMAT,
128
+ "num_queries": cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES - cfg.MODEL.TEXT_ENCODER.N_CTX,
129
+ "task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
130
+ "max_seq_len": cfg.INPUT.MAX_SEQ_LEN,
131
+ "semantic_prob": cfg.INPUT.TASK_PROB.SEMANTIC,
132
+ "instance_prob": cfg.INPUT.TASK_PROB.INSTANCE,
133
+ }
134
+ return ret
135
+
136
+ def _get_semantic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
137
+ instances = Instances(image_shape)
138
+
139
+ classes = []
140
+ texts = ["a semantic photo"] * self.num_queries
141
+ masks = []
142
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
143
+
144
+ for segment_info in segments_info:
145
+ class_id = segment_info["category_id"]
146
+ if not segment_info["iscrowd"]:
147
+ mask = pan_seg_gt == segment_info["id"]
148
+ if not np.all(mask == False):
149
+ if class_id not in classes:
150
+ cls_name = self.class_names[class_id]
151
+ classes.append(class_id)
152
+ masks.append(mask)
153
+ num_class_obj[cls_name] += 1
154
+ else:
155
+ idx = classes.index(class_id)
156
+ masks[idx] += mask
157
+ masks[idx] = np.clip(masks[idx], 0, 1).astype(np.bool)
158
+ label[mask] = class_id
159
+
160
+ num = 0
161
+ for i, cls_name in enumerate(self.class_names):
162
+ if num_class_obj[cls_name] > 0:
163
+ for _ in range(num_class_obj[cls_name]):
164
+ if num >= len(texts):
165
+ break
166
+ texts[num] = f"a photo with a {cls_name}"
167
+ num += 1
168
+
169
+ classes = np.array(classes)
170
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
171
+ if len(masks) == 0:
172
+ # Some image does not have annotation (all ignored)
173
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
174
+ instances.gt_bboxes = torch.zeros((0, 4))
175
+ else:
176
+ masks = BitMasks(
177
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
178
+ )
179
+ instances.gt_masks = masks.tensor
180
+ # Placeholder bounding boxes for stuff regions. Note that these are not used during training.
181
+ instances.gt_bboxes = torch.stack([torch.tensor([0., 0., 1., 1.])] * instances.gt_masks.shape[0])
182
+ return instances, texts, label
183
+
184
+ def _get_instance_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
185
+ instances = Instances(image_shape)
186
+
187
+ classes = []
188
+ texts = ["an instance photo"] * self.num_queries
189
+ masks = []
190
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
191
+
192
+ for segment_info in segments_info:
193
+ class_id = segment_info["category_id"]
194
+ if class_id in self.things:
195
+ if not segment_info["iscrowd"]:
196
+ mask = pan_seg_gt == segment_info["id"]
197
+ if not np.all(mask == False):
198
+ cls_name = self.class_names[class_id]
199
+ classes.append(class_id)
200
+ masks.append(mask)
201
+ num_class_obj[cls_name] += 1
202
+ label[mask] = class_id
203
+
204
+ num = 0
205
+ for i, cls_name in enumerate(self.class_names):
206
+ if num_class_obj[cls_name] > 0:
207
+ for _ in range(num_class_obj[cls_name]):
208
+ if num >= len(texts):
209
+ break
210
+ texts[num] = f"a photo with a {cls_name}"
211
+ num += 1
212
+
213
+ classes = np.array(classes)
214
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
215
+ if len(masks) == 0:
216
+ # Some image does not have annotation (all ignored)
217
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
218
+ instances.gt_bboxes = torch.zeros((0, 4))
219
+ else:
220
+ masks = BitMasks(
221
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
222
+ )
223
+ instances.gt_masks = masks.tensor
224
+ instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
225
+ return instances, texts, label
226
+
227
+ def _get_panoptic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
228
+ instances = Instances(image_shape)
229
+
230
+ classes = []
231
+ texts = ["a panoptic photo"] * self.num_queries
232
+ masks = []
233
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
234
+
235
+ for segment_info in segments_info:
236
+ class_id = segment_info["category_id"]
237
+ if not segment_info["iscrowd"]:
238
+ mask = pan_seg_gt == segment_info["id"]
239
+ if not np.all(mask == False):
240
+ cls_name = self.class_names[class_id]
241
+ classes.append(class_id)
242
+ masks.append(mask)
243
+ num_class_obj[cls_name] += 1
244
+ label[mask] = class_id
245
+
246
+ num = 0
247
+ for i, cls_name in enumerate(self.class_names):
248
+ if num_class_obj[cls_name] > 0:
249
+ for _ in range(num_class_obj[cls_name]):
250
+ if num >= len(texts):
251
+ break
252
+ texts[num] = f"a photo with a {cls_name}"
253
+ num += 1
254
+
255
+ classes = np.array(classes)
256
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
257
+ if len(masks) == 0:
258
+ # Some image does not have annotation (all ignored)
259
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
260
+ instances.gt_bboxes = torch.zeros((0, 4))
261
+ else:
262
+ masks = BitMasks(
263
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
264
+ )
265
+ instances.gt_masks = masks.tensor
266
+ instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
267
+ for i in range(instances.gt_classes.shape[0]):
268
+ # Placeholder bounding boxes for stuff regions. Note that these are not used during training.
269
+ if instances.gt_classes[i].item() not in self.things:
270
+ instances.gt_bboxes[i] = torch.tensor([0., 0., 1., 1.])
271
+ return instances, texts, label
272
+
273
+ def __call__(self, dataset_dict):
274
+ """
275
+ Args:
276
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
277
+
278
+ Returns:
279
+ dict: a format that builtin models in detectron2 accept
280
+ """
281
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
282
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
283
+ utils.check_image_size(dataset_dict, image)
284
+
285
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
286
+ image_shape = image.shape[:2] # h, w
287
+
288
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
289
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
290
+ # Therefore it's important to use torch.Tensor.
291
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
292
+
293
+ if not self.is_train:
294
+ # USER: Modify this if you want to keep them for some reason.
295
+ dataset_dict.pop("annotations", None)
296
+ return dataset_dict
297
+
298
+ # semantic segmentation
299
+ if "sem_seg_file_name" in dataset_dict:
300
+ # PyTorch transformation not implemented for uint16, so converting it to double first
301
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
302
+ sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
303
+ else:
304
+ sem_seg_gt = None
305
+
306
+ if "pan_seg_file_name" in dataset_dict:
307
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
308
+ segments_info = dataset_dict["segments_info"]
309
+
310
+ # apply the same transformation to panoptic segmentation
311
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
312
+
313
+ from panopticapi.utils import rgb2id
314
+ pan_seg_gt = rgb2id(pan_seg_gt)
315
+
316
+ prob_task = np.random.uniform(0,1.)
317
+
318
+ num_class_obj = {}
319
+
320
+ for name in self.class_names:
321
+ num_class_obj[name] = 0
322
+
323
+ if prob_task < self.semantic_prob:
324
+ task = "The task is semantic"
325
+ instances, text, sem_seg = self._get_semantic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
326
+ elif prob_task < self.instance_prob:
327
+ task = "The task is instance"
328
+ instances, text, sem_seg = self._get_instance_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
329
+ else:
330
+ task = "The task is panoptic"
331
+ instances, text, sem_seg = self._get_panoptic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
332
+
333
+
334
+ dataset_dict["sem_seg"] = torch.from_numpy(sem_seg).long()
335
+ dataset_dict["instances"] = instances
336
+ dataset_dict["orig_shape"] = image_shape
337
+ dataset_dict["task"] = task
338
+ dataset_dict["text"] = text
339
+ dataset_dict["thing_ids"] = self.things
340
+
341
+ return dataset_dict
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/dataset_mappers/dataset_mapper.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import copy
7
+ import logging
8
+ import numpy as np
9
+ from typing import List, Optional, Union
10
+ import torch
11
+
12
+ from annotator.oneformer.detectron2.config import configurable
13
+
14
+ from annotator.oneformer.detectron2.data import detection_utils as utils
15
+ from annotator.oneformer.detectron2.data import transforms as T
16
+ from annotator.oneformer.oneformer.data.tokenizer import SimpleTokenizer, Tokenize
17
+
18
+ __all__ = ["DatasetMapper"]
19
+
20
+
21
+ class DatasetMapper:
22
+ """
23
+ A callable which takes a dataset dict in Detectron2 Dataset format,
24
+ and map it into a format used by the model.
25
+
26
+ This is the default callable to be used to map your dataset dict into training data.
27
+ You may need to follow it to implement your own one for customized logic,
28
+ such as a different way to read or transform images.
29
+ See :doc:`/tutorials/data_loading` for details.
30
+
31
+ The callable currently does the following:
32
+
33
+ 1. Read the image from "file_name"
34
+ 2. Applies cropping/geometric transforms to the image and annotations
35
+ 3. Prepare data and annotations to Tensor and :class:`Instances`
36
+ """
37
+
38
+ @configurable
39
+ def __init__(
40
+ self,
41
+ is_train: bool,
42
+ *,
43
+ augmentations: List[Union[T.Augmentation, T.Transform]],
44
+ image_format: str,
45
+ task_seq_len: int,
46
+ task: str = "panoptic",
47
+ use_instance_mask: bool = False,
48
+ use_keypoint: bool = False,
49
+ instance_mask_format: str = "polygon",
50
+ keypoint_hflip_indices: Optional[np.ndarray] = None,
51
+ precomputed_proposal_topk: Optional[int] = None,
52
+ recompute_boxes: bool = False,
53
+ ):
54
+ """
55
+ NOTE: this interface is experimental.
56
+
57
+ Args:
58
+ is_train: whether it's used in training or inference
59
+ augmentations: a list of augmentations or deterministic transforms to apply
60
+ image_format: an image format supported by :func:`detection_utils.read_image`.
61
+ use_instance_mask: whether to process instance segmentation annotations, if available
62
+ use_keypoint: whether to process keypoint annotations if available
63
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
64
+ masks into this format.
65
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
66
+ precomputed_proposal_topk: if given, will load pre-computed
67
+ proposals from dataset_dict and keep the top k proposals for each image.
68
+ recompute_boxes: whether to overwrite bounding box annotations
69
+ by computing tight bounding boxes from instance mask annotations.
70
+ """
71
+ if recompute_boxes:
72
+ assert use_instance_mask, "recompute_boxes requires instance masks"
73
+ # fmt: off
74
+ self.is_train = is_train
75
+ self.augmentations = T.AugmentationList(augmentations)
76
+ self.image_format = image_format
77
+ self.use_instance_mask = use_instance_mask
78
+ self.instance_mask_format = instance_mask_format
79
+ self.use_keypoint = use_keypoint
80
+ self.keypoint_hflip_indices = keypoint_hflip_indices
81
+ self.proposal_topk = precomputed_proposal_topk
82
+ self.recompute_boxes = recompute_boxes
83
+ self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
84
+ self.task = task
85
+ assert self.task in ["panoptic", "semantic", "instance"]
86
+
87
+ # fmt: on
88
+ logger = logging.getLogger(__name__)
89
+ mode = "training" if is_train else "inference"
90
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
91
+
92
+ @classmethod
93
+ def from_config(cls, cfg, is_train: bool = True):
94
+ augs = utils.build_augmentation(cfg, is_train)
95
+ if cfg.INPUT.CROP.ENABLED and is_train:
96
+ augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
97
+ recompute_boxes = cfg.MODEL.MASK_ON
98
+ else:
99
+ recompute_boxes = False
100
+
101
+ ret = {
102
+ "is_train": is_train,
103
+ "augmentations": augs,
104
+ "image_format": cfg.INPUT.FORMAT,
105
+ "use_instance_mask": cfg.MODEL.MASK_ON,
106
+ "instance_mask_format": cfg.INPUT.MASK_FORMAT,
107
+ "use_keypoint": cfg.MODEL.KEYPOINT_ON,
108
+ "task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
109
+ "recompute_boxes": recompute_boxes,
110
+ "task": cfg.MODEL.TEST.TASK,
111
+ }
112
+
113
+ if cfg.MODEL.KEYPOINT_ON:
114
+ ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
115
+
116
+ if cfg.MODEL.LOAD_PROPOSALS:
117
+ ret["precomputed_proposal_topk"] = (
118
+ cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
119
+ if is_train
120
+ else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
121
+ )
122
+ return ret
123
+
124
+ def _transform_annotations(self, dataset_dict, transforms, image_shape):
125
+ # USER: Modify this if you want to keep them for some reason.
126
+ for anno in dataset_dict["annotations"]:
127
+ if not self.use_instance_mask:
128
+ anno.pop("segmentation", None)
129
+ if not self.use_keypoint:
130
+ anno.pop("keypoints", None)
131
+
132
+ # USER: Implement additional transformations if you have other types of data
133
+ annos = [
134
+ utils.transform_instance_annotations(
135
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
136
+ )
137
+ for obj in dataset_dict.pop("annotations")
138
+ if obj.get("iscrowd", 0) == 0
139
+ ]
140
+ instances = utils.annotations_to_instances(
141
+ annos, image_shape, mask_format=self.instance_mask_format
142
+ )
143
+
144
+ # After transforms such as cropping are applied, the bounding box may no longer
145
+ # tightly bound the object. As an example, imagine a triangle object
146
+ # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
147
+ # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
148
+ # the intersection of original bounding box and the cropping box.
149
+ if self.recompute_boxes:
150
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
151
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
152
+
153
+ def __call__(self, dataset_dict):
154
+ """
155
+ Args:
156
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
157
+
158
+ Returns:
159
+ dict: a format that builtin models in detectron2 accept
160
+ """
161
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
162
+ # USER: Write your own image loading if it's not from a file
163
+ image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
164
+ utils.check_image_size(dataset_dict, image)
165
+
166
+ task = f"The task is {self.task}"
167
+ dataset_dict["task"] = task
168
+
169
+ # USER: Remove if you don't do semantic/panoptic segmentation.
170
+ if "sem_seg_file_name" in dataset_dict:
171
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
172
+ else:
173
+ sem_seg_gt = None
174
+
175
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
176
+ transforms = self.augmentations(aug_input)
177
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
178
+
179
+ image_shape = image.shape[:2] # h, w
180
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
181
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
182
+ # Therefore it's important to use torch.Tensor.
183
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
184
+ if sem_seg_gt is not None:
185
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
186
+
187
+ # USER: Remove if you don't use pre-computed proposals.
188
+ # Most users would not need this feature.
189
+ if self.proposal_topk is not None:
190
+ utils.transform_proposals(
191
+ dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
192
+ )
193
+
194
+ if not self.is_train:
195
+ # USER: Modify this if you want to keep them for some reason.
196
+ dataset_dict.pop("annotations", None)
197
+ dataset_dict.pop("sem_seg_file_name", None)
198
+ return dataset_dict
199
+
200
+ if "annotations" in dataset_dict:
201
+ self._transform_annotations(dataset_dict, transforms, image_shape)
202
+
203
+ return dataset_dict
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/dataset_mappers/oneformer_unified_dataset_mapper.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import copy
7
+ import logging
8
+ import os
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.nn import functional as F
13
+
14
+ from annotator.oneformer.detectron2.config import configurable
15
+ from annotator.oneformer.detectron2.data import detection_utils as utils
16
+ from annotator.oneformer.detectron2.data import transforms as T
17
+ from annotator.oneformer.detectron2.structures import BitMasks, Instances
18
+ from annotator.oneformer.detectron2.data import MetadataCatalog
19
+ from annotator.oneformer.detectron2.projects.point_rend import ColorAugSSDTransform
20
+ from annotator.oneformer.oneformer.utils.box_ops import masks_to_boxes
21
+ from annotator.oneformer.oneformer.data.tokenizer import SimpleTokenizer, Tokenize
22
+
23
+ __all__ = ["OneFormerUnifiedDatasetMapper"]
24
+
25
+
26
+ class OneFormerUnifiedDatasetMapper:
27
+ """
28
+ A callable which takes a dataset dict in Detectron2 Dataset format,
29
+ and map it into a format used by OneFormer for universal segmentation.
30
+
31
+ The callable currently does the following:
32
+
33
+ 1. Read the image from "file_name"
34
+ 2. Applies geometric transforms to the image and annotation
35
+ 3. Find and applies suitable cropping to the image and annotation
36
+ 4. Prepare image and annotation to Tensors
37
+ """
38
+
39
+ @configurable
40
+ def __init__(
41
+ self,
42
+ is_train=True,
43
+ *,
44
+ name,
45
+ num_queries,
46
+ meta,
47
+ augmentations,
48
+ image_format,
49
+ ignore_label,
50
+ size_divisibility,
51
+ task_seq_len,
52
+ max_seq_len,
53
+ semantic_prob,
54
+ instance_prob,
55
+ ):
56
+ """
57
+ NOTE: this interface is experimental.
58
+ Args:
59
+ is_train: for training or inference
60
+ augmentations: a list of augmentations or deterministic transforms to apply
61
+ image_format: an image format supported by :func:`detection_utils.read_image`.
62
+ ignore_label: the label that is ignored to evaluation
63
+ size_divisibility: pad image size to be divisible by this value
64
+ """
65
+ self.is_train = is_train
66
+ self.meta = meta
67
+ self.name = name
68
+ self.tfm_gens = augmentations
69
+ self.img_format = image_format
70
+ self.ignore_label = ignore_label
71
+ self.size_divisibility = size_divisibility
72
+ self.num_queries = num_queries
73
+
74
+ logger = logging.getLogger(__name__)
75
+ mode = "training" if is_train else "inference"
76
+ logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
77
+
78
+ self.things = []
79
+ for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
80
+ self.things.append(v)
81
+ self.class_names = self.meta.stuff_classes
82
+ self.text_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=max_seq_len)
83
+ self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
84
+ self.semantic_prob = semantic_prob
85
+ self.instance_prob = instance_prob
86
+
87
+ @classmethod
88
+ def from_config(cls, cfg, is_train=True):
89
+ # Build augmentation
90
+ augs = [
91
+ T.ResizeShortestEdge(
92
+ cfg.INPUT.MIN_SIZE_TRAIN,
93
+ cfg.INPUT.MAX_SIZE_TRAIN,
94
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
95
+ )
96
+ ]
97
+ if cfg.INPUT.CROP.ENABLED:
98
+ augs.append(
99
+ T.RandomCrop_CategoryAreaConstraint(
100
+ cfg.INPUT.CROP.TYPE,
101
+ cfg.INPUT.CROP.SIZE,
102
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
103
+ cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
104
+ )
105
+ )
106
+ if cfg.INPUT.COLOR_AUG_SSD:
107
+ augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
108
+ augs.append(T.RandomFlip())
109
+
110
+ # Assume always applies to the training set.
111
+ dataset_names = cfg.DATASETS.TRAIN
112
+ meta = MetadataCatalog.get(dataset_names[0])
113
+ ignore_label = meta.ignore_label
114
+
115
+ ret = {
116
+ "is_train": is_train,
117
+ "meta": meta,
118
+ "name": dataset_names[0],
119
+ "num_queries": cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES - cfg.MODEL.TEXT_ENCODER.N_CTX,
120
+ "task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
121
+ "max_seq_len": cfg.INPUT.MAX_SEQ_LEN,
122
+ "augmentations": augs,
123
+ "image_format": cfg.INPUT.FORMAT,
124
+ "ignore_label": ignore_label,
125
+ "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
126
+ "semantic_prob": cfg.INPUT.TASK_PROB.SEMANTIC,
127
+ "instance_prob": cfg.INPUT.TASK_PROB.INSTANCE,
128
+ }
129
+ return ret
130
+
131
+ def _get_semantic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
132
+ pan_seg_gt = pan_seg_gt.numpy()
133
+ instances = Instances(image_shape)
134
+
135
+ classes = []
136
+ texts = ["a semantic photo"] * self.num_queries
137
+ masks = []
138
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
139
+
140
+ for segment_info in segments_info:
141
+ class_id = segment_info["category_id"]
142
+ if not segment_info["iscrowd"]:
143
+ mask = pan_seg_gt == segment_info["id"]
144
+ if not np.all(mask == False):
145
+ if class_id not in classes:
146
+ cls_name = self.class_names[class_id]
147
+ classes.append(class_id)
148
+ masks.append(mask)
149
+ num_class_obj[cls_name] += 1
150
+ else:
151
+ idx = classes.index(class_id)
152
+ masks[idx] += mask
153
+ masks[idx] = np.clip(masks[idx], 0, 1).astype(np.bool)
154
+ label[mask] = class_id
155
+
156
+ num = 0
157
+ for i, cls_name in enumerate(self.class_names):
158
+ if num_class_obj[cls_name] > 0:
159
+ for _ in range(num_class_obj[cls_name]):
160
+ if num >= len(texts):
161
+ break
162
+ texts[num] = f"a photo with a {cls_name}"
163
+ num += 1
164
+
165
+ classes = np.array(classes)
166
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
167
+ if len(masks) == 0:
168
+ # Some image does not have annotation (all ignored)
169
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
170
+ instances.gt_bboxes = torch.zeros((0, 4))
171
+ else:
172
+ masks = BitMasks(
173
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
174
+ )
175
+ instances.gt_masks = masks.tensor
176
+ # Placeholder bounding boxes for stuff regions. Note that these are not used during training.
177
+ instances.gt_bboxes = torch.stack([torch.tensor([0., 0., 1., 1.])] * instances.gt_masks.shape[0])
178
+ return instances, texts, label
179
+
180
+ def _get_instance_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
181
+ pan_seg_gt = pan_seg_gt.numpy()
182
+ instances = Instances(image_shape)
183
+
184
+ classes = []
185
+ texts = ["an instance photo"] * self.num_queries
186
+ masks = []
187
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
188
+
189
+ for segment_info in segments_info:
190
+ class_id = segment_info["category_id"]
191
+ if class_id in self.things:
192
+ if not segment_info["iscrowd"]:
193
+ mask = pan_seg_gt == segment_info["id"]
194
+ if not np.all(mask == False):
195
+ cls_name = self.class_names[class_id]
196
+ classes.append(class_id)
197
+ masks.append(mask)
198
+ num_class_obj[cls_name] += 1
199
+ label[mask] = class_id
200
+
201
+ num = 0
202
+ for i, cls_name in enumerate(self.class_names):
203
+ if num_class_obj[cls_name] > 0:
204
+ for _ in range(num_class_obj[cls_name]):
205
+ if num >= len(texts):
206
+ break
207
+ texts[num] = f"a photo with a {cls_name}"
208
+ num += 1
209
+
210
+ classes = np.array(classes)
211
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
212
+ if len(masks) == 0:
213
+ # Some image does not have annotation (all ignored)
214
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
215
+ instances.gt_bboxes = torch.zeros((0, 4))
216
+ else:
217
+ masks = BitMasks(
218
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
219
+ )
220
+ instances.gt_masks = masks.tensor
221
+ instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
222
+ return instances, texts, label
223
+
224
+ def _get_panoptic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
225
+ pan_seg_gt = pan_seg_gt.numpy()
226
+ instances = Instances(image_shape)
227
+
228
+ classes = []
229
+ texts = ["a panoptic photo"] * self.num_queries
230
+ masks = []
231
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
232
+
233
+ for segment_info in segments_info:
234
+ class_id = segment_info["category_id"]
235
+ if not segment_info["iscrowd"]:
236
+ mask = pan_seg_gt == segment_info["id"]
237
+ if not np.all(mask == False):
238
+ cls_name = self.class_names[class_id]
239
+ classes.append(class_id)
240
+ masks.append(mask)
241
+ num_class_obj[cls_name] += 1
242
+ label[mask] = class_id
243
+
244
+ num = 0
245
+ for i, cls_name in enumerate(self.class_names):
246
+ if num_class_obj[cls_name] > 0:
247
+ for _ in range(num_class_obj[cls_name]):
248
+ if num >= len(texts):
249
+ break
250
+ texts[num] = f"a photo with a {cls_name}"
251
+ num += 1
252
+
253
+ classes = np.array(classes)
254
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
255
+ if len(masks) == 0:
256
+ # Some image does not have annotation (all ignored)
257
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
258
+ instances.gt_bboxes = torch.zeros((0, 4))
259
+ else:
260
+ masks = BitMasks(
261
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
262
+ )
263
+ instances.gt_masks = masks.tensor
264
+ instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
265
+ for i in range(instances.gt_classes.shape[0]):
266
+ # Placeholder bounding boxes for stuff regions. Note that these are not used during training.
267
+ if instances.gt_classes[i].item() not in self.things:
268
+ instances.gt_bboxes[i] = torch.tensor([0., 0., 1., 1.])
269
+ return instances, texts, label
270
+
271
+ def __call__(self, dataset_dict):
272
+ """
273
+ Args:
274
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
275
+
276
+ Returns:
277
+ dict: a format that builtin models in detectron2 accept
278
+ """
279
+ assert self.is_train, "OneFormerUnifiedDatasetMapper should only be used for training!"
280
+
281
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
282
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
283
+ utils.check_image_size(dataset_dict, image)
284
+
285
+ # semantic segmentation
286
+ if "sem_seg_file_name" in dataset_dict:
287
+ # PyTorch transformation not implemented for uint16, so converting it to double first
288
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
289
+ else:
290
+ sem_seg_gt = None
291
+
292
+ # panoptic segmentation
293
+ if "pan_seg_file_name" in dataset_dict:
294
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
295
+ segments_info = dataset_dict["segments_info"]
296
+ else:
297
+ pan_seg_gt = None
298
+ segments_info = None
299
+
300
+ if pan_seg_gt is None:
301
+ raise ValueError(
302
+ "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
303
+ dataset_dict["file_name"]
304
+ )
305
+ )
306
+
307
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
308
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
309
+ image = aug_input.image
310
+ if sem_seg_gt is not None:
311
+ sem_seg_gt = aug_input.sem_seg
312
+
313
+ # apply the same transformation to panoptic segmentation
314
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
315
+
316
+ from panopticapi.utils import rgb2id
317
+
318
+ pan_seg_gt = rgb2id(pan_seg_gt)
319
+
320
+ # Pad image and segmentation label here!
321
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
322
+ if sem_seg_gt is not None:
323
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
324
+ pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
325
+
326
+ if self.size_divisibility > 0:
327
+ image_size = (image.shape[-2], image.shape[-1])
328
+ padding_size = [
329
+ 0,
330
+ self.size_divisibility - image_size[1],
331
+ 0,
332
+ self.size_divisibility - image_size[0],
333
+ ]
334
+ image = F.pad(image, padding_size, value=128).contiguous()
335
+ if sem_seg_gt is not None:
336
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
337
+ pan_seg_gt = F.pad(
338
+ pan_seg_gt, padding_size, value=0
339
+ ).contiguous() # 0 is the VOID panoptic label
340
+
341
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
342
+
343
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
344
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
345
+ # Therefore it's important to use torch.Tensor.
346
+ dataset_dict["image"] = image
347
+
348
+ if "annotations" in dataset_dict:
349
+ raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
350
+
351
+ prob_task = np.random.uniform(0,1.)
352
+
353
+ num_class_obj = {}
354
+
355
+ for name in self.class_names:
356
+ num_class_obj[name] = 0
357
+
358
+ if prob_task < self.semantic_prob:
359
+ task = "The task is semantic"
360
+ instances, text, sem_seg = self._get_semantic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
361
+ elif prob_task < self.instance_prob:
362
+ task = "The task is instance"
363
+ instances, text, sem_seg = self._get_instance_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
364
+ else:
365
+ task = "The task is panoptic"
366
+ instances, text, sem_seg = self._get_panoptic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
367
+
368
+ dataset_dict["sem_seg"] = torch.from_numpy(sem_seg).long()
369
+ dataset_dict["instances"] = instances
370
+ dataset_dict["orig_shape"] = image_shape
371
+ dataset_dict["task"] = task
372
+ dataset_dict["text"] = text
373
+ dataset_dict["thing_ids"] = self.things
374
+
375
+ return dataset_dict
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from . import (
2
+ register_ade20k_panoptic,
3
+ register_cityscapes_panoptic,
4
+ register_coco_panoptic_annos_semseg,
5
+ register_ade20k_instance,
6
+ register_coco_panoptic2instance,
7
+ )
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_ade20k_instance.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_ade20k_instance.py
3
+ # ------------------------------------------------------------------------------
4
+
5
+ import json
6
+ import logging
7
+ import numpy as np
8
+ import os
9
+ from PIL import Image
10
+
11
+ from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
12
+ from annotator.oneformer.detectron2.data.datasets.coco import load_coco_json, register_coco_instances
13
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
14
+
15
+ ADE_CATEGORIES = [{'id': 7, 'name': 'bed'}, {'id': 8, 'name': 'windowpane'}, {'id': 10, 'name': 'cabinet'}, {'id': 12, 'name': 'person'}, {'id': 14, 'name': 'door'}, {'id': 15, 'name': 'table'}, {'id': 18, 'name': 'curtain'}, {'id': 19, 'name': 'chair'}, {'id': 20, 'name': 'car'}, {'id': 22, 'name': 'painting'}, {'id': 23, 'name': 'sofa'}, {'id': 24, 'name': 'shelf'}, {'id': 27, 'name': 'mirror'}, {'id': 30, 'name': 'armchair'}, {'id': 31, 'name': 'seat'}, {'id': 32, 'name': 'fence'}, {'id': 33, 'name': 'desk'}, {'id': 35, 'name': 'wardrobe'}, {'id': 36, 'name': 'lamp'}, {'id': 37, 'name': 'bathtub'}, {'id': 38, 'name': 'railing'}, {'id': 39, 'name': 'cushion'}, {'id': 41, 'name': 'box'}, {'id': 42, 'name': 'column'}, {'id': 43, 'name': 'signboard'}, {'id': 44, 'name': 'chest of drawers'}, {'id': 45, 'name': 'counter'}, {'id': 47, 'name': 'sink'}, {'id': 49, 'name': 'fireplace'}, {'id': 50, 'name': 'refrigerator'}, {'id': 53, 'name': 'stairs'}, {'id': 55, 'name': 'case'}, {'id': 56, 'name': 'pool table'}, {'id': 57, 'name': 'pillow'}, {'id': 58, 'name': 'screen door'}, {'id': 62, 'name': 'bookcase'}, {'id': 64, 'name': 'coffee table'}, {'id': 65, 'name': 'toilet'}, {'id': 66, 'name': 'flower'}, {'id': 67, 'name': 'book'}, {'id': 69, 'name': 'bench'}, {'id': 70, 'name': 'countertop'}, {'id': 71, 'name': 'stove'}, {'id': 72, 'name': 'palm'}, {'id': 73, 'name': 'kitchen island'}, {'id': 74, 'name': 'computer'}, {'id': 75, 'name': 'swivel chair'}, {'id': 76, 'name': 'boat'}, {'id': 78, 'name': 'arcade machine'}, {'id': 80, 'name': 'bus'}, {'id': 81, 'name': 'towel'}, {'id': 82, 'name': 'light'}, {'id': 83, 'name': 'truck'}, {'id': 85, 'name': 'chandelier'}, {'id': 86, 'name': 'awning'}, {'id': 87, 'name': 'streetlight'}, {'id': 88, 'name': 'booth'}, {'id': 89, 'name': 'television receiver'}, {'id': 90, 'name': 'airplane'}, {'id': 92, 'name': 'apparel'}, {'id': 93, 'name': 'pole'}, {'id': 95, 'name': 'bannister'}, {'id': 97, 'name': 'ottoman'}, {'id': 98, 'name': 'bottle'}, {'id': 102, 'name': 'van'}, {'id': 103, 'name': 'ship'}, {'id': 104, 'name': 'fountain'}, {'id': 107, 'name': 'washer'}, {'id': 108, 'name': 'plaything'}, {'id': 110, 'name': 'stool'}, {'id': 111, 'name': 'barrel'}, {'id': 112, 'name': 'basket'}, {'id': 115, 'name': 'bag'}, {'id': 116, 'name': 'minibike'}, {'id': 118, 'name': 'oven'}, {'id': 119, 'name': 'ball'}, {'id': 120, 'name': 'food'}, {'id': 121, 'name': 'step'}, {'id': 123, 'name': 'trade name'}, {'id': 124, 'name': 'microwave'}, {'id': 125, 'name': 'pot'}, {'id': 126, 'name': 'animal'}, {'id': 127, 'name': 'bicycle'}, {'id': 129, 'name': 'dishwasher'}, {'id': 130, 'name': 'screen'}, {'id': 132, 'name': 'sculpture'}, {'id': 133, 'name': 'hood'}, {'id': 134, 'name': 'sconce'}, {'id': 135, 'name': 'vase'}, {'id': 136, 'name': 'traffic light'}, {'id': 137, 'name': 'tray'}, {'id': 138, 'name': 'ashcan'}, {'id': 139, 'name': 'fan'}, {'id': 142, 'name': 'plate'}, {'id': 143, 'name': 'monitor'}, {'id': 144, 'name': 'bulletin board'}, {'id': 146, 'name': 'radiator'}, {'id': 147, 'name': 'glass'}, {'id': 148, 'name': 'clock'}, {'id': 149, 'name': 'flag'}]
16
+
17
+
18
+ _PREDEFINED_SPLITS = {
19
+ # point annotations without masks
20
+ "ade20k_instance_train": (
21
+ "ADEChallengeData2016/images/training",
22
+ "ADEChallengeData2016/ade20k_instance_train.json",
23
+ ),
24
+ "ade20k_instance_val": (
25
+ "ADEChallengeData2016/images/validation",
26
+ "ADEChallengeData2016/ade20k_instance_val.json",
27
+ ),
28
+ }
29
+
30
+
31
+ def _get_ade_instances_meta():
32
+ thing_ids = [k["id"] for k in ADE_CATEGORIES]
33
+ assert len(thing_ids) == 100, len(thing_ids)
34
+ # Mapping from the incontiguous ADE category id to an id in [0, 99]
35
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
36
+ thing_classes = [k["name"] for k in ADE_CATEGORIES]
37
+ ret = {
38
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
39
+ "thing_classes": thing_classes,
40
+ }
41
+ return ret
42
+
43
+
44
+ def register_all_ade20k_instance(root):
45
+ for key, (image_root, json_file) in _PREDEFINED_SPLITS.items():
46
+ # Assume pre-defined datasets live in `./datasets`.
47
+ register_coco_instances(
48
+ key,
49
+ _get_ade_instances_meta(),
50
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
51
+ os.path.join(root, image_root),
52
+ )
53
+
54
+
55
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
56
+ register_all_ade20k_instance(_root)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_ade20k_panoptic.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_ade20k_panoptic.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import json
7
+ import os
8
+
9
+ from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
10
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
11
+
12
+ ADE20K_150_CATEGORIES = [
13
+ {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
14
+ {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
15
+ {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
16
+ {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
17
+ {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
18
+ {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
19
+ {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
20
+ {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
21
+ {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
22
+ {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
23
+ {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
24
+ {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
25
+ {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
26
+ {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
27
+ {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
28
+ {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
29
+ {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
30
+ {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
31
+ {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
32
+ {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
33
+ {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
34
+ {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
35
+ {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
36
+ {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
37
+ {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
38
+ {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
39
+ {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
40
+ {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
41
+ {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
42
+ {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
43
+ {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
44
+ {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
45
+ {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
46
+ {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
47
+ {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
48
+ {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
49
+ {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
50
+ {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
51
+ {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
52
+ {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
53
+ {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
54
+ {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
55
+ {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
56
+ {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
57
+ {
58
+ "color": [6, 51, 255],
59
+ "id": 44,
60
+ "isthing": 1,
61
+ "name": "chest of drawers, chest, bureau, dresser",
62
+ },
63
+ {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
64
+ {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
65
+ {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
66
+ {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
67
+ {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
68
+ {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
69
+ {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
70
+ {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
71
+ {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
72
+ {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
73
+ {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
74
+ {
75
+ "color": [255, 71, 0],
76
+ "id": 56,
77
+ "isthing": 1,
78
+ "name": "pool table, billiard table, snooker table",
79
+ },
80
+ {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
81
+ {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
82
+ {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
83
+ {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
84
+ {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
85
+ {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
86
+ {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
87
+ {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
88
+ {
89
+ "color": [0, 255, 133],
90
+ "id": 65,
91
+ "isthing": 1,
92
+ "name": "toilet, can, commode, crapper, pot, potty, stool, throne",
93
+ },
94
+ {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
95
+ {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
96
+ {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
97
+ {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
98
+ {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
99
+ {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
100
+ {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
101
+ {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
102
+ {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
103
+ {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
104
+ {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
105
+ {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
106
+ {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
107
+ {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
108
+ {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
109
+ {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
110
+ {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
111
+ {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
112
+ {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
113
+ {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
114
+ {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
115
+ {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
116
+ {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
117
+ {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
118
+ {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
119
+ {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
120
+ {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
121
+ {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
122
+ {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
123
+ {
124
+ "color": [0, 122, 255],
125
+ "id": 95,
126
+ "isthing": 1,
127
+ "name": "bannister, banister, balustrade, balusters, handrail",
128
+ },
129
+ {
130
+ "color": [0, 255, 163],
131
+ "id": 96,
132
+ "isthing": 0,
133
+ "name": "escalator, moving staircase, moving stairway",
134
+ },
135
+ {
136
+ "color": [255, 153, 0],
137
+ "id": 97,
138
+ "isthing": 1,
139
+ "name": "ottoman, pouf, pouffe, puff, hassock",
140
+ },
141
+ {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
142
+ {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
143
+ {
144
+ "color": [143, 255, 0],
145
+ "id": 100,
146
+ "isthing": 0,
147
+ "name": "poster, posting, placard, notice, bill, card",
148
+ },
149
+ {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
150
+ {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
151
+ {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
152
+ {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
153
+ {
154
+ "color": [133, 0, 255],
155
+ "id": 105,
156
+ "isthing": 0,
157
+ "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
158
+ },
159
+ {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
160
+ {
161
+ "color": [184, 0, 255],
162
+ "id": 107,
163
+ "isthing": 1,
164
+ "name": "washer, automatic washer, washing machine",
165
+ },
166
+ {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
167
+ {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
168
+ {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
169
+ {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
170
+ {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
171
+ {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
172
+ {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
173
+ {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
174
+ {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
175
+ {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
176
+ {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
177
+ {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
178
+ {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
179
+ {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
180
+ {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
181
+ {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
182
+ {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
183
+ {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
184
+ {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
185
+ {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
186
+ {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
187
+ {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
188
+ {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
189
+ {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
190
+ {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
191
+ {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
192
+ {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
193
+ {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
194
+ {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
195
+ {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
196
+ {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
197
+ {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
198
+ {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
199
+ {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
200
+ {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
201
+ {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
202
+ {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
203
+ {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
204
+ {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
205
+ {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
206
+ {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
207
+ {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
208
+ ]
209
+
210
+ ADE20k_COLORS = [k["color"] for k in ADE20K_150_CATEGORIES]
211
+
212
+ MetadataCatalog.get("ade20k_sem_seg_train").set(
213
+ stuff_colors=ADE20k_COLORS[:],
214
+ )
215
+
216
+ MetadataCatalog.get("ade20k_sem_seg_val").set(
217
+ stuff_colors=ADE20k_COLORS[:],
218
+ )
219
+
220
+
221
+ def load_ade20k_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta):
222
+ """
223
+ Args:
224
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
225
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
226
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
227
+ Returns:
228
+ list[dict]: a list of dicts in Detectron2 standard format. (See
229
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
230
+ """
231
+
232
+ def _convert_category_id(segment_info, meta):
233
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
234
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
235
+ segment_info["category_id"]
236
+ ]
237
+ segment_info["isthing"] = True
238
+ else:
239
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
240
+ segment_info["category_id"]
241
+ ]
242
+ segment_info["isthing"] = False
243
+ return segment_info
244
+
245
+ with PathManager.open(json_file) as f:
246
+ json_info = json.load(f)
247
+
248
+ ret = []
249
+ for ann in json_info["annotations"]:
250
+ image_id = ann["image_id"]
251
+ # TODO: currently we assume image and label has the same filename but
252
+ # different extension, and images have extension ".jpg" for COCO. Need
253
+ # to make image extension a user-provided argument if we extend this
254
+ # function to support other COCO-like datasets.
255
+ image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
256
+ label_file = os.path.join(gt_dir, ann["file_name"])
257
+ sem_label_file = os.path.join(semseg_dir, ann["file_name"])
258
+ segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
259
+ ret.append(
260
+ {
261
+ "file_name": image_file,
262
+ "image_id": image_id,
263
+ "pan_seg_file_name": label_file,
264
+ "sem_seg_file_name": sem_label_file,
265
+ "segments_info": segments_info,
266
+ }
267
+ )
268
+ assert len(ret), f"No images found in {image_dir}!"
269
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
270
+ assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
271
+ assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
272
+ return ret
273
+
274
+
275
+ def register_ade20k_panoptic(
276
+ name, metadata, image_root, panoptic_root, semantic_root, panoptic_json, instances_json=None,
277
+ ):
278
+ """
279
+ Register a "standard" version of ADE20k panoptic segmentation dataset named `name`.
280
+ The dictionaries in this registered dataset follows detectron2's standard format.
281
+ Hence it's called "standard".
282
+ Args:
283
+ name (str): the name that identifies a dataset,
284
+ e.g. "ade20k_panoptic_train"
285
+ metadata (dict): extra metadata associated with this dataset.
286
+ image_root (str): directory which contains all the images
287
+ panoptic_root (str): directory which contains panoptic annotation images in COCO format
288
+ panoptic_json (str): path to the json panoptic annotation file in COCO format
289
+ sem_seg_root (none): not used, to be consistent with
290
+ `register_coco_panoptic_separated`.
291
+ instances_json (str): path to the json instance annotation file
292
+ """
293
+ panoptic_name = name
294
+ DatasetCatalog.register(
295
+ panoptic_name,
296
+ lambda: load_ade20k_panoptic_json(
297
+ panoptic_json, image_root, panoptic_root, semantic_root, metadata
298
+ ),
299
+ )
300
+ MetadataCatalog.get(panoptic_name).set(
301
+ panoptic_root=panoptic_root,
302
+ image_root=image_root,
303
+ panoptic_json=panoptic_json,
304
+ json_file=instances_json,
305
+ evaluator_type="ade20k_panoptic_seg",
306
+ ignore_label=255,
307
+ label_divisor=1000,
308
+ **metadata,
309
+ )
310
+
311
+
312
+ _PREDEFINED_SPLITS_ADE20K_PANOPTIC = {
313
+ "ade20k_panoptic_train": (
314
+ "ADEChallengeData2016/images/training",
315
+ "ADEChallengeData2016/ade20k_panoptic_train",
316
+ "ADEChallengeData2016/ade20k_panoptic_train.json",
317
+ "ADEChallengeData2016/annotations_detectron2/training",
318
+ "ADEChallengeData2016/ade20k_instance_train.json",
319
+ ),
320
+ "ade20k_panoptic_val": (
321
+ "ADEChallengeData2016/images/validation",
322
+ "ADEChallengeData2016/ade20k_panoptic_val",
323
+ "ADEChallengeData2016/ade20k_panoptic_val.json",
324
+ "ADEChallengeData2016/annotations_detectron2/validation",
325
+ "ADEChallengeData2016/ade20k_instance_val.json",
326
+ ),
327
+ }
328
+
329
+
330
+ def get_metadata():
331
+ meta = {}
332
+ # The following metadata maps contiguous id from [0, #thing categories +
333
+ # #stuff categories) to their names and colors. We have to replica of the
334
+ # same name and color under "thing_*" and "stuff_*" because the current
335
+ # visualization function in D2 handles thing and class classes differently
336
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
337
+ # enable reusing existing visualization functions.
338
+ thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
339
+ thing_colors = [k["color"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
340
+ stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
341
+ stuff_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
342
+
343
+ meta["thing_classes"] = thing_classes
344
+ meta["thing_colors"] = thing_colors
345
+ meta["stuff_classes"] = stuff_classes
346
+ meta["stuff_colors"] = stuff_colors
347
+
348
+ # Convert category id for training:
349
+ # category id: like semantic segmentation, it is the class id for each
350
+ # pixel. Since there are some classes not used in evaluation, the category
351
+ # id is not always contiguous and thus we have two set of category ids:
352
+ # - original category id: category id in the original dataset, mainly
353
+ # used for evaluation.
354
+ # - contiguous category id: [0, #classes), in order to train the linear
355
+ # softmax classifier.
356
+ thing_dataset_id_to_contiguous_id = {}
357
+ stuff_dataset_id_to_contiguous_id = {}
358
+
359
+ for i, cat in enumerate(ADE20K_150_CATEGORIES):
360
+ if cat["isthing"]:
361
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
362
+ # else:
363
+ # stuff_dataset_id_to_contiguous_id[cat["id"]] = i
364
+
365
+ # in order to use sem_seg evaluator
366
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
367
+
368
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
369
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
370
+
371
+ return meta
372
+
373
+
374
+ def register_all_ade20k_panoptic(root):
375
+ metadata = get_metadata()
376
+ for (
377
+ prefix,
378
+ (image_root, panoptic_root, panoptic_json, semantic_root, instance_json),
379
+ ) in _PREDEFINED_SPLITS_ADE20K_PANOPTIC.items():
380
+ # The "standard" version of COCO panoptic segmentation dataset,
381
+ # e.g. used by Panoptic-DeepLab
382
+ register_ade20k_panoptic(
383
+ prefix,
384
+ metadata,
385
+ os.path.join(root, image_root),
386
+ os.path.join(root, panoptic_root),
387
+ os.path.join(root, semantic_root),
388
+ os.path.join(root, panoptic_json),
389
+ os.path.join(root, instance_json),
390
+ )
391
+
392
+
393
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
394
+ register_all_ade20k_panoptic(_root)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_cityscapes_panoptic.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/cityscapes_panoptic.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import json
7
+ import logging
8
+ import os
9
+
10
+ from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
11
+ from annotator.oneformer.detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES
12
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
13
+
14
+ """
15
+ This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog.
16
+ """
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info):
23
+ files = []
24
+ # scan through the directory
25
+ cities = PathManager.ls(image_dir)
26
+ logger.info(f"{len(cities)} cities found in '{image_dir}'.")
27
+ image_dict = {}
28
+ for city in cities:
29
+ city_img_dir = os.path.join(image_dir, city)
30
+ for basename in PathManager.ls(city_img_dir):
31
+ image_file = os.path.join(city_img_dir, basename)
32
+
33
+ suffix = "_leftImg8bit.png"
34
+ assert basename.endswith(suffix), basename
35
+ basename = os.path.basename(basename)[: -len(suffix)]
36
+
37
+ image_dict[basename] = image_file
38
+
39
+ for ann in json_info["annotations"]:
40
+ image_file = image_dict.get(ann["image_id"], None)
41
+ assert image_file is not None, "No image {} found for annotation {}".format(
42
+ ann["image_id"], ann["file_name"]
43
+ )
44
+ label_file = os.path.join(gt_dir, ann["file_name"])
45
+ segments_info = ann["segments_info"]
46
+ files.append((image_file, label_file, segments_info))
47
+
48
+ assert len(files), "No images found in {}".format(image_dir)
49
+ assert PathManager.isfile(files[0][0]), files[0][0]
50
+ assert PathManager.isfile(files[0][1]), files[0][1]
51
+ return files
52
+
53
+
54
+ def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta):
55
+ """
56
+ Args:
57
+ image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
58
+ gt_dir (str): path to the raw annotations. e.g.,
59
+ "~/cityscapes/gtFine/cityscapes_panoptic_train".
60
+ gt_json (str): path to the json file. e.g.,
61
+ "~/cityscapes/gtFine/cityscapes_panoptic_train.json".
62
+ meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id"
63
+ and "stuff_dataset_id_to_contiguous_id" to map category ids to
64
+ contiguous ids for training.
65
+
66
+ Returns:
67
+ list[dict]: a list of dicts in Detectron2 standard format. (See
68
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
69
+ """
70
+
71
+ def _convert_category_id(segment_info, meta):
72
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
73
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
74
+ segment_info["category_id"]
75
+ ]
76
+ else:
77
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
78
+ segment_info["category_id"]
79
+ ]
80
+ return segment_info
81
+
82
+ assert os.path.exists(
83
+ gt_json
84
+ ), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa
85
+
86
+
87
+ with open(gt_json) as f:
88
+ json_info = json.load(f)
89
+
90
+ files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info)
91
+ ret = []
92
+ for image_file, label_file, segments_info in files:
93
+ sem_label_file = (
94
+ image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png"
95
+ )
96
+ segments_info = [_convert_category_id(x, meta) for x in segments_info]
97
+ ret.append(
98
+ {
99
+ "file_name": image_file,
100
+ "image_id": "_".join(
101
+ os.path.splitext(os.path.basename(image_file))[0].split("_")[:3]
102
+ ),
103
+ "sem_seg_file_name": sem_label_file,
104
+ "pan_seg_file_name": label_file,
105
+ "segments_info": segments_info,
106
+ }
107
+ )
108
+ assert len(ret), f"No images found in {image_dir}!"
109
+ assert PathManager.isfile(
110
+ ret[0]["sem_seg_file_name"]
111
+ ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
112
+ assert PathManager.isfile(
113
+ ret[0]["pan_seg_file_name"]
114
+ ), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa
115
+ return ret
116
+
117
+
118
+ _RAW_CITYSCAPES_PANOPTIC_SPLITS = {
119
+ "cityscapes_fine_panoptic_train": (
120
+ "cityscapes/leftImg8bit/train",
121
+ "cityscapes/gtFine/cityscapes_panoptic_train",
122
+ "cityscapes/gtFine/cityscapes_panoptic_train.json",
123
+ ),
124
+ "cityscapes_fine_panoptic_val": (
125
+ "cityscapes/leftImg8bit/val",
126
+ "cityscapes/gtFine/cityscapes_panoptic_val",
127
+ "cityscapes/gtFine/cityscapes_panoptic_val.json",
128
+ ),
129
+ # "cityscapes_fine_panoptic_test": not supported yet
130
+ }
131
+
132
+
133
+ def register_all_cityscapes_panoptic(root):
134
+ meta = {}
135
+ # The following metadata maps contiguous id from [0, #thing categories +
136
+ # #stuff categories) to their names and colors. We have to replica of the
137
+ # same name and color under "thing_*" and "stuff_*" because the current
138
+ # visualization function in D2 handles thing and class classes differently
139
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
140
+ # enable reusing existing visualization functions.
141
+ thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
142
+ thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
143
+ stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
144
+ stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
145
+
146
+ meta["thing_classes"] = thing_classes
147
+ meta["thing_colors"] = thing_colors
148
+ meta["stuff_classes"] = stuff_classes
149
+ meta["stuff_colors"] = stuff_colors
150
+
151
+ # There are three types of ids in cityscapes panoptic segmentation:
152
+ # (1) category id: like semantic segmentation, it is the class id for each
153
+ # pixel. Since there are some classes not used in evaluation, the category
154
+ # id is not always contiguous and thus we have two set of category ids:
155
+ # - original category id: category id in the original dataset, mainly
156
+ # used for evaluation.
157
+ # - contiguous category id: [0, #classes), in order to train the classifier
158
+ # (2) instance id: this id is used to differentiate different instances from
159
+ # the same category. For "stuff" classes, the instance id is always 0; for
160
+ # "thing" classes, the instance id starts from 1 and 0 is reserved for
161
+ # ignored instances (e.g. crowd annotation).
162
+ # (3) panoptic id: this is the compact id that encode both category and
163
+ # instance id by: category_id * 1000 + instance_id.
164
+ thing_dataset_id_to_contiguous_id = {}
165
+ stuff_dataset_id_to_contiguous_id = {}
166
+
167
+ for k in CITYSCAPES_CATEGORIES:
168
+ if k["isthing"] == 1:
169
+ thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
170
+ else:
171
+ stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
172
+
173
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
174
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
175
+
176
+ for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items():
177
+ image_dir = os.path.join(root, image_dir)
178
+ gt_dir = os.path.join(root, gt_dir)
179
+ gt_json = os.path.join(root, gt_json)
180
+
181
+ if key in DatasetCatalog.list():
182
+ DatasetCatalog.remove(key)
183
+
184
+ DatasetCatalog.register(
185
+ key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta)
186
+ )
187
+ MetadataCatalog.get(key).set(
188
+ panoptic_root=gt_dir,
189
+ image_root=image_dir,
190
+ panoptic_json=gt_json,
191
+ gt_dir=gt_dir.replace("cityscapes_panoptic_", ""),
192
+ evaluator_type="cityscapes_panoptic_seg",
193
+ ignore_label=255,
194
+ label_divisor=1000,
195
+ **meta,
196
+ )
197
+
198
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
199
+ register_all_cityscapes_panoptic(_root)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_coco_panoptic2instance.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/builtin.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+
7
+ """
8
+ This file registers pre-defined datasets at hard-coded paths, and their metadata.
9
+
10
+ We hard-code metadata for common datasets. This will enable:
11
+ 1. Consistency check when loading the datasets
12
+ 2. Use models on these standard datasets directly and run demos,
13
+ without having to download the dataset annotations
14
+
15
+ We hard-code some paths to the dataset that's assumed to
16
+ exist in "./datasets/".
17
+
18
+ Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
19
+ To add new dataset, refer to the tutorial "docs/DATASETS.md".
20
+ """
21
+
22
+ import os
23
+ from annotator.oneformer.detectron2.data.datasets.builtin_meta import _get_builtin_metadata
24
+ from annotator.oneformer.detectron2.data.datasets.coco import register_coco_instances
25
+
26
+
27
+ _PREDEFINED_SPLITS_COCO = {
28
+ "coco_2017_val_panoptic2instance": ("coco/val2017", "coco/annotations/panoptic2instances_val2017.json"),
29
+ }
30
+
31
+
32
+ def register_panoptic2instances_coco(root):
33
+ for key, (image_root, json_file) in _PREDEFINED_SPLITS_COCO.items():
34
+ # Assume pre-defined datasets live in `./datasets`.
35
+ register_coco_instances(
36
+ key,
37
+ _get_builtin_metadata("coco"),
38
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
39
+ os.path.join(root, image_root),
40
+ )
41
+
42
+
43
+ _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
44
+ register_panoptic2instances_coco(_root)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/datasets/register_coco_panoptic_annos_semseg.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_coco_panoptic_annos_semseg.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import json
7
+ import os
8
+
9
+ from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
10
+ from annotator.oneformer.detectron2.data.datasets import load_sem_seg
11
+ from annotator.oneformer.detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
12
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
13
+ import contextlib
14
+ import logging
15
+ import io
16
+ from fvcore.common.timer import Timer
17
+ import annotator.oneformer.pycocotools.mask as mask_util
18
+ from annotator.oneformer.detectron2.structures import BoxMode
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ _PREDEFINED_SPLITS_COCO_PANOPTIC = {
25
+ "coco_2017_train_panoptic": (
26
+ # This is the original panoptic annotation directory
27
+ "coco/panoptic_train2017",
28
+ "coco/annotations/panoptic_train2017.json",
29
+ # This directory contains semantic annotations that are
30
+ # converted from panoptic annotations.
31
+ # It is used by PanopticFPN.
32
+ # You can use the script at detectron2/datasets/prepare_panoptic_fpn.py
33
+ # to create these directories.
34
+ "coco/panoptic_semseg_train2017",
35
+ ),
36
+ "coco_2017_val_panoptic": (
37
+ "coco/panoptic_val2017",
38
+ "coco/annotations/panoptic_val2017.json",
39
+ "coco/panoptic_semseg_val2017",
40
+ ),
41
+ }
42
+
43
+ def load_coco_instance_json(json_file, image_root, dataset_name=None):
44
+ from annotator.oneformer.pycocotools.coco import COCO
45
+
46
+ timer = Timer()
47
+ json_file = PathManager.get_local_path(json_file)
48
+ with contextlib.redirect_stdout(io.StringIO()):
49
+ coco_api = COCO(json_file)
50
+ if timer.seconds() > 1:
51
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
52
+
53
+ id_map = None
54
+ if dataset_name is not None:
55
+ meta = MetadataCatalog.get(dataset_name)
56
+ cat_ids = sorted(coco_api.getCatIds())
57
+ cats = coco_api.loadCats(cat_ids)
58
+ # The categories in a custom json file may not be sorted.
59
+ thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
60
+ meta.thing_classes = thing_classes
61
+
62
+ # In COCO, certain category ids are artificially removed,
63
+ # and by convention they are always ignored.
64
+ # We deal with COCO's id issue and translate
65
+ # the category ids to contiguous ids in [0, 80).
66
+
67
+ # It works by looking at the "categories" field in the json, therefore
68
+ # if users' own json also have incontiguous ids, we'll
69
+ # apply this mapping as well but print a warning.
70
+ if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
71
+ if "coco" not in dataset_name:
72
+ logger.warning(
73
+ """
74
+ Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
75
+ """
76
+ )
77
+ id_map = {v: i for i, v in enumerate(cat_ids)}
78
+ meta.thing_dataset_id_to_contiguous_id = id_map
79
+
80
+ # sort indices for reproducible results
81
+ img_ids = sorted(coco_api.imgs.keys())
82
+ # imgs is a list of dicts, each looks something like:
83
+ # {'license': 4,
84
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
85
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
86
+ # 'height': 427,
87
+ # 'width': 640,
88
+ # 'date_captured': '2013-11-17 05:57:24',
89
+ # 'id': 1268}
90
+ imgs = coco_api.loadImgs(img_ids)
91
+ # anns is a list[list[dict]], where each dict is an annotation
92
+ # record for an object. The inner list enumerates the objects in an image
93
+ # and the outer list enumerates over images. Example of anns[0]:
94
+ # [{'segmentation': [[192.81,
95
+ # 247.09,
96
+ # ...
97
+ # 219.03,
98
+ # 249.06]],
99
+ # 'area': 1035.749,
100
+ # 'iscrowd': 0,
101
+ # 'image_id': 1268,
102
+ # 'bbox': [192.81, 224.8, 74.73, 33.43],
103
+ # 'category_id': 16,
104
+ # 'id': 42986},
105
+ # ...]
106
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
107
+ total_num_valid_anns = sum([len(x) for x in anns])
108
+ total_num_anns = len(coco_api.anns)
109
+ if total_num_valid_anns < total_num_anns:
110
+ logger.warning(
111
+ f"{json_file} contains {total_num_anns} annotations, but only "
112
+ f"{total_num_valid_anns} of them match to images in the file."
113
+ )
114
+
115
+ if "minival" not in json_file:
116
+ # The popular valminusminival & minival annotations for COCO2014 contain this bug.
117
+ # However the ratio of buggy annotations there is tiny and does not affect accuracy.
118
+ # Therefore we explicitly white-list them.
119
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
120
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
121
+ json_file
122
+ )
123
+
124
+ imgs_anns = list(zip(imgs, anns))
125
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
126
+
127
+ dataset_dicts = {}
128
+
129
+ ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"]
130
+
131
+ num_instances_without_valid_segmentation = 0
132
+
133
+ for (img_dict, anno_dict_list) in imgs_anns:
134
+ record = {}
135
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
136
+ record["height"] = img_dict["height"]
137
+ record["width"] = img_dict["width"]
138
+ image_id = record["image_id"] = img_dict["id"]
139
+
140
+ objs = []
141
+ for anno in anno_dict_list:
142
+ # Check that the image_id in this annotation is the same as
143
+ # the image_id we're looking at.
144
+ # This fails only when the data parsing logic or the annotation file is buggy.
145
+
146
+ # The original COCO valminusminival2014 & minival2014 annotation files
147
+ # actually contains bugs that, together with certain ways of using COCO API,
148
+ # can trigger this assertion.
149
+ assert anno["image_id"] == image_id
150
+
151
+ assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
152
+
153
+ obj = {key: anno[key] for key in ann_keys if key in anno}
154
+ if "bbox" in obj and len(obj["bbox"]) == 0:
155
+ raise ValueError(
156
+ f"One annotation of image {image_id} contains empty 'bbox' value! "
157
+ "This json does not have valid COCO format."
158
+ )
159
+
160
+ segm = anno.get("segmentation", None)
161
+ if segm: # either list[list[float]] or dict(RLE)
162
+ if isinstance(segm, dict):
163
+ if isinstance(segm["counts"], list):
164
+ # convert to compressed RLE
165
+ segm = mask_util.frPyObjects(segm, *segm["size"])
166
+ else:
167
+ # filter out invalid polygons (< 3 points)
168
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
169
+ if len(segm) == 0:
170
+ num_instances_without_valid_segmentation += 1
171
+ continue # ignore this instance
172
+ obj["segmentation"] = segm
173
+
174
+ keypts = anno.get("keypoints", None)
175
+ if keypts: # list[int]
176
+ for idx, v in enumerate(keypts):
177
+ if idx % 3 != 2:
178
+ # COCO's segmentation coordinates are floating points in [0, H or W],
179
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
180
+ # Therefore we assume the coordinates are "pixel indices" and
181
+ # add 0.5 to convert to floating point coordinates.
182
+ keypts[idx] = v + 0.5
183
+ obj["keypoints"] = keypts
184
+
185
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
186
+ if id_map:
187
+ annotation_category_id = obj["category_id"]
188
+ try:
189
+ obj["category_id"] = id_map[annotation_category_id]
190
+ except KeyError as e:
191
+ raise KeyError(
192
+ f"Encountered category_id={annotation_category_id} "
193
+ "but this id does not exist in 'categories' of the json file."
194
+ ) from e
195
+ objs.append(obj)
196
+ record["annotations"] = objs
197
+ dataset_dicts[image_id] = record
198
+
199
+ if num_instances_without_valid_segmentation > 0:
200
+ logger.warning(
201
+ "Filtered out {} instances without valid segmentation. ".format(
202
+ num_instances_without_valid_segmentation
203
+ )
204
+ + "There might be issues in your dataset generation process. Please "
205
+ "check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
206
+ )
207
+ return dataset_dicts
208
+
209
+ def get_metadata():
210
+ meta = {}
211
+ # The following metadata maps contiguous id from [0, #thing categories +
212
+ # #stuff categories) to their names and colors. We have to replica of the
213
+ # same name and color under "thing_*" and "stuff_*" because the current
214
+ # visualization function in D2 handles thing and class classes differently
215
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
216
+ # enable reusing existing visualization functions.
217
+ thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
218
+ thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
219
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
220
+ stuff_colors = [k["color"] for k in COCO_CATEGORIES]
221
+
222
+ meta["thing_classes"] = thing_classes
223
+ meta["thing_colors"] = thing_colors
224
+ meta["stuff_classes"] = stuff_classes
225
+ meta["stuff_colors"] = stuff_colors
226
+
227
+ # Convert category id for training:
228
+ # category id: like semantic segmentation, it is the class id for each
229
+ # pixel. Since there are some classes not used in evaluation, the category
230
+ # id is not always contiguous and thus we have two set of category ids:
231
+ # - original category id: category id in the original dataset, mainly
232
+ # used for evaluation.
233
+ # - contiguous category id: [0, #classes), in order to train the linear
234
+ # softmax classifier.
235
+ thing_dataset_id_to_contiguous_id = {}
236
+ stuff_dataset_id_to_contiguous_id = {}
237
+
238
+ for i, cat in enumerate(COCO_CATEGORIES):
239
+ if cat["isthing"]:
240
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
241
+ # else:
242
+ # stuff_dataset_id_to_contiguous_id[cat["id"]] = i
243
+
244
+ # in order to use sem_seg evaluator
245
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
246
+
247
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
248
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
249
+
250
+ return meta
251
+
252
+
253
+ def load_coco_panoptic_json(json_file, instances_json, instances_name, image_dir, gt_dir, semseg_dir, meta):
254
+ """
255
+ Args:
256
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
257
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
258
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
259
+ Returns:
260
+ list[dict]: a list of dicts in Detectron2 standard format. (See
261
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
262
+ """
263
+
264
+ def _convert_category_id(segment_info, meta):
265
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
266
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
267
+ segment_info["category_id"]
268
+ ]
269
+ segment_info["isthing"] = True
270
+ else:
271
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
272
+ segment_info["category_id"]
273
+ ]
274
+ segment_info["isthing"] = False
275
+ return segment_info
276
+
277
+ with PathManager.open(json_file) as f:
278
+ json_info = json.load(f)
279
+
280
+ instance_data_dicts = load_coco_instance_json(instances_json, image_dir.replace("panoptic_", ""), instances_name)
281
+
282
+ ret = []
283
+ for ann in json_info["annotations"]:
284
+ image_id = int(ann["image_id"])
285
+ # TODO: currently we assume image and label has the same filename but
286
+ # different extension, and images have extension ".jpg" for COCO. Need
287
+ # to make image extension a user-provided argument if we extend this
288
+ # function to support other COCO-like datasets.
289
+ image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
290
+ label_file = os.path.join(gt_dir, ann["file_name"])
291
+ sem_label_file = os.path.join(semseg_dir, ann["file_name"])
292
+ segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
293
+ ret.append(
294
+ {
295
+ "file_name": image_file,
296
+ "image_id": image_id,
297
+ "pan_seg_file_name": label_file,
298
+ "sem_seg_file_name": sem_label_file,
299
+ "segments_info": segments_info,
300
+ "annotations": instance_data_dicts[image_id]["annotations"],
301
+ }
302
+ )
303
+ assert len(ret), f"No images found in {image_dir}!"
304
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
305
+ assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
306
+ assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
307
+ return ret
308
+
309
+
310
+ def register_coco_panoptic_annos_sem_seg(
311
+ name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json, instances_name,
312
+ ):
313
+ panoptic_name = name
314
+ delattr(MetadataCatalog.get(panoptic_name), "thing_classes")
315
+ delattr(MetadataCatalog.get(panoptic_name), "thing_colors")
316
+ MetadataCatalog.get(panoptic_name).set(
317
+ thing_classes=metadata["thing_classes"],
318
+ thing_colors=metadata["thing_colors"],
319
+ # thing_dataset_id_to_contiguous_id=metadata["thing_dataset_id_to_contiguous_id"],
320
+ )
321
+
322
+ # the name is "coco_2017_train_panoptic_with_sem_seg" and "coco_2017_val_panoptic_with_sem_seg"
323
+ semantic_name = name + "_with_sem_seg"
324
+ DatasetCatalog.register(
325
+ semantic_name,
326
+ lambda: load_coco_panoptic_json(panoptic_json, instances_json, instances_name, image_root, panoptic_root, sem_seg_root, metadata),
327
+ )
328
+ MetadataCatalog.get(semantic_name).set(
329
+ sem_seg_root=sem_seg_root,
330
+ panoptic_root=panoptic_root,
331
+ image_root=image_root,
332
+ panoptic_json=panoptic_json,
333
+ json_file=instances_json,
334
+ evaluator_type="coco_panoptic_seg",
335
+ ignore_label=255,
336
+ label_divisor=1000,
337
+ **metadata,
338
+ )
339
+
340
+
341
+ def register_all_coco_panoptic_annos_sem_seg(root):
342
+ for (
343
+ prefix,
344
+ (panoptic_root, panoptic_json, semantic_root),
345
+ ) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
346
+
347
+ prefix_instances = prefix[: -len("_panoptic")]
348
+ instances_meta = MetadataCatalog.get(prefix_instances)
349
+ image_root, instances_json = instances_meta.image_root, instances_meta.json_file
350
+
351
+ if 'val' in instances_json:
352
+ instances_json = instances_json.replace('instances_', 'panoptic2instances_')
353
+
354
+ register_coco_panoptic_annos_sem_seg(
355
+ prefix,
356
+ get_metadata(),
357
+ image_root,
358
+ os.path.join(root, panoptic_root),
359
+ os.path.join(root, panoptic_json),
360
+ os.path.join(root, semantic_root),
361
+ instances_json,
362
+ prefix_instances,
363
+ )
364
+
365
+
366
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
367
+ register_all_coco_panoptic_annos_sem_seg(_root)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/data/tokenizer.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------
2
+ # MIT License
3
+ #
4
+ # Copyright (c) 2021 OpenAI
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ #
24
+ # Modified by Jiarui Xu
25
+ # -------------------------------------------------------------------------
26
+
27
+ import gzip
28
+ import html
29
+ import os
30
+ from functools import lru_cache
31
+
32
+ import ftfy
33
+ import regex as re
34
+ import torch
35
+
36
+
37
+ @lru_cache()
38
+ def default_bpe():
39
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bpe_simple_vocab_16e6.txt.gz')
40
+
41
+ @lru_cache()
42
+ def bytes_to_unicode():
43
+ """Returns list of utf-8 byte and a corresponding list of unicode strings.
44
+
45
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
46
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent
47
+ coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables
48
+ between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
49
+ """
50
+ bs = list(range(ord('!'), ord('~') + 1)) + list(range(ord('¡'), ord('¬') + 1)) + list(range(ord('®'), ord('ÿ') + 1))
51
+ cs = bs[:]
52
+ n = 0
53
+ for b in range(2**8):
54
+ if b not in bs:
55
+ bs.append(b)
56
+ cs.append(2**8 + n)
57
+ n += 1
58
+ cs = [chr(n) for n in cs]
59
+ return dict(zip(bs, cs))
60
+
61
+
62
+ def get_pairs(word):
63
+ """Return set of symbol pairs in a word.
64
+
65
+ Word is represented as tuple of symbols (symbols being variable-length strings).
66
+ """
67
+ pairs = set()
68
+ prev_char = word[0]
69
+ for char in word[1:]:
70
+ pairs.add((prev_char, char))
71
+ prev_char = char
72
+ return pairs
73
+
74
+
75
+ def basic_clean(text):
76
+ text = ftfy.fix_text(text)
77
+ text = html.unescape(html.unescape(text))
78
+ return text.strip()
79
+
80
+
81
+ def whitespace_clean(text):
82
+ text = re.sub(r'\s+', ' ', text)
83
+ text = text.strip()
84
+ return text
85
+
86
+ class Tokenize:
87
+
88
+ def __init__(self, tokenizer, max_seq_len=77, truncate=True):
89
+ self.tokenizer = tokenizer
90
+ self.max_seq_len = max_seq_len
91
+ self.truncate = truncate
92
+
93
+ def __call__(self, texts):
94
+ expanded_dim = False
95
+ if isinstance(texts, str):
96
+ texts = [texts]
97
+ expanded_dim = True
98
+
99
+ sot_token = self.tokenizer.encoder['<|startoftext|>']
100
+ eot_token = self.tokenizer.encoder['<|endoftext|>']
101
+ all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
102
+ result = torch.zeros(len(all_tokens), self.max_seq_len, dtype=torch.long)
103
+
104
+ for i, tokens in enumerate(all_tokens):
105
+ if len(tokens) > self.max_seq_len:
106
+ if self.truncate:
107
+ tokens = tokens[:self.max_seq_len]
108
+ tokens[-1] = eot_token
109
+ else:
110
+ raise RuntimeError(f'Input {texts[i]} is too long for context length {self.max_seq_len}')
111
+ result[i, :len(tokens)] = torch.tensor(tokens)
112
+
113
+ if expanded_dim:
114
+ return result[0]
115
+
116
+ return result
117
+
118
+
119
+ class SimpleTokenizer(object):
120
+
121
+ def __init__(self, bpe_path: str = default_bpe()):
122
+ self.byte_encoder = bytes_to_unicode()
123
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
124
+ merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
125
+ merges = merges[1:49152 - 256 - 2 + 1]
126
+ merges = [tuple(merge.split()) for merge in merges]
127
+ vocab = list(bytes_to_unicode().values())
128
+ vocab = vocab + [v + '</w>' for v in vocab]
129
+ for merge in merges:
130
+ vocab.append(''.join(merge))
131
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
132
+ self.encoder = dict(zip(vocab, range(len(vocab))))
133
+ self.decoder = {v: k for k, v in self.encoder.items()}
134
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
135
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
136
+ self.pat = re.compile(
137
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
138
+ re.IGNORECASE)
139
+
140
+ def bpe(self, token):
141
+ if token in self.cache:
142
+ return self.cache[token]
143
+ word = tuple(token[:-1]) + (token[-1] + '</w>', )
144
+ pairs = get_pairs(word)
145
+
146
+ if not pairs:
147
+ return token + '</w>'
148
+
149
+ while True:
150
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
151
+ if bigram not in self.bpe_ranks:
152
+ break
153
+ first, second = bigram
154
+ new_word = []
155
+ i = 0
156
+ while i < len(word):
157
+ try:
158
+ j = word.index(first, i)
159
+ new_word.extend(word[i:j])
160
+ i = j
161
+ except: # noqa: E722
162
+ new_word.extend(word[i:])
163
+ break
164
+
165
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
166
+ new_word.append(first + second)
167
+ i += 2
168
+ else:
169
+ new_word.append(word[i])
170
+ i += 1
171
+ new_word = tuple(new_word)
172
+ word = new_word
173
+ if len(word) == 1:
174
+ break
175
+ else:
176
+ pairs = get_pairs(word)
177
+ word = ' '.join(word)
178
+ self.cache[token] = word
179
+ return word
180
+
181
+ def encode(self, text):
182
+ bpe_tokens = []
183
+ text = whitespace_clean(basic_clean(text)).lower()
184
+ for token in re.findall(self.pat, text):
185
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
186
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
187
+ return bpe_tokens
188
+
189
+ def decode(self, tokens):
190
+ text = ''.join([self.decoder[token] for token in tokens])
191
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace').replace('</w>', ' ')
192
+ return text
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/demo/colormap.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ """
4
+ An awesome colormap for really neat visualizations.
5
+ Copied from Detectron, and removed gray colors.
6
+ """
7
+
8
+ import numpy as np
9
+ import random
10
+ random.seed(0)
11
+
12
+ __all__ = ["colormap", "random_color", "random_colors"]
13
+
14
+ # fmt: off
15
+ # RGB:
16
+ # _COLORS = np.array(
17
+ # [
18
+ # 0.000, 0.447, 0.741,
19
+ # 0.850, 0.325, 0.098,
20
+ # 0.929, 0.694, 0.125,
21
+ # 0.494, 0.184, 0.556,
22
+ # 0.466, 0.674, 0.188,
23
+ # 0.301, 0.745, 0.933,
24
+ # 0.635, 0.078, 0.184,
25
+ # 0.300, 0.300, 0.300,
26
+ # 0.600, 0.600, 0.600,
27
+ # 1.000, 0.000, 0.000,
28
+ # 1.000, 0.500, 0.000,
29
+ # 0.749, 0.749, 0.000,
30
+ # 0.000, 1.000, 0.000,
31
+ # 0.000, 0.000, 1.000,
32
+ # 0.667, 0.000, 1.000,
33
+ # 0.333, 0.333, 0.000,
34
+ # 0.333, 0.667, 0.000,
35
+ # 0.333, 1.000, 0.000,
36
+ # 0.667, 0.333, 0.000,
37
+ # 0.667, 0.667, 0.000,
38
+ # 0.667, 1.000, 0.000,
39
+ # 1.000, 0.333, 0.000,
40
+ # 1.000, 0.667, 0.000,
41
+ # 1.000, 1.000, 0.000,
42
+ # 0.000, 0.333, 0.500,
43
+ # 0.000, 0.667, 0.500,
44
+ # 0.000, 1.000, 0.500,
45
+ # 0.333, 0.000, 0.500,
46
+ # 0.333, 0.333, 0.500,
47
+ # 0.333, 0.667, 0.500,
48
+ # 0.333, 1.000, 0.500,
49
+ # 0.667, 0.000, 0.500,
50
+ # 0.667, 0.333, 0.500,
51
+ # 0.667, 0.667, 0.500,
52
+ # 0.667, 1.000, 0.500,
53
+ # 1.000, 0.000, 0.500,
54
+ # 1.000, 0.333, 0.500,
55
+ # 1.000, 0.667, 0.500,
56
+ # 1.000, 1.000, 0.500,
57
+ # 0.000, 0.333, 1.000,
58
+ # 0.000, 0.667, 1.000,
59
+ # 0.000, 1.000, 1.000,
60
+ # 0.333, 0.000, 1.000,
61
+ # 0.333, 0.333, 1.000,
62
+ # 0.333, 0.667, 1.000,
63
+ # 0.333, 1.000, 1.000,
64
+ # 0.667, 0.000, 1.000,
65
+ # 0.667, 0.333, 1.000,
66
+ # 0.667, 0.667, 1.000,
67
+ # 0.667, 1.000, 1.000,
68
+ # 1.000, 0.000, 1.000,
69
+ # 1.000, 0.333, 1.000,
70
+ # 1.000, 0.667, 1.000,
71
+ # 0.333, 0.000, 0.000,
72
+ # 0.500, 0.000, 0.000,
73
+ # 0.667, 0.000, 0.000,
74
+ # 0.833, 0.000, 0.000,
75
+ # 1.000, 0.000, 0.000,
76
+ # 0.000, 0.167, 0.000,
77
+ # 0.000, 0.333, 0.000,
78
+ # 0.000, 0.500, 0.000,
79
+ # 0.000, 0.667, 0.000,
80
+ # 0.000, 0.833, 0.000,
81
+ # 0.000, 1.000, 0.000,
82
+ # 0.000, 0.000, 0.167,
83
+ # 0.000, 0.000, 0.333,
84
+ # 0.000, 0.000, 0.500,
85
+ # 0.000, 0.000, 0.667,
86
+ # 0.000, 0.000, 0.833,
87
+ # 0.000, 0.000, 1.000,
88
+ # 0.000, 0.000, 0.000,
89
+ # 0.143, 0.143, 0.143,
90
+ # 0.857, 0.857, 0.857,
91
+ # 1.000, 1.000, 1.000
92
+ # ]
93
+ # ).astype(np.float32).reshape(-1, 3)
94
+ # fmt: on
95
+
96
+ _COLORS = []
97
+
98
+
99
+ def gen_color():
100
+ color = tuple(np.round(np.random.choice(range(256), size=3)/255, 3))
101
+ if color not in _COLORS and np.mean(color) != 0.0:
102
+ _COLORS.append(color)
103
+ else:
104
+ gen_color()
105
+
106
+
107
+ for _ in range(300):
108
+ gen_color()
109
+
110
+
111
+ def colormap(rgb=False, maximum=255):
112
+ """
113
+ Args:
114
+ rgb (bool): whether to return RGB colors or BGR colors.
115
+ maximum (int): either 255 or 1
116
+ Returns:
117
+ ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
118
+ """
119
+ assert maximum in [255, 1], maximum
120
+ c = _COLORS * maximum
121
+ if not rgb:
122
+ c = c[:, ::-1]
123
+ return c
124
+
125
+
126
+ def random_color(rgb=False, maximum=255):
127
+ """
128
+ Args:
129
+ rgb (bool): whether to return RGB colors or BGR colors.
130
+ maximum (int): either 255 or 1
131
+ Returns:
132
+ ndarray: a vector of 3 numbers
133
+ """
134
+ idx = np.random.randint(0, len(_COLORS))
135
+ ret = _COLORS[idx] * maximum
136
+ if not rgb:
137
+ ret = ret[::-1]
138
+ return ret
139
+
140
+
141
+ def random_colors(N, rgb=False, maximum=255):
142
+ """
143
+ Args:
144
+ N (int): number of unique colors needed
145
+ rgb (bool): whether to return RGB colors or BGR colors.
146
+ maximum (int): either 255 or 1
147
+ Returns:
148
+ ndarray: a list of random_color
149
+ """
150
+ indices = random.sample(range(len(_COLORS)), N)
151
+ ret = [_COLORS[i] * maximum for i in indices]
152
+ if not rgb:
153
+ ret = [x[::-1] for x in ret]
154
+ return ret
155
+
156
+
157
+ if __name__ == "__main__":
158
+ import cv2
159
+
160
+ size = 100
161
+ H, W = 10, 10
162
+ canvas = np.random.rand(H * size, W * size, 3).astype("float32")
163
+ for h in range(H):
164
+ for w in range(W):
165
+ idx = h * W + w
166
+ if idx >= len(_COLORS):
167
+ break
168
+ canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
169
+ cv2.imshow("a", canvas)
170
+ cv2.waitKey(0)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/demo/defaults.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import annotator.oneformer.detectron2.data.transforms as T
3
+ from annotator.oneformer.detectron2.checkpoint import DetectionCheckpointer
4
+ from annotator.oneformer.detectron2.data import (
5
+ MetadataCatalog,
6
+ )
7
+ from annotator.oneformer.detectron2.modeling import build_model
8
+
9
+
10
+ __all__ = [
11
+ "DefaultPredictor",
12
+ ]
13
+
14
+
15
+ class DefaultPredictor:
16
+ """
17
+ Create a simple end-to-end predictor with the given config that runs on
18
+ single device for a single input image.
19
+ Compared to using the model directly, this class does the following additions:
20
+ 1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
21
+ 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
22
+ 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
23
+ 4. Take one input image and produce a single output, instead of a batch.
24
+ This is meant for simple demo purposes, so it does the above steps automatically.
25
+ This is not meant for benchmarks or running complicated inference logic.
26
+ If you'd like to do anything more complicated, please refer to its source code as
27
+ examples to build and use the model manually.
28
+ Attributes:
29
+ metadata (Metadata): the metadata of the underlying dataset, obtained from
30
+ cfg.DATASETS.TEST.
31
+ Examples:
32
+ ::
33
+ pred = DefaultPredictor(cfg)
34
+ inputs = cv2.imread("input.jpg")
35
+ outputs = pred(inputs)
36
+ """
37
+
38
+ def __init__(self, cfg):
39
+ self.cfg = cfg.clone() # cfg can be modified by model
40
+ self.model = build_model(self.cfg)
41
+ self.model.eval()
42
+ if len(cfg.DATASETS.TEST):
43
+ self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
44
+
45
+ checkpointer = DetectionCheckpointer(self.model)
46
+ checkpointer.load(cfg.MODEL.WEIGHTS)
47
+
48
+ self.aug = T.ResizeShortestEdge(
49
+ [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
50
+ )
51
+
52
+ self.input_format = cfg.INPUT.FORMAT
53
+ assert self.input_format in ["RGB", "BGR"], self.input_format
54
+
55
+ def __call__(self, original_image, task):
56
+ """
57
+ Args:
58
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
59
+ Returns:
60
+ predictions (dict):
61
+ the output of the model for one image only.
62
+ See :doc:`/tutorials/models` for details about the format.
63
+ """
64
+ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
65
+ # Apply pre-processing to image.
66
+ if self.input_format == "RGB":
67
+ # whether the model expects BGR inputs or RGB
68
+ original_image = original_image[:, :, ::-1]
69
+ height, width = original_image.shape[:2]
70
+ image = self.aug.get_transform(original_image).apply_image(original_image)
71
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
72
+
73
+ task = f"The task is {task}"
74
+
75
+ inputs = {"image": image, "height": height, "width": width, "task": task}
76
+ predictions = self.model([inputs])[0]
77
+ return predictions
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/demo/predictor.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py
3
+ import atexit
4
+ import bisect
5
+ import multiprocessing as mp
6
+ from collections import deque
7
+
8
+ import cv2
9
+ import torch
10
+
11
+ from annotator.oneformer.detectron2.data import MetadataCatalog
12
+ from defaults import DefaultPredictor
13
+ from annotator.oneformer.detectron2.utils.video_visualizer import VideoVisualizer
14
+ from visualizer import ColorMode, Visualizer
15
+
16
+
17
+ class VisualizationDemo(object):
18
+ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
19
+ """
20
+ Args:
21
+ cfg (CfgNode):
22
+ instance_mode (ColorMode):
23
+ parallel (bool): whether to run the model in different processes from visualization.
24
+ Useful since the visualization logic can be slow.
25
+ """
26
+ self.metadata = MetadataCatalog.get(
27
+ cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
28
+ )
29
+ if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST[0]:
30
+ from cityscapesscripts.helpers.labels import labels
31
+ stuff_colors = [k.color for k in labels if k.trainId != 255]
32
+ self.metadata = self.metadata.set(stuff_colors=stuff_colors)
33
+ self.cpu_device = torch.device("cpu")
34
+ self.instance_mode = instance_mode
35
+
36
+ self.parallel = parallel
37
+ if parallel:
38
+ num_gpu = torch.cuda.device_count()
39
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
40
+ else:
41
+ self.predictor = DefaultPredictor(cfg)
42
+
43
+ def run_on_image(self, image, task, sem_gt, pan_gt, ins_gt, box_gt):
44
+ """
45
+ Args:
46
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
47
+ This is the format used by OpenCV.
48
+ Returns:
49
+ predictions (dict): the output of the model.
50
+ vis_output (VisImage): the visualized image output.
51
+ """
52
+ vis_output = None
53
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
54
+ image = image[:, :, ::-1]
55
+ vis_output = {}
56
+
57
+ if task == 'panoptic':
58
+ visualizer = Visualizer(image, metadata=self.metadata, instance_mode=0)
59
+ predictions = self.predictor(image, "panoptic")
60
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
61
+ vis_output['panoptic'] = visualizer.draw_panoptic_seg_predictions(
62
+ panoptic_seg.to(self.cpu_device), segments_info, alpha=1
63
+ )
64
+
65
+ # visualizer = Visualizer(image, metadata=self.metadata, instance_mode=0)
66
+ # vis_output['pan_gt'] = visualizer.draw_panoptic_seg(
67
+ # pan_gt[0].to(self.cpu_device), pan_gt[1], alpha=1
68
+ # )
69
+
70
+ if task == 'panoptic' or task == 'semantic':
71
+ visualizer = Visualizer(image, metadata=self.metadata, instance_mode=1)
72
+ predictions = self.predictor(image, "semantic")
73
+ vis_output['semantic'] = visualizer.draw_sem_seg(
74
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device), alpha=1
75
+ )
76
+
77
+ # visualizer = Visualizer(image, metadata=self.metadata, instance_mode=1)
78
+ # vis_output['gt_sem'] = visualizer.draw_sem_seg(
79
+ # sem_gt.to(self.cpu_device), alpha=1
80
+ # )
81
+
82
+ if task == 'panoptic' or task == 'instance':
83
+ visualizer = Visualizer(image, metadata=self.metadata, instance_mode=2)
84
+ predictions = self.predictor(image, "instance")
85
+ instances = predictions["instances"].to(self.cpu_device)
86
+ vis_output['instance'] = visualizer.draw_instance_predictions(predictions=instances, alpha=1)
87
+
88
+ if 'boxes' in predictions:
89
+ boxes, labels, scores = predictions["boxes"]
90
+ visualizer = Visualizer(image, False, metadata=self.metadata, instance_mode=0)
91
+ vis_output['boxes'] = visualizer.draw_box_predictions(
92
+ boxes.to(self.cpu_device), labels.to(self.cpu_device), scores.to(self.cpu_device))
93
+
94
+
95
+ # visualizer = Visualizer(image, metadata=self.metadata, instance_mode=2)
96
+ # vis_output['ins_gt'] = visualizer.draw_instance_predictions(predictions=ins_gt.to(self.cpu_device), alpha=1)
97
+ # vis_output['input'] = visualizer.get_image(image)
98
+
99
+ return predictions, vis_output
100
+
101
+
102
+ class AsyncPredictor:
103
+ """
104
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
105
+ Because rendering the visualization takes considerably amount of time,
106
+ this helps improve throughput a little bit when rendering videos.
107
+ """
108
+
109
+ class _StopToken:
110
+ pass
111
+
112
+ class _PredictWorker(mp.Process):
113
+ def __init__(self, cfg, task_queue, result_queue):
114
+ self.cfg = cfg
115
+ self.task_queue = task_queue
116
+ self.result_queue = result_queue
117
+ super().__init__()
118
+
119
+ def run(self):
120
+ predictor = DefaultPredictor(self.cfg)
121
+
122
+ while True:
123
+ task = self.task_queue.get()
124
+ if isinstance(task, AsyncPredictor._StopToken):
125
+ break
126
+ idx, data = task
127
+ result = predictor(data)
128
+ self.result_queue.put((idx, result))
129
+
130
+ def __init__(self, cfg, num_gpus: int = 1):
131
+ """
132
+ Args:
133
+ cfg (CfgNode):
134
+ num_gpus (int): if 0, will run on CPU
135
+ """
136
+ num_workers = max(num_gpus, 1)
137
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
138
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
139
+ self.procs = []
140
+ for gpuid in range(max(num_gpus, 1)):
141
+ cfg = cfg.clone()
142
+ cfg.defrost()
143
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
144
+ self.procs.append(
145
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
146
+ )
147
+
148
+ self.put_idx = 0
149
+ self.get_idx = 0
150
+ self.result_rank = []
151
+ self.result_data = []
152
+
153
+ for p in self.procs:
154
+ p.start()
155
+ atexit.register(self.shutdown)
156
+
157
+ def put(self, image):
158
+ self.put_idx += 1
159
+ self.task_queue.put((self.put_idx, image))
160
+
161
+ def get(self):
162
+ self.get_idx += 1 # the index needed for this request
163
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
164
+ res = self.result_data[0]
165
+ del self.result_data[0], self.result_rank[0]
166
+ return res
167
+
168
+ while True:
169
+ # make sure the results are returned in the correct order
170
+ idx, res = self.result_queue.get()
171
+ if idx == self.get_idx:
172
+ return res
173
+ insert = bisect.bisect(self.result_rank, idx)
174
+ self.result_rank.insert(insert, idx)
175
+ self.result_data.insert(insert, res)
176
+
177
+ def __len__(self):
178
+ return self.put_idx - self.get_idx
179
+
180
+ def __call__(self, image):
181
+ self.put(image)
182
+ return self.get()
183
+
184
+ def shutdown(self):
185
+ for _ in self.procs:
186
+ self.task_queue.put(AsyncPredictor._StopToken())
187
+
188
+ @property
189
+ def default_buffer_size(self):
190
+ return len(self.procs) * 5
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/demo/visualizer.py ADDED
@@ -0,0 +1,1350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import colorsys
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+ from enum import Enum, unique
7
+ import cv2
8
+ import matplotlib as mpl
9
+ import matplotlib.colors as mplc
10
+ import matplotlib.figure as mplfigure
11
+ import annotator.oneformer.pycocotools.mask as mask_util
12
+ import torch
13
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
14
+ from PIL import Image
15
+
16
+ from annotator.oneformer.detectron2.data import MetadataCatalog
17
+ from annotator.oneformer.detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
18
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
19
+ import random
20
+ random.seed(0)
21
+ from .colormap import random_color, _COLORS
22
+ logger = logging.getLogger(__name__)
23
+
24
+ __all__ = ["ColorMode", "VisImage", "Visualizer"]
25
+
26
+
27
+ _SMALL_OBJECT_AREA_THRESH = 1000
28
+ _LARGE_MASK_AREA_THRESH = 120000
29
+ _OFF_WHITE = (1.0, 1.0, 1.0)
30
+ _BLACK = (0, 0, 0)
31
+ _RED = (1.0, 0, 0)
32
+
33
+ _KEYPOINT_THRESHOLD = 0.05
34
+
35
+
36
+ def instance_color(rgb=False, idx=1, maximum=255):
37
+ """
38
+ Args:
39
+ rgb (bool): whether to return RGB colors or BGR colors.
40
+ maximum (int): either 255 or 1
41
+ Returns:
42
+ ndarray: a vector of 3 numbers
43
+ """
44
+ ret = _COLORS[idx] * maximum
45
+ if not rgb:
46
+ ret = ret[::-1]
47
+ return ret
48
+
49
+ @unique
50
+ class ColorMode(Enum):
51
+ """
52
+ Enum of different color modes to use for instance visualizations.
53
+ """
54
+
55
+ IMAGE = 0
56
+ """
57
+ Picks a random color for every instance and overlay segmentations with low opacity.
58
+ """
59
+ SEGMENTATION = 1
60
+ """
61
+ Let instances of the same category have similar colors
62
+ (from metadata.thing_colors), and overlay them with
63
+ high opacity. This provides more attention on the quality of segmentation.
64
+ """
65
+ IMAGE_BW = 2
66
+ """
67
+ Same as IMAGE, but convert all areas without masks to gray-scale.
68
+ Only available for drawing per-instance mask predictions.
69
+ """
70
+
71
+
72
+ class GenericMask:
73
+ """
74
+ Attribute:
75
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
76
+ Each ndarray has format [x, y, x, y, ...]
77
+ mask (ndarray): a binary mask
78
+ """
79
+
80
+ def __init__(self, mask_or_polygons, height, width):
81
+ self._mask = self._polygons = self._has_holes = None
82
+ self.height = height
83
+ self.width = width
84
+
85
+ m = mask_or_polygons
86
+ if isinstance(m, dict):
87
+ # RLEs
88
+ assert "counts" in m and "size" in m
89
+ if isinstance(m["counts"], list): # uncompressed RLEs
90
+ h, w = m["size"]
91
+ assert h == height and w == width
92
+ m = mask_util.frPyObjects(m, h, w)
93
+ self._mask = mask_util.decode(m)[:, :]
94
+ return
95
+
96
+ if isinstance(m, list): # list[ndarray]
97
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
98
+ return
99
+
100
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
101
+ assert m.shape[1] != 2, m.shape
102
+ assert m.shape == (
103
+ height,
104
+ width,
105
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
106
+ self._mask = m.astype("uint8")
107
+ return
108
+
109
+ raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
110
+
111
+ @property
112
+ def mask(self):
113
+ if self._mask is None:
114
+ self._mask = self.polygons_to_mask(self._polygons)
115
+ return self._mask
116
+
117
+ @property
118
+ def polygons(self):
119
+ if self._polygons is None:
120
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
121
+ return self._polygons
122
+
123
+ @property
124
+ def has_holes(self):
125
+ if self._has_holes is None:
126
+ if self._mask is not None:
127
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
128
+ else:
129
+ self._has_holes = False # if original format is polygon, does not have holes
130
+ return self._has_holes
131
+
132
+ def mask_to_polygons(self, mask):
133
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
134
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
135
+ # Internal contours (holes) are placed in hierarchy-2.
136
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
137
+ mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
138
+ res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
139
+ hierarchy = res[-1]
140
+ if hierarchy is None: # empty mask
141
+ return [], False
142
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
143
+ res = res[-2]
144
+ res = [x.flatten() for x in res]
145
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
146
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
147
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
148
+ res = [x + 0.5 for x in res if len(x) >= 6]
149
+ return res, has_holes
150
+
151
+ def polygons_to_mask(self, polygons):
152
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
153
+ rle = mask_util.merge(rle)
154
+ return mask_util.decode(rle)[:, :]
155
+
156
+ def area(self):
157
+ return self.mask.sum()
158
+
159
+ def bbox(self):
160
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
161
+ p = mask_util.merge(p)
162
+ bbox = mask_util.toBbox(p)
163
+ bbox[2] += bbox[0]
164
+ bbox[3] += bbox[1]
165
+ return bbox
166
+
167
+
168
+ class _PanopticPrediction:
169
+ """
170
+ Unify different panoptic annotation/prediction formats
171
+ """
172
+
173
+ def __init__(self, panoptic_seg, segments_info, metadata=None):
174
+ if segments_info is None:
175
+ assert metadata is not None
176
+ # If "segments_info" is None, we assume "panoptic_img" is a
177
+ # H*W int32 image storing the panoptic_id in the format of
178
+ # category_id * label_divisor + instance_id. We reserve -1 for
179
+ # VOID label.
180
+ label_divisor = metadata.label_divisor
181
+ segments_info = []
182
+ for panoptic_label in np.unique(panoptic_seg.numpy()):
183
+ if panoptic_label == -1:
184
+ # VOID region.
185
+ continue
186
+ pred_class = panoptic_label // label_divisor
187
+ isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
188
+ segments_info.append(
189
+ {
190
+ "id": int(panoptic_label),
191
+ "category_id": int(pred_class),
192
+ "isthing": bool(isthing),
193
+ }
194
+ )
195
+ del metadata
196
+
197
+ self._seg = panoptic_seg
198
+
199
+ self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
200
+ segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
201
+ areas = areas.numpy()
202
+ sorted_idxs = np.argsort(-areas)
203
+ self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
204
+ self._seg_ids = self._seg_ids.tolist()
205
+ for sid, area in zip(self._seg_ids, self._seg_areas):
206
+ if sid in self._sinfo:
207
+ self._sinfo[sid]["area"] = float(area)
208
+
209
+ def non_empty_mask(self):
210
+ """
211
+ Returns:
212
+ (H, W) array, a mask for all pixels that have a prediction
213
+ """
214
+ empty_ids = []
215
+ for id in self._seg_ids:
216
+ if id not in self._sinfo:
217
+ empty_ids.append(id)
218
+ if len(empty_ids) == 0:
219
+ return np.zeros(self._seg.shape, dtype=np.uint8)
220
+ assert (
221
+ len(empty_ids) == 1
222
+ ), ">1 ids corresponds to no labels. This is currently not supported"
223
+ return (self._seg != empty_ids[0]).numpy().astype(np.bool)
224
+
225
+ def semantic_masks(self):
226
+ for sid in self._seg_ids:
227
+ sinfo = self._sinfo.get(sid)
228
+ if sinfo is None or sinfo["isthing"]:
229
+ # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
230
+ continue
231
+ yield (self._seg == sid).numpy().astype(np.bool), sinfo
232
+
233
+ def instance_masks(self):
234
+ for sid in self._seg_ids:
235
+ sinfo = self._sinfo.get(sid)
236
+ if sinfo is None or not sinfo["isthing"]:
237
+ continue
238
+ mask = (self._seg == sid).numpy().astype(np.bool)
239
+ if mask.sum() > 0:
240
+ yield mask, sinfo
241
+
242
+
243
+ def _create_text_labels(classes, scores, class_names, is_crowd=None):
244
+ """
245
+ Args:
246
+ classes (list[int] or None):
247
+ scores (list[float] or None):
248
+ class_names (list[str] or None):
249
+ is_crowd (list[bool] or None):
250
+ Returns:
251
+ list[str] or None
252
+ """
253
+ labels = None
254
+ if classes is not None:
255
+ if class_names is not None and len(class_names) > 0:
256
+ labels = [class_names[i] for i in classes]
257
+ else:
258
+ labels = [str(i) for i in classes]
259
+ if scores is not None:
260
+ if labels is None:
261
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
262
+ else:
263
+ labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
264
+ if labels is not None and is_crowd is not None:
265
+ labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
266
+ return labels
267
+
268
+
269
+ class VisImage:
270
+ def __init__(self, img, scale=1.0):
271
+ """
272
+ Args:
273
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
274
+ scale (float): scale the input image
275
+ """
276
+ self.img = img
277
+ self.scale = scale
278
+ self.width, self.height = img.shape[1], img.shape[0]
279
+ self._setup_figure(img)
280
+
281
+ def _setup_figure(self, img):
282
+ """
283
+ Args:
284
+ Same as in :meth:`__init__()`.
285
+ Returns:
286
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
287
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
288
+ """
289
+ fig = mplfigure.Figure(frameon=False)
290
+ self.dpi = fig.get_dpi()
291
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
292
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
293
+ fig.set_size_inches(
294
+ (self.width * self.scale + 1e-2) / self.dpi,
295
+ (self.height * self.scale + 1e-2) / self.dpi,
296
+ )
297
+ self.canvas = FigureCanvasAgg(fig)
298
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
299
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
300
+ ax.axis("off")
301
+ self.fig = fig
302
+ self.ax = ax
303
+ self.reset_image(img)
304
+
305
+ def reset_image(self, img):
306
+ """
307
+ Args:
308
+ img: same as in __init__
309
+ """
310
+ img = img.astype("uint8")
311
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
312
+
313
+ def save(self, filepath):
314
+ """
315
+ Args:
316
+ filepath (str): a string that contains the absolute path, including the file name, where
317
+ the visualized image will be saved.
318
+ """
319
+ self.fig.savefig(filepath)
320
+
321
+ def get_image(self):
322
+ """
323
+ Returns:
324
+ ndarray:
325
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
326
+ The shape is scaled w.r.t the input image using the given `scale` argument.
327
+ """
328
+ canvas = self.canvas
329
+ s, (width, height) = canvas.print_to_buffer()
330
+ # buf = io.BytesIO() # works for cairo backend
331
+ # canvas.print_rgba(buf)
332
+ # width, height = self.width, self.height
333
+ # s = buf.getvalue()
334
+
335
+ buffer = np.frombuffer(s, dtype="uint8")
336
+
337
+ img_rgba = buffer.reshape(height, width, 4)
338
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
339
+ return rgb.astype("uint8")
340
+
341
+
342
+ class Visualizer:
343
+ """
344
+ Visualizer that draws data about detection/segmentation on images.
345
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
346
+ that draw primitive objects to images, as well as high-level wrappers like
347
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
348
+ that draw composite data in some pre-defined style.
349
+ Note that the exact visualization style for the high-level wrappers are subject to change.
350
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
351
+ of objects themselves (e.g. when the object is too small) may change according
352
+ to different heuristics, as long as the results still look visually reasonable.
353
+ To obtain a consistent style, you can implement custom drawing functions with the
354
+ abovementioned primitive methods instead. If you need more customized visualization
355
+ styles, you can process the data yourself following their format documented in
356
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
357
+ intend to satisfy everyone's preference on drawing styles.
358
+ This visualizer focuses on high rendering quality rather than performance. It is not
359
+ designed to be used for real-time applications.
360
+ """
361
+
362
+ # TODO implement a fast, rasterized version using OpenCV
363
+
364
+ def __init__(self, img_rgb, is_img=True, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
365
+ """
366
+ Args:
367
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
368
+ the height and width of the image respectively. C is the number of
369
+ color channels. The image is required to be in RGB format since that
370
+ is a requirement of the Matplotlib library. The image is also expected
371
+ to be in the range [0, 255].
372
+ metadata (Metadata): dataset metadata (e.g. class names and colors)
373
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
374
+ instances on an image.
375
+ """
376
+ if is_img:
377
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
378
+ else:
379
+ self.img = np.zeros_like(img_rgb).clip(0, 255).astype(np.uint8) + 255
380
+ if metadata is None:
381
+ metadata = MetadataCatalog.get("__nonexist__")
382
+ self.metadata = metadata
383
+ self.output = VisImage(self.img, scale=scale)
384
+ self.cpu_device = torch.device("cpu")
385
+
386
+ # too small texts are useless, therefore clamp to 9
387
+ self._default_font_size = max(
388
+ np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
389
+ )
390
+ self._instance_mode = instance_mode
391
+ self.keypoint_threshold = _KEYPOINT_THRESHOLD
392
+
393
+ def get_image(self, img):
394
+ img = np.asarray(img).clip(0, 255).astype(np.uint8)
395
+ return VisImage(img, scale=1.0)
396
+
397
+ def draw_box_predictions(
398
+ self,
399
+ boxes=None,
400
+ labels=None,
401
+ scores=None,
402
+ assigned_colors=None
403
+ ):
404
+ """
405
+ Args:
406
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
407
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
408
+ or a :class:`RotatedBoxes`,
409
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
410
+ for the N objects in a single image,
411
+ labels (list[str]): the text to be displayed for each instance.
412
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
413
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
414
+ for full list of formats that the colors are accepted in.
415
+ Returns:
416
+ output (VisImage): image object with visualizations.
417
+ """
418
+ num_instances = 0
419
+ boxes = self._convert_boxes(boxes)
420
+ classes = labels.tolist()
421
+ scores = scores.tolist()
422
+ labels = _create_text_labels(classes, scores, self.metadata.get("stuff_classes", None))
423
+ num_instances = len(boxes)
424
+ assert len(labels) == num_instances
425
+ if assigned_colors is None:
426
+ # assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
427
+ assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
428
+ if num_instances == 0:
429
+ return self.output
430
+
431
+ # Display in largest to smallest order to reduce occlusion.
432
+ areas = None
433
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
434
+
435
+ if areas is not None:
436
+ sorted_idxs = np.argsort(-areas).tolist()
437
+ # Re-order overlapped instances in descending order.
438
+ boxes = boxes[sorted_idxs] if boxes is not None else None
439
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
440
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
441
+
442
+ for i in range(num_instances):
443
+ color = assigned_colors[i]
444
+ if boxes is not None:
445
+ self.draw_box(boxes[i], edge_color=color)
446
+
447
+ if labels is not None:
448
+ # first get a box
449
+ if boxes is not None:
450
+ x0, y0, x1, y1 = boxes[i]
451
+ text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
452
+ horiz_align = "left"
453
+ else:
454
+ continue # drawing the box confidence for keypoints isn't very useful.
455
+ # for small objects, draw text at the side to avoid occlusion
456
+ instance_area = (y1 - y0) * (x1 - x0)
457
+ if (
458
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
459
+ or y1 - y0 < 40 * self.output.scale
460
+ ):
461
+ if y1 >= self.output.height - 5:
462
+ text_pos = (x1, y0)
463
+ else:
464
+ text_pos = (x0, y1)
465
+
466
+ height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
467
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
468
+ font_size = (
469
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
470
+ * 0.5
471
+ * self._default_font_size
472
+ )
473
+ self.draw_text(
474
+ labels[i],
475
+ text_pos,
476
+ color=lighter_color,
477
+ horizontal_alignment=horiz_align,
478
+ font_size=font_size,
479
+ )
480
+
481
+ return self.output
482
+
483
+
484
+ def draw_instance_predictions(self, predictions, alpha=0.8, is_text=True):
485
+ """
486
+ Draw instance-level prediction results on an image.
487
+ Args:
488
+ predictions (Instances): the output of an instance detection/segmentation
489
+ model. Following fields will be used to draw:
490
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
491
+ Returns:
492
+ output (VisImage): image object with visualizations.
493
+ """
494
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
495
+ scores = predictions.scores if predictions.has("scores") else None
496
+ classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
497
+ labels = _create_text_labels(classes, scores, self.metadata.get("stuff_classes", None))
498
+ keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
499
+
500
+ if predictions.has("pred_masks"):
501
+ masks = np.asarray(predictions.pred_masks)
502
+ masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
503
+ else:
504
+ masks = None
505
+
506
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("stuff_colors"):
507
+ # colors = [
508
+ # self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
509
+ # ]
510
+ colors = [
511
+ instance_color(rgb=True, idx=c, maximum=1) for c in classes
512
+ ]
513
+ else:
514
+ colors = None
515
+
516
+ if self._instance_mode == ColorMode.IMAGE_BW:
517
+ self.output.reset_image(
518
+ self._create_grayscale_image(
519
+ (predictions.pred_masks.any(dim=0) > 0).numpy()
520
+ if predictions.has("pred_masks")
521
+ else None
522
+ )
523
+ )
524
+
525
+ self.overlay_instances(
526
+ masks=masks,
527
+ boxes=boxes,
528
+ labels=labels,
529
+ keypoints=keypoints,
530
+ assigned_colors=colors,
531
+ alpha=alpha,
532
+ is_text=is_text,
533
+ )
534
+ return self.output
535
+
536
+ def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8, is_text=True, edge_color=_OFF_WHITE):
537
+ """
538
+ Draw semantic segmentation predictions/labels.
539
+ Args:
540
+ sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
541
+ Each value is the integer label of the pixel.
542
+ area_threshold (int): segments with less than `area_threshold` are not drawn.
543
+ alpha (float): the larger it is, the more opaque the segmentations are.
544
+ Returns:
545
+ output (VisImage): image object with visualizations.
546
+ """
547
+ if isinstance(sem_seg, torch.Tensor):
548
+ sem_seg = sem_seg.numpy()
549
+ labels, areas = np.unique(sem_seg, return_counts=True)
550
+ sorted_idxs = np.argsort(-areas).tolist()
551
+ labels = labels[sorted_idxs]
552
+ for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
553
+ try:
554
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
555
+ except (AttributeError, IndexError):
556
+ mask_color = None
557
+
558
+ binary_mask = (sem_seg == label).astype(np.uint8)
559
+ text = self.metadata.stuff_classes[label]
560
+ self.draw_binary_mask(
561
+ binary_mask,
562
+ color=mask_color,
563
+ edge_color=edge_color,
564
+ text=text,
565
+ alpha=alpha,
566
+ area_threshold=area_threshold,
567
+ is_text=is_text,
568
+ )
569
+ return self.output
570
+
571
+ def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7, is_text=True,):
572
+ """
573
+ Draw panoptic prediction annotations or results.
574
+ Args:
575
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
576
+ segment.
577
+ segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
578
+ If it is a ``list[dict]``, each dict contains keys "id", "category_id".
579
+ If None, category id of each pixel is computed by
580
+ ``pixel // metadata.label_divisor``.
581
+ area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
582
+ Returns:
583
+ output (VisImage): image object with visualizations.
584
+ """
585
+ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
586
+
587
+ if self._instance_mode == ColorMode.IMAGE_BW:
588
+ self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
589
+
590
+ # draw mask for all semantic segments first i.e. "stuff"
591
+ for mask, sinfo in pred.semantic_masks():
592
+ category_idx = sinfo["category_id"]
593
+ try:
594
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
595
+ except AttributeError:
596
+ mask_color = None
597
+
598
+ text = self.metadata.stuff_classes[category_idx]
599
+ self.draw_binary_mask(
600
+ mask,
601
+ color=mask_color,
602
+ edge_color=_OFF_WHITE,
603
+ text=text,
604
+ alpha=alpha,
605
+ area_threshold=area_threshold,
606
+ is_text=is_text,
607
+ )
608
+
609
+ # draw mask for all instances second
610
+ all_instances = list(pred.instance_masks())
611
+ if len(all_instances) == 0:
612
+ return self.output
613
+ masks, sinfo = list(zip(*all_instances))
614
+ category_ids = [x["category_id"] for x in sinfo]
615
+
616
+ try:
617
+ scores = [x["score"] for x in sinfo]
618
+ except KeyError:
619
+ scores = None
620
+ labels = _create_text_labels(
621
+ category_ids, scores, self.metadata.stuff_classes, [x.get("iscrowd", 0) for x in sinfo]
622
+ )
623
+
624
+ try:
625
+ colors = [
626
+ self._jitter([x / 255 for x in self.metadata.stuff_colors[c]]) for c in category_ids
627
+ ]
628
+ except AttributeError:
629
+ colors = None
630
+ self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha, is_text=is_text)
631
+
632
+ return self.output
633
+
634
+ draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
635
+
636
+ def draw_dataset_dict(self, dic):
637
+ """
638
+ Draw annotations/segmentaions in Detectron2 Dataset format.
639
+ Args:
640
+ dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
641
+ Returns:
642
+ output (VisImage): image object with visualizations.
643
+ """
644
+ annos = dic.get("annotations", None)
645
+ if annos:
646
+ if "segmentation" in annos[0]:
647
+ masks = [x["segmentation"] for x in annos]
648
+ else:
649
+ masks = None
650
+ if "keypoints" in annos[0]:
651
+ keypts = [x["keypoints"] for x in annos]
652
+ keypts = np.array(keypts).reshape(len(annos), -1, 3)
653
+ else:
654
+ keypts = None
655
+
656
+ boxes = [
657
+ BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
658
+ if len(x["bbox"]) == 4
659
+ else x["bbox"]
660
+ for x in annos
661
+ ]
662
+
663
+ colors = None
664
+ category_ids = [x["category_id"] for x in annos]
665
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("stuff_colors"):
666
+ colors = [
667
+ self._jitter([x / 255 for x in self.metadata.stuff_colors[c]])
668
+ for c in category_ids
669
+ ]
670
+ names = self.metadata.get("stuff_classes", None)
671
+ labels = _create_text_labels(
672
+ category_ids,
673
+ scores=None,
674
+ class_names=names,
675
+ is_crowd=[x.get("iscrowd", 0) for x in annos],
676
+ )
677
+ self.overlay_instances(
678
+ labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
679
+ )
680
+
681
+ sem_seg = dic.get("sem_seg", None)
682
+ if sem_seg is None and "sem_seg_file_name" in dic:
683
+ with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
684
+ sem_seg = Image.open(f)
685
+ sem_seg = np.asarray(sem_seg, dtype="uint8")
686
+ if sem_seg is not None:
687
+ self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
688
+
689
+ pan_seg = dic.get("pan_seg", None)
690
+ # if pan_seg is None and "pan_seg_file_name" in dic:
691
+ # with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
692
+ # pan_seg = Image.open(f)
693
+ # pan_seg = np.asarray(pan_seg)
694
+ # from panopticapi.utils import rgb2id
695
+ #
696
+ # pan_seg = rgb2id(pan_seg)
697
+ if pan_seg is not None:
698
+ segments_info = dic["segments_info"]
699
+ pan_seg = torch.tensor(pan_seg)
700
+ self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
701
+ return self.output
702
+
703
+ def overlay_instances(
704
+ self,
705
+ *,
706
+ boxes=None,
707
+ labels=None,
708
+ masks=None,
709
+ keypoints=None,
710
+ assigned_colors=None,
711
+ alpha=0.5,
712
+ is_text=True,
713
+ ):
714
+ """
715
+ Args:
716
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
717
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
718
+ or a :class:`RotatedBoxes`,
719
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
720
+ for the N objects in a single image,
721
+ labels (list[str]): the text to be displayed for each instance.
722
+ masks (masks-like object): Supported types are:
723
+ * :class:`detectron2.structures.PolygonMasks`,
724
+ :class:`detectron2.structures.BitMasks`.
725
+ * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
726
+ The first level of the list corresponds to individual instances. The second
727
+ level to all the polygon that compose the instance, and the third level
728
+ to the polygon coordinates. The third level should have the format of
729
+ [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
730
+ * list[ndarray]: each ndarray is a binary mask of shape (H, W).
731
+ * list[dict]: each dict is a COCO-style RLE.
732
+ keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
733
+ where the N is the number of instances and K is the number of keypoints.
734
+ The last dimension corresponds to (x, y, visibility or score).
735
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
736
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
737
+ for full list of formats that the colors are accepted in.
738
+ Returns:
739
+ output (VisImage): image object with visualizations.
740
+ """
741
+ num_instances = 0
742
+ if boxes is not None:
743
+ boxes = self._convert_boxes(boxes)
744
+ num_instances = len(boxes)
745
+ if masks is not None:
746
+ masks = self._convert_masks(masks)
747
+ if num_instances:
748
+ assert len(masks) == num_instances
749
+ else:
750
+ num_instances = len(masks)
751
+ if keypoints is not None:
752
+ if num_instances:
753
+ assert len(keypoints) == num_instances
754
+ else:
755
+ num_instances = len(keypoints)
756
+ keypoints = self._convert_keypoints(keypoints)
757
+ if labels is not None:
758
+ assert len(labels) == num_instances
759
+ if assigned_colors is None:
760
+ # assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
761
+ assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
762
+ if num_instances == 0:
763
+ return self.output
764
+ if boxes is not None and boxes.shape[1] == 5:
765
+ return self.overlay_rotated_instances(
766
+ boxes=boxes, labels=labels, assigned_colors=assigned_colors
767
+ )
768
+
769
+ # Display in largest to smallest order to reduce occlusion.
770
+ areas = None
771
+ if boxes is not None:
772
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
773
+ elif masks is not None:
774
+ areas = np.asarray([x.area() for x in masks])
775
+
776
+ if areas is not None:
777
+ sorted_idxs = np.argsort(-areas).tolist()
778
+ # Re-order overlapped instances in descending order.
779
+ boxes = boxes[sorted_idxs] if boxes is not None else None
780
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
781
+ masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
782
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
783
+ keypoints = keypoints[sorted_idxs] if keypoints is not None else None
784
+
785
+ for i in range(num_instances):
786
+ color = assigned_colors[i]
787
+ if boxes is not None:
788
+ self.draw_box(boxes[i], edge_color=color)
789
+
790
+ if masks is not None:
791
+ for segment in masks[i].polygons:
792
+ self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
793
+
794
+ if labels is not None:
795
+ # first get a box
796
+ if boxes is not None:
797
+ x0, y0, x1, y1 = boxes[i]
798
+ text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
799
+ horiz_align = "left"
800
+ elif masks is not None:
801
+ # skip small mask without polygon
802
+ if len(masks[i].polygons) == 0:
803
+ continue
804
+
805
+ x0, y0, x1, y1 = masks[i].bbox()
806
+
807
+ # draw text in the center (defined by median) when box is not drawn
808
+ # median is less sensitive to outliers.
809
+ text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
810
+ horiz_align = "center"
811
+ else:
812
+ continue # drawing the box confidence for keypoints isn't very useful.
813
+ # for small objects, draw text at the side to avoid occlusion
814
+ instance_area = (y1 - y0) * (x1 - x0)
815
+ if (
816
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
817
+ or y1 - y0 < 40 * self.output.scale
818
+ ):
819
+ if y1 >= self.output.height - 5:
820
+ text_pos = (x1, y0)
821
+ else:
822
+ text_pos = (x0, y1)
823
+
824
+ height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
825
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
826
+ font_size = (
827
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
828
+ * 0.5
829
+ * self._default_font_size
830
+ )
831
+ if is_text:
832
+ self.draw_text(
833
+ labels[i],
834
+ text_pos,
835
+ color=lighter_color,
836
+ horizontal_alignment=horiz_align,
837
+ font_size=font_size,
838
+ )
839
+
840
+ # draw keypoints
841
+ if keypoints is not None:
842
+ for keypoints_per_instance in keypoints:
843
+ self.draw_and_connect_keypoints(keypoints_per_instance)
844
+
845
+ return self.output
846
+
847
+ def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
848
+ """
849
+ Args:
850
+ boxes (ndarray): an Nx5 numpy array of
851
+ (x_center, y_center, width, height, angle_degrees) format
852
+ for the N objects in a single image.
853
+ labels (list[str]): the text to be displayed for each instance.
854
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
855
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
856
+ for full list of formats that the colors are accepted in.
857
+ Returns:
858
+ output (VisImage): image object with visualizations.
859
+ """
860
+ num_instances = len(boxes)
861
+
862
+ if assigned_colors is None:
863
+ # assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
864
+ assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
865
+ if num_instances == 0:
866
+ return self.output
867
+
868
+ # Display in largest to smallest order to reduce occlusion.
869
+ if boxes is not None:
870
+ areas = boxes[:, 2] * boxes[:, 3]
871
+
872
+ sorted_idxs = np.argsort(-areas).tolist()
873
+ # Re-order overlapped instances in descending order.
874
+ boxes = boxes[sorted_idxs]
875
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
876
+ colors = [assigned_colors[idx] for idx in sorted_idxs]
877
+
878
+ for i in range(num_instances):
879
+ self.draw_rotated_box_with_label(
880
+ boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
881
+ )
882
+
883
+ return self.output
884
+
885
+ def draw_and_connect_keypoints(self, keypoints):
886
+ """
887
+ Draws keypoints of an instance and follows the rules for keypoint connections
888
+ to draw lines between appropriate keypoints. This follows color heuristics for
889
+ line color.
890
+ Args:
891
+ keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
892
+ and the last dimension corresponds to (x, y, probability).
893
+ Returns:
894
+ output (VisImage): image object with visualizations.
895
+ """
896
+ visible = {}
897
+ keypoint_names = self.metadata.get("keypoint_names")
898
+ for idx, keypoint in enumerate(keypoints):
899
+
900
+ # draw keypoint
901
+ x, y, prob = keypoint
902
+ if prob > self.keypoint_threshold:
903
+ self.draw_circle((x, y), color=_RED)
904
+ if keypoint_names:
905
+ keypoint_name = keypoint_names[idx]
906
+ visible[keypoint_name] = (x, y)
907
+
908
+ if self.metadata.get("keypoint_connection_rules"):
909
+ for kp0, kp1, color in self.metadata.keypoint_connection_rules:
910
+ if kp0 in visible and kp1 in visible:
911
+ x0, y0 = visible[kp0]
912
+ x1, y1 = visible[kp1]
913
+ color = tuple(x / 255.0 for x in color)
914
+ self.draw_line([x0, x1], [y0, y1], color=color)
915
+
916
+ # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
917
+ # Note that this strategy is specific to person keypoints.
918
+ # For other keypoints, it should just do nothing
919
+ try:
920
+ ls_x, ls_y = visible["left_shoulder"]
921
+ rs_x, rs_y = visible["right_shoulder"]
922
+ mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
923
+ except KeyError:
924
+ pass
925
+ else:
926
+ # draw line from nose to mid-shoulder
927
+ nose_x, nose_y = visible.get("nose", (None, None))
928
+ if nose_x is not None:
929
+ self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
930
+
931
+ try:
932
+ # draw line from mid-shoulder to mid-hip
933
+ lh_x, lh_y = visible["left_hip"]
934
+ rh_x, rh_y = visible["right_hip"]
935
+ except KeyError:
936
+ pass
937
+ else:
938
+ mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
939
+ self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
940
+ return self.output
941
+
942
+ """
943
+ Primitive drawing functions:
944
+ """
945
+
946
+ def draw_text(
947
+ self,
948
+ text,
949
+ position,
950
+ *,
951
+ font_size=None,
952
+ color="g",
953
+ horizontal_alignment="center",
954
+ rotation=0,
955
+ ):
956
+ """
957
+ Args:
958
+ text (str): class label
959
+ position (tuple): a tuple of the x and y coordinates to place text on image.
960
+ font_size (int, optional): font of the text. If not provided, a font size
961
+ proportional to the image width is calculated and used.
962
+ color: color of the text. Refer to `matplotlib.colors` for full list
963
+ of formats that are accepted.
964
+ horizontal_alignment (str): see `matplotlib.text.Text`
965
+ rotation: rotation angle in degrees CCW
966
+ Returns:
967
+ output (VisImage): image object with text drawn.
968
+ """
969
+ if not font_size:
970
+ font_size = self._default_font_size
971
+
972
+ # since the text background is dark, we don't want the text to be dark
973
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
974
+ color[np.argmax(color)] = max(0.8, np.max(color))
975
+
976
+ x, y = position
977
+ self.output.ax.text(
978
+ x,
979
+ y,
980
+ text,
981
+ size=font_size * self.output.scale,
982
+ family="sans-serif",
983
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
984
+ verticalalignment="top",
985
+ horizontalalignment=horizontal_alignment,
986
+ color=color,
987
+ zorder=10,
988
+ rotation=rotation,
989
+ )
990
+ return self.output
991
+
992
+ def draw_box(self, box_coord, alpha=1.0, edge_color="g", line_style="-"):
993
+ """
994
+ Args:
995
+ box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
996
+ are the coordinates of the image's top left corner. x1 and y1 are the
997
+ coordinates of the image's bottom right corner.
998
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
999
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
1000
+ for full list of formats that are accepted.
1001
+ line_style (string): the string to use to create the outline of the boxes.
1002
+ Returns:
1003
+ output (VisImage): image object with box drawn.
1004
+ """
1005
+ x0, y0, x1, y1 = box_coord
1006
+ width = x1 - x0
1007
+ height = y1 - y0
1008
+
1009
+ linewidth = 2
1010
+
1011
+ self.output.ax.add_patch(
1012
+ mpl.patches.Rectangle(
1013
+ (x0, y0),
1014
+ width,
1015
+ height,
1016
+ fill=False,
1017
+ edgecolor=edge_color,
1018
+ linewidth=linewidth * self.output.scale,
1019
+ alpha=alpha,
1020
+ linestyle=line_style,
1021
+ )
1022
+ )
1023
+ return self.output
1024
+
1025
+ def draw_rotated_box_with_label(
1026
+ self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
1027
+ ):
1028
+ """
1029
+ Draw a rotated box with label on its top-left corner.
1030
+ Args:
1031
+ rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
1032
+ where cnt_x and cnt_y are the center coordinates of the box.
1033
+ w and h are the width and height of the box. angle represents how
1034
+ many degrees the box is rotated CCW with regard to the 0-degree box.
1035
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1036
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
1037
+ for full list of formats that are accepted.
1038
+ line_style (string): the string to use to create the outline of the boxes.
1039
+ label (string): label for rotated box. It will not be rendered when set to None.
1040
+ Returns:
1041
+ output (VisImage): image object with box drawn.
1042
+ """
1043
+ cnt_x, cnt_y, w, h, angle = rotated_box
1044
+ area = w * h
1045
+ # use thinner lines when the box is small
1046
+ linewidth = self._default_font_size / (
1047
+ 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
1048
+ )
1049
+
1050
+ theta = angle * math.pi / 180.0
1051
+ c = math.cos(theta)
1052
+ s = math.sin(theta)
1053
+ rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
1054
+ # x: left->right ; y: top->down
1055
+ rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
1056
+ for k in range(4):
1057
+ j = (k + 1) % 4
1058
+ self.draw_line(
1059
+ [rotated_rect[k][0], rotated_rect[j][0]],
1060
+ [rotated_rect[k][1], rotated_rect[j][1]],
1061
+ color=edge_color,
1062
+ linestyle="--" if k == 1 else line_style,
1063
+ linewidth=linewidth,
1064
+ )
1065
+
1066
+ if label is not None:
1067
+ text_pos = rotated_rect[1] # topleft corner
1068
+
1069
+ height_ratio = h / np.sqrt(self.output.height * self.output.width)
1070
+ label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
1071
+ font_size = (
1072
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
1073
+ )
1074
+ self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
1075
+
1076
+ return self.output
1077
+
1078
+ def draw_circle(self, circle_coord, color, radius=3):
1079
+ """
1080
+ Args:
1081
+ circle_coord (list(int) or tuple(int)): contains the x and y coordinates
1082
+ of the center of the circle.
1083
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1084
+ formats that are accepted.
1085
+ radius (int): radius of the circle.
1086
+ Returns:
1087
+ output (VisImage): image object with box drawn.
1088
+ """
1089
+ x, y = circle_coord
1090
+ self.output.ax.add_patch(
1091
+ mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
1092
+ )
1093
+ return self.output
1094
+
1095
+ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
1096
+ """
1097
+ Args:
1098
+ x_data (list[int]): a list containing x values of all the points being drawn.
1099
+ Length of list should match the length of y_data.
1100
+ y_data (list[int]): a list containing y values of all the points being drawn.
1101
+ Length of list should match the length of x_data.
1102
+ color: color of the line. Refer to `matplotlib.colors` for a full list of
1103
+ formats that are accepted.
1104
+ linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
1105
+ for a full list of formats that are accepted.
1106
+ linewidth (float or None): width of the line. When it's None,
1107
+ a default value will be computed and used.
1108
+ Returns:
1109
+ output (VisImage): image object with line drawn.
1110
+ """
1111
+ if linewidth is None:
1112
+ linewidth = self._default_font_size / 3
1113
+ linewidth = max(linewidth, 1)
1114
+ self.output.ax.add_line(
1115
+ mpl.lines.Line2D(
1116
+ x_data,
1117
+ y_data,
1118
+ linewidth=linewidth * self.output.scale,
1119
+ color=color,
1120
+ linestyle=linestyle,
1121
+ )
1122
+ )
1123
+ return self.output
1124
+
1125
+ def draw_binary_mask(
1126
+ self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10, is_text=True,
1127
+ ):
1128
+ """
1129
+ Args:
1130
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
1131
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
1132
+ type.
1133
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1134
+ formats that are accepted. If None, will pick a random color.
1135
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1136
+ full list of formats that are accepted.
1137
+ text (str): if None, will be drawn on the object
1138
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1139
+ area_threshold (float): a connected component smaller than this area will not be shown.
1140
+ Returns:
1141
+ output (VisImage): image object with mask drawn.
1142
+ """
1143
+ if color is None:
1144
+ color = random_color(rgb=True, maximum=1)
1145
+ color = mplc.to_rgb(color)
1146
+
1147
+ has_valid_segment = False
1148
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
1149
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
1150
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
1151
+
1152
+ if not mask.has_holes:
1153
+ # draw polygons for regular masks
1154
+ for segment in mask.polygons:
1155
+ # area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
1156
+ # if area < (area_threshold or 0):
1157
+ # continue
1158
+ has_valid_segment = True
1159
+ segment = segment.reshape(-1, 2)
1160
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
1161
+ else:
1162
+ # TODO: Use Path/PathPatch to draw vector graphics:
1163
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
1164
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1165
+ rgba[:, :, :3] = color
1166
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
1167
+ has_valid_segment = True
1168
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1169
+
1170
+ if is_text:
1171
+ if text is not None and has_valid_segment:
1172
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1173
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
1174
+ return self.output
1175
+
1176
+ def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
1177
+ """
1178
+ Args:
1179
+ soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
1180
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1181
+ formats that are accepted. If None, will pick a random color.
1182
+ text (str): if None, will be drawn on the object
1183
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1184
+ Returns:
1185
+ output (VisImage): image object with mask drawn.
1186
+ """
1187
+ if color is None:
1188
+ color = random_color(rgb=True, maximum=1)
1189
+ color = mplc.to_rgb(color)
1190
+
1191
+ shape2d = (soft_mask.shape[0], soft_mask.shape[1])
1192
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1193
+ rgba[:, :, :3] = color
1194
+ rgba[:, :, 3] = soft_mask * alpha
1195
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1196
+
1197
+ if text is not None:
1198
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1199
+ binary_mask = (soft_mask > 0.5).astype("uint8")
1200
+ # self._draw_text_in_mask(binary_mask, text, lighter_color)
1201
+ return self.output
1202
+
1203
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
1204
+ """
1205
+ Args:
1206
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
1207
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1208
+ formats that are accepted.
1209
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1210
+ full list of formats that are accepted. If not provided, a darker shade
1211
+ of the polygon color will be used instead.
1212
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1213
+ Returns:
1214
+ output (VisImage): image object with polygon drawn.
1215
+ """
1216
+ if edge_color is None:
1217
+ # make edge color darker than the polygon color
1218
+ if alpha > 0.8:
1219
+ edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
1220
+ else:
1221
+ edge_color = color
1222
+ edge_color = mplc.to_rgb(edge_color) + (1,)
1223
+
1224
+ polygon = mpl.patches.Polygon(
1225
+ segment,
1226
+ fill=True,
1227
+ facecolor=mplc.to_rgb(color) + (alpha,),
1228
+ edgecolor=edge_color,
1229
+ linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
1230
+ )
1231
+ self.output.ax.add_patch(polygon)
1232
+ return self.output
1233
+
1234
+ """
1235
+ Internal methods:
1236
+ """
1237
+
1238
+ def _jitter(self, color):
1239
+ """
1240
+ Randomly modifies given color to produce a slightly different color than the color given.
1241
+ Args:
1242
+ color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
1243
+ picked. The values in the list are in the [0.0, 1.0] range.
1244
+ Returns:
1245
+ jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
1246
+ color after being jittered. The values in the list are in the [0.0, 1.0] range.
1247
+ """
1248
+ color = mplc.to_rgb(color)
1249
+ vec = np.random.rand(3)
1250
+ # better to do it in another color space
1251
+ vec = vec / np.linalg.norm(vec) * 0.5
1252
+ res = np.clip(vec + color, 0, 1)
1253
+ return tuple(res)
1254
+
1255
+ def _create_grayscale_image(self, mask=None):
1256
+ """
1257
+ Create a grayscale version of the original image.
1258
+ The colors in masked area, if given, will be kept.
1259
+ """
1260
+ img_bw = self.img.astype("f4").mean(axis=2)
1261
+ img_bw = np.stack([img_bw] * 3, axis=2)
1262
+ if mask is not None:
1263
+ img_bw[mask] = self.img[mask]
1264
+ return img_bw
1265
+
1266
+ def _change_color_brightness(self, color, brightness_factor):
1267
+ """
1268
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
1269
+ less or more saturation than the original color.
1270
+ Args:
1271
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1272
+ formats that are accepted.
1273
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
1274
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
1275
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
1276
+ Returns:
1277
+ modified_color (tuple[double]): a tuple containing the RGB values of the
1278
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
1279
+ """
1280
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
1281
+ color = mplc.to_rgb(color)
1282
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
1283
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
1284
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
1285
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
1286
+ modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
1287
+ return modified_color
1288
+
1289
+ def _convert_boxes(self, boxes):
1290
+ """
1291
+ Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
1292
+ """
1293
+ if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
1294
+ return boxes.tensor.detach().numpy()
1295
+ else:
1296
+ return np.asarray(boxes)
1297
+
1298
+ def _convert_masks(self, masks_or_polygons):
1299
+ """
1300
+ Convert different format of masks or polygons to a tuple of masks and polygons.
1301
+ Returns:
1302
+ list[GenericMask]:
1303
+ """
1304
+
1305
+ m = masks_or_polygons
1306
+ if isinstance(m, PolygonMasks):
1307
+ m = m.polygons
1308
+ if isinstance(m, BitMasks):
1309
+ m = m.tensor.numpy()
1310
+ if isinstance(m, torch.Tensor):
1311
+ m = m.numpy()
1312
+ ret = []
1313
+ for x in m:
1314
+ if isinstance(x, GenericMask):
1315
+ ret.append(x)
1316
+ else:
1317
+ ret.append(GenericMask(x, self.output.height, self.output.width))
1318
+ return ret
1319
+
1320
+ def _draw_text_in_mask(self, binary_mask, text, color):
1321
+ """
1322
+ Find proper places to draw text given a binary mask.
1323
+ """
1324
+ # TODO sometimes drawn on wrong objects. the heuristics here can improve.
1325
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
1326
+ if stats[1:, -1].size == 0:
1327
+ return
1328
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
1329
+
1330
+ # draw text on the largest component, as well as other very large components.
1331
+ for cid in range(1, _num_cc):
1332
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1333
+ # median is more stable than centroid
1334
+ # center = centroids[largest_component_id]
1335
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1336
+ self.draw_text(text, center, color=color)
1337
+
1338
+ def _convert_keypoints(self, keypoints):
1339
+ if isinstance(keypoints, Keypoints):
1340
+ keypoints = keypoints.tensor
1341
+ keypoints = np.asarray(keypoints)
1342
+ return keypoints
1343
+
1344
+ def get_output(self):
1345
+ """
1346
+ Returns:
1347
+ output (VisImage): the image output containing the visualizations added
1348
+ to the image.
1349
+ """
1350
+ return self.output
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .detection_coco_evaluator import *
2
+ from .coco_evaluator import *
3
+ from .cityscapes_evaluation import CityscapesInstanceEvaluator
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/cityscapes_evaluation.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/cityscapes_evaluation.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import glob
7
+ import logging
8
+ import numpy as np
9
+ import os
10
+ import tempfile
11
+ from collections import OrderedDict
12
+ import torch
13
+ from PIL import Image
14
+
15
+ from annotator.oneformer.detectron2.data import MetadataCatalog
16
+ from annotator.oneformer.detectron2.utils import comm
17
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
18
+
19
+ from .evaluator import DatasetEvaluator
20
+
21
+
22
+ class CityscapesEvaluator(DatasetEvaluator):
23
+ """
24
+ Base class for evaluation using cityscapes API.
25
+ """
26
+
27
+ def __init__(self, dataset_name):
28
+ """
29
+ Args:
30
+ dataset_name (str): the name of the dataset.
31
+ It must have the following metadata associated with it:
32
+ "thing_classes", "gt_dir".
33
+ """
34
+ self._metadata = MetadataCatalog.get(dataset_name)
35
+ self._cpu_device = torch.device("cpu")
36
+ self._logger = logging.getLogger(__name__)
37
+
38
+ def reset(self):
39
+ self._working_dir = tempfile.TemporaryDirectory(prefix="cityscapes_eval_")
40
+ self._temp_dir = self._working_dir.name
41
+ # All workers will write to the same results directory
42
+ # TODO this does not work in distributed training
43
+ assert (
44
+ comm.get_local_size() == comm.get_world_size()
45
+ ), "CityscapesEvaluator currently do not work with multiple machines."
46
+ self._temp_dir = comm.all_gather(self._temp_dir)[0]
47
+ if self._temp_dir != self._working_dir.name:
48
+ self._working_dir.cleanup()
49
+ self._logger.info(
50
+ "Writing cityscapes results to temporary directory {} ...".format(self._temp_dir)
51
+ )
52
+
53
+
54
+ class CityscapesInstanceEvaluator(CityscapesEvaluator):
55
+ """
56
+ Evaluate instance segmentation results on cityscapes dataset using cityscapes API.
57
+
58
+ Note:
59
+ * It does not work in multi-machine distributed training.
60
+ * It contains a synchronization, therefore has to be used on all ranks.
61
+ * Only the main process runs evaluation.
62
+ """
63
+
64
+ def process(self, inputs, outputs):
65
+ from cityscapesscripts.helpers.labels import name2label
66
+
67
+ for input, output in zip(inputs, outputs):
68
+ file_name = input["file_name"]
69
+ basename = os.path.splitext(os.path.basename(file_name))[0]
70
+ pred_txt = os.path.join(self._temp_dir, basename + "_pred.txt")
71
+
72
+ if "instances" in output:
73
+ output = output["instances"].to(self._cpu_device)
74
+ num_instances = len(output)
75
+ with open(pred_txt, "w") as fout:
76
+ for i in range(num_instances):
77
+ pred_class = output.pred_classes[i]
78
+ classes = self._metadata.stuff_classes[pred_class]
79
+ class_id = name2label[classes].id
80
+ score = output.scores[i]
81
+ mask = output.pred_masks[i].numpy().astype("uint8")
82
+ png_filename = os.path.join(
83
+ self._temp_dir, basename + "_{}_{}.png".format(i, classes)
84
+ )
85
+
86
+ Image.fromarray(mask * 255).save(png_filename)
87
+ fout.write(
88
+ "{} {} {}\n".format(os.path.basename(png_filename), class_id, score)
89
+ )
90
+ else:
91
+ # Cityscapes requires a prediction file for every ground truth image.
92
+ with open(pred_txt, "w") as fout:
93
+ pass
94
+
95
+ def evaluate(self):
96
+ """
97
+ Returns:
98
+ dict: has a key "segm", whose value is a dict of "AP" and "AP50".
99
+ """
100
+ comm.synchronize()
101
+ if comm.get_rank() > 0:
102
+ return
103
+ import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as cityscapes_eval
104
+
105
+ self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
106
+
107
+ # set some global states in cityscapes evaluation API, before evaluating
108
+ cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
109
+ cityscapes_eval.args.predictionWalk = None
110
+ cityscapes_eval.args.JSONOutput = False
111
+ cityscapes_eval.args.colorized = False
112
+ cityscapes_eval.args.gtInstancesFile = os.path.join(self._temp_dir, "gtInstances.json")
113
+
114
+ # These lines are adopted from
115
+ # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
116
+ gt_dir = PathManager.get_local_path(self._metadata.gt_dir)
117
+ groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_instanceIds.png"))
118
+ assert len(
119
+ groundTruthImgList
120
+ ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
121
+ cityscapes_eval.args.groundTruthSearch
122
+ )
123
+ predictionImgList = []
124
+ for gt in groundTruthImgList:
125
+ predictionImgList.append(cityscapes_eval.getPrediction(gt, cityscapes_eval.args))
126
+ results = cityscapes_eval.evaluateImgLists(
127
+ predictionImgList, groundTruthImgList, cityscapes_eval.args
128
+ )["averages"]
129
+
130
+ ret = OrderedDict()
131
+ ret["segm"] = {"AP": results["allAp"] * 100, "AP50": results["allAp50%"] * 100}
132
+ self._working_dir.cleanup()
133
+ return ret
134
+
135
+
136
+ class CityscapesSemSegEvaluator(CityscapesEvaluator):
137
+ """
138
+ Evaluate semantic segmentation results on cityscapes dataset using cityscapes API.
139
+
140
+ Note:
141
+ * It does not work in multi-machine distributed training.
142
+ * It contains a synchronization, therefore has to be used on all ranks.
143
+ * Only the main process runs evaluation.
144
+ """
145
+
146
+ def process(self, inputs, outputs):
147
+ from cityscapesscripts.helpers.labels import trainId2label
148
+
149
+ for input, output in zip(inputs, outputs):
150
+ file_name = input["file_name"]
151
+ basename = os.path.splitext(os.path.basename(file_name))[0]
152
+ pred_filename = os.path.join(self._temp_dir, basename + "_pred.png")
153
+
154
+ output = output["sem_seg"].argmax(dim=0).to(self._cpu_device).numpy()
155
+ pred = 255 * np.ones(output.shape, dtype=np.uint8)
156
+ for train_id, label in trainId2label.items():
157
+ if label.ignoreInEval:
158
+ continue
159
+ pred[output == train_id] = label.id
160
+ Image.fromarray(pred).save(pred_filename)
161
+
162
+ def evaluate(self):
163
+ comm.synchronize()
164
+ if comm.get_rank() > 0:
165
+ return
166
+ # Load the Cityscapes eval script *after* setting the required env var,
167
+ # since the script reads CITYSCAPES_DATASET into global variables at load time.
168
+ import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_eval
169
+
170
+ self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
171
+
172
+ # set some global states in cityscapes evaluation API, before evaluating
173
+ cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
174
+ cityscapes_eval.args.predictionWalk = None
175
+ cityscapes_eval.args.JSONOutput = False
176
+ cityscapes_eval.args.colorized = False
177
+
178
+ # These lines are adopted from
179
+ # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py # noqa
180
+ gt_dir = PathManager.get_local_path(self._metadata.gt_dir)
181
+ groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png"))
182
+ assert len(
183
+ groundTruthImgList
184
+ ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
185
+ cityscapes_eval.args.groundTruthSearch
186
+ )
187
+ predictionImgList = []
188
+ for gt in groundTruthImgList:
189
+ predictionImgList.append(cityscapes_eval.getPrediction(cityscapes_eval.args, gt))
190
+ results = cityscapes_eval.evaluateImgLists(
191
+ predictionImgList, groundTruthImgList, cityscapes_eval.args
192
+ )
193
+ ret = OrderedDict()
194
+ ret["sem_seg"] = {
195
+ "IoU": 100.0 * results["averageScoreClasses"],
196
+ "iIoU": 100.0 * results["averageScoreInstClasses"],
197
+ "IoU_sup": 100.0 * results["averageScoreCategories"],
198
+ "iIoU_sup": 100.0 * results["averageScoreInstCategories"],
199
+ }
200
+ self._working_dir.cleanup()
201
+ return ret
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/coco_evaluator.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/coco_evaluation.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import contextlib
7
+ import copy
8
+ import io
9
+ import itertools
10
+ import json
11
+ import logging
12
+ import numpy as np
13
+ import os
14
+ import pickle
15
+ from collections import OrderedDict
16
+ import annotator.oneformer.pycocotools.mask as mask_util
17
+ import torch
18
+ from annotator.oneformer.pycocotools.coco import COCO
19
+ from annotator.oneformer.pycocotools.cocoeval import COCOeval
20
+ from tabulate import tabulate
21
+
22
+ import annotator.oneformer.detectron2.utils.comm as comm
23
+ from annotator.oneformer.detectron2.config import CfgNode
24
+ from annotator.oneformer.detectron2.data import MetadataCatalog
25
+ from annotator.oneformer.detectron2.data.datasets.coco import convert_to_coco_json
26
+ from annotator.oneformer.detectron2.structures import Boxes, BoxMode, pairwise_iou
27
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
28
+ from annotator.oneformer.detectron2.utils.logger import create_small_table
29
+
30
+ from .evaluator import DatasetEvaluator
31
+
32
+ try:
33
+ from annotator.oneformer.detectron2.evaluation.fast_eval_api import COCOeval_opt
34
+ except ImportError:
35
+ COCOeval_opt = COCOeval
36
+
37
+
38
+ class COCOEvaluator(DatasetEvaluator):
39
+ """
40
+ Evaluate AP for instance detection/segmentation, AP
41
+ for keypoint detection outputs using COCO's metrics.
42
+ See http://cocodataset.org/#detection-eval and
43
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
44
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
45
+ the metric cannot be computed (e.g. due to no predictions made).
46
+
47
+ In addition to COCO, this evaluator is able to support any bounding box detection,
48
+ instance segmentation, or keypoint detection dataset.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ dataset_name,
54
+ tasks=None,
55
+ distributed=True,
56
+ output_dir=None,
57
+ *,
58
+ max_dets_per_image=None,
59
+ use_fast_impl=True,
60
+ kpt_oks_sigmas=(),
61
+ allow_cached_coco=True,
62
+ ):
63
+ """
64
+ Args:
65
+ dataset_name (str): name of the dataset to be evaluated.
66
+ It must have either the following corresponding metadata:
67
+
68
+ "json_file": the path to the COCO format annotation
69
+
70
+ Or it must be in detectron2's standard dataset format
71
+ so it can be converted to COCO format automatically.
72
+ tasks (tuple[str]): tasks that can be evaluated under the given
73
+ configuration. A task is one of "bbox", "segm", "keypoints".
74
+ By default, will infer this automatically from predictions.
75
+ distributed (True): if True, will collect results from all ranks and run evaluation
76
+ in the main process.
77
+ Otherwise, will only evaluate the results in the current process.
78
+ output_dir (str): optional, an output directory to dump all
79
+ results predicted on the dataset. The dump contains two files:
80
+
81
+ 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
82
+ contains all the results in the format they are produced by the model.
83
+ 2. "coco_instances_results.json" a json file in COCO's result format.
84
+ max_dets_per_image (int): limit on the maximum number of detections per image.
85
+ By default in COCO, this limit is to 100, but this can be customized
86
+ to be greater, as is needed in evaluation metrics AP fixed and AP pool
87
+ (see https://arxiv.org/pdf/2102.01066.pdf)
88
+ This doesn't affect keypoint evaluation.
89
+ use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
90
+ Although the results should be very close to the official implementation in COCO
91
+ API, it is still recommended to compute results with the official API for use in
92
+ papers. The faster implementation also uses more RAM.
93
+ kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
94
+ See http://cocodataset.org/#keypoints-eval
95
+ When empty, it will use the defaults in COCO.
96
+ Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
97
+ allow_cached_coco (bool): Whether to use cached coco json from previous validation
98
+ runs. You should set this to False if you need to use different validation data.
99
+ Defaults to True.
100
+ """
101
+ self._logger = logging.getLogger(__name__)
102
+ self._distributed = distributed
103
+ self._output_dir = output_dir
104
+
105
+ if use_fast_impl and (COCOeval_opt is COCOeval):
106
+ self._logger.info("Fast COCO eval is not built. Falling back to official COCO eval.")
107
+ use_fast_impl = False
108
+ self._use_fast_impl = use_fast_impl
109
+
110
+ # COCOeval requires the limit on the number of detections per image (maxDets) to be a list
111
+ # with at least 3 elements. The default maxDets in COCOeval is [1, 10, 100], in which the
112
+ # 3rd element (100) is used as the limit on the number of detections per image when
113
+ # evaluating AP. COCOEvaluator expects an integer for max_dets_per_image, so for COCOeval,
114
+ # we reformat max_dets_per_image into [1, 10, max_dets_per_image], based on the defaults.
115
+ if max_dets_per_image is None:
116
+ max_dets_per_image = [1, 10, 100]
117
+ else:
118
+ max_dets_per_image = [1, 10, max_dets_per_image]
119
+ self._max_dets_per_image = max_dets_per_image
120
+
121
+ if tasks is not None and isinstance(tasks, CfgNode):
122
+ kpt_oks_sigmas = (
123
+ tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas
124
+ )
125
+ self._logger.warn(
126
+ "COCO Evaluator instantiated using config, this is deprecated behavior."
127
+ " Please pass in explicit arguments instead."
128
+ )
129
+ self._tasks = None # Infering it from predictions should be better
130
+ else:
131
+ self._tasks = tasks
132
+
133
+ self._cpu_device = torch.device("cpu")
134
+
135
+ self._metadata = MetadataCatalog.get(dataset_name)
136
+ if not hasattr(self._metadata, "json_file"):
137
+ if output_dir is None:
138
+ raise ValueError(
139
+ "output_dir must be provided to COCOEvaluator "
140
+ "for datasets not in COCO format."
141
+ )
142
+ self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...")
143
+
144
+ cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
145
+ self._metadata.json_file = cache_path
146
+ convert_to_coco_json(dataset_name, cache_path, allow_cached=allow_cached_coco)
147
+
148
+ json_file = PathManager.get_local_path(self._metadata.json_file)
149
+ with contextlib.redirect_stdout(io.StringIO()):
150
+ self._coco_api = COCO(json_file)
151
+
152
+ # Test set json files do not contain annotations (evaluation must be
153
+ # performed using the COCO evaluation server).
154
+ self._do_evaluation = "annotations" in self._coco_api.dataset
155
+ if self._do_evaluation:
156
+ self._kpt_oks_sigmas = kpt_oks_sigmas
157
+
158
+ def reset(self):
159
+ self._predictions = []
160
+
161
+ def process(self, inputs, outputs):
162
+ """
163
+ Args:
164
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
165
+ It is a list of dict. Each dict corresponds to an image and
166
+ contains keys like "height", "width", "file_name", "image_id".
167
+ outputs: the outputs of a COCO model. It is a list of dicts with key
168
+ "instances" that contains :class:`Instances`.
169
+ """
170
+ for input, output in zip(inputs, outputs):
171
+ prediction = {"image_id": input["image_id"]}
172
+
173
+ if "instances" in output:
174
+ instances = output["instances"].to(self._cpu_device)
175
+ prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
176
+ if len(prediction) > 1:
177
+ self._predictions.append(prediction)
178
+
179
+ def evaluate(self, img_ids=None):
180
+ """
181
+ Args:
182
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
183
+ """
184
+ if self._distributed:
185
+ comm.synchronize()
186
+ predictions = comm.gather(self._predictions, dst=0)
187
+ predictions = list(itertools.chain(*predictions))
188
+
189
+ if not comm.is_main_process():
190
+ return {}
191
+ else:
192
+ predictions = self._predictions
193
+
194
+ if len(predictions) == 0:
195
+ self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
196
+ return {}
197
+
198
+ if self._output_dir:
199
+ PathManager.mkdirs(self._output_dir)
200
+ file_path = os.path.join(self._output_dir, "instances_predictions.pth")
201
+ with PathManager.open(file_path, "wb") as f:
202
+ torch.save(predictions, f)
203
+
204
+ self._results = OrderedDict()
205
+ if "instances" in predictions[0]:
206
+ self._eval_predictions(predictions, img_ids=img_ids)
207
+ # Copy so the caller can do whatever with results
208
+ return copy.deepcopy(self._results)
209
+
210
+ def _tasks_from_predictions(self, predictions):
211
+ """
212
+ Get COCO API "tasks" (i.e. iou_type) from COCO-format predictions.
213
+ """
214
+ for pred in predictions:
215
+ if "segmentation" in pred:
216
+ tasks = {"segm"}
217
+ if "keypoints" in pred:
218
+ tasks.add("keypoints")
219
+ return sorted(tasks)
220
+
221
+ def _eval_predictions(self, predictions, img_ids=None):
222
+ """
223
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
224
+ """
225
+ self._logger.info("Preparing results for COCO format ...")
226
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
227
+ tasks = self._tasks or self._tasks_from_predictions(coco_results)
228
+
229
+ # unmap the category ids for COCO
230
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
231
+ dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
232
+ all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
233
+ num_classes = len(all_contiguous_ids)
234
+ assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
235
+
236
+ reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
237
+ for result in coco_results:
238
+ category_id = result["category_id"]
239
+ assert category_id < num_classes, (
240
+ f"A prediction has class={category_id}, "
241
+ f"but the dataset only has {num_classes} classes and "
242
+ f"predicted class id should be in [0, {num_classes - 1}]."
243
+ )
244
+ result["category_id"] = reverse_id_mapping[category_id]
245
+
246
+ if self._output_dir:
247
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
248
+ self._logger.info("Saving results to {}".format(file_path))
249
+ with PathManager.open(file_path, "w") as f:
250
+ f.write(json.dumps(coco_results))
251
+ f.flush()
252
+
253
+ if not self._do_evaluation:
254
+ self._logger.info("Annotations are not available for evaluation.")
255
+ return
256
+
257
+ self._logger.info(
258
+ "Evaluating predictions with {} COCO API...".format(
259
+ "unofficial" if self._use_fast_impl else "official"
260
+ )
261
+ )
262
+ for task in sorted(tasks):
263
+ assert task in {"segm", "keypoints"}, f"Got unknown task: {task}!"
264
+ coco_eval = (
265
+ _evaluate_predictions_on_coco(
266
+ self._coco_api,
267
+ coco_results,
268
+ task,
269
+ kpt_oks_sigmas=self._kpt_oks_sigmas,
270
+ use_fast_impl=self._use_fast_impl,
271
+ img_ids=img_ids,
272
+ max_dets_per_image=self._max_dets_per_image,
273
+ )
274
+ if len(coco_results) > 0
275
+ else None # cocoapi does not handle empty results very well
276
+ )
277
+
278
+ res = self._derive_coco_results(
279
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
280
+ )
281
+ self._results[task] = res
282
+
283
+ def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
284
+ """
285
+ Derive the desired score numbers from summarized COCOeval.
286
+
287
+ Args:
288
+ coco_eval (None or COCOEval): None represents no predictions from model.
289
+ iou_type (str):
290
+ class_names (None or list[str]): if provided, will use it to predict
291
+ per-category AP.
292
+
293
+ Returns:
294
+ a dict of {metric name: score}
295
+ """
296
+
297
+ metrics = {
298
+ "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
299
+ "keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
300
+ }[iou_type]
301
+
302
+ if coco_eval is None:
303
+ self._logger.warn("No predictions from the model!")
304
+ return {metric: float("nan") for metric in metrics}
305
+
306
+ # the standard metrics
307
+ results = {
308
+ metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
309
+ for idx, metric in enumerate(metrics)
310
+ }
311
+ self._logger.info(
312
+ "Evaluation results for {}: \n".format(iou_type) + create_small_table(results)
313
+ )
314
+ if not np.isfinite(sum(results.values())):
315
+ self._logger.info("Some metrics cannot be computed and is shown as NaN.")
316
+
317
+ if class_names is None or len(class_names) <= 1:
318
+ return results
319
+ # Compute per-category AP
320
+ # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
321
+ precisions = coco_eval.eval["precision"]
322
+ # precision has dims (iou, recall, cls, area range, max dets)
323
+ assert len(class_names) == precisions.shape[2]
324
+
325
+ results_per_category = []
326
+ for idx, name in enumerate(class_names):
327
+ # area range index 0: all area ranges
328
+ # max dets index -1: typically 100 per image
329
+ precision = precisions[:, :, idx, 0, -1]
330
+ precision = precision[precision > -1]
331
+ ap = np.mean(precision) if precision.size else float("nan")
332
+ results_per_category.append(("{}".format(name), float(ap * 100)))
333
+
334
+ # tabulate it
335
+ N_COLS = min(6, len(results_per_category) * 2)
336
+ results_flatten = list(itertools.chain(*results_per_category))
337
+ results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
338
+ table = tabulate(
339
+ results_2d,
340
+ tablefmt="pipe",
341
+ floatfmt=".3f",
342
+ headers=["category", "AP"] * (N_COLS // 2),
343
+ numalign="left",
344
+ )
345
+ self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
346
+
347
+ results.update({"AP-" + name: ap for name, ap in results_per_category})
348
+ return results
349
+
350
+
351
+ def instances_to_coco_json(instances, img_id):
352
+ """
353
+ Dump an "Instances" object to a COCO-format json that's used for evaluation.
354
+
355
+ Args:
356
+ instances (Instances):
357
+ img_id (int): the image id
358
+
359
+ Returns:
360
+ list[dict]: list of json annotations in COCO format.
361
+ """
362
+ num_instance = len(instances)
363
+ if num_instance == 0:
364
+ return []
365
+
366
+ scores = instances.scores.tolist()
367
+ classes = instances.pred_classes.tolist()
368
+
369
+ has_mask = instances.has("pred_masks")
370
+ if has_mask:
371
+ # use RLE to encode the masks, because they are too large and takes memory
372
+ # since this evaluator stores outputs of the entire dataset
373
+ rles = [
374
+ mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
375
+ for mask in instances.pred_masks
376
+ ]
377
+ for rle in rles:
378
+ # "counts" is an array encoded by mask_util as a byte-stream. Python3's
379
+ # json writer which always produces strings cannot serialize a bytestream
380
+ # unless you decode it. Thankfully, utf-8 works out (which is also what
381
+ # the annotator.oneformer.pycocotools/_mask.pyx does).
382
+ rle["counts"] = rle["counts"].decode("utf-8")
383
+
384
+ has_keypoints = instances.has("pred_keypoints")
385
+ if has_keypoints:
386
+ keypoints = instances.pred_keypoints
387
+
388
+ results = []
389
+ for k in range(num_instance):
390
+ result = {
391
+ "image_id": img_id,
392
+ "category_id": classes[k],
393
+ "score": scores[k],
394
+ }
395
+ if has_mask:
396
+ result["segmentation"] = rles[k]
397
+ if has_keypoints:
398
+ # In COCO annotations,
399
+ # keypoints coordinates are pixel indices.
400
+ # However our predictions are floating point coordinates.
401
+ # Therefore we subtract 0.5 to be consistent with the annotation format.
402
+ # This is the inverse of data loading logic in `datasets/coco.py`.
403
+ keypoints[k][:, :2] -= 0.5
404
+ result["keypoints"] = keypoints[k].flatten().tolist()
405
+ results.append(result)
406
+ return results
407
+
408
+ def _evaluate_predictions_on_coco(
409
+ coco_gt,
410
+ coco_results,
411
+ iou_type,
412
+ kpt_oks_sigmas=None,
413
+ use_fast_impl=True,
414
+ img_ids=None,
415
+ max_dets_per_image=None,
416
+ ):
417
+ """
418
+ Evaluate the coco results using COCOEval API.
419
+ """
420
+ assert len(coco_results) > 0
421
+
422
+ if iou_type == "segm":
423
+ coco_results = copy.deepcopy(coco_results)
424
+ # When evaluating mask AP, if the results contain bbox, cocoapi will
425
+ # use the box area as the area of the instance, instead of the mask area.
426
+ # This leads to a different definition of small/medium/large.
427
+ # We remove the bbox field to let mask AP use mask area.
428
+ for c in coco_results:
429
+ c.pop("bbox", None)
430
+
431
+ coco_dt = coco_gt.loadRes(coco_results)
432
+ coco_eval = (COCOeval_opt if use_fast_impl else COCOeval)(coco_gt, coco_dt, iou_type)
433
+ # For COCO, the default max_dets_per_image is [1, 10, 100].
434
+ if max_dets_per_image is None:
435
+ max_dets_per_image = [1, 10, 100] # Default from COCOEval
436
+ else:
437
+ assert (
438
+ len(max_dets_per_image) >= 3
439
+ ), "COCOeval requires maxDets (and max_dets_per_image) to have length at least 3"
440
+ # In the case that user supplies a custom input for max_dets_per_image,
441
+ # apply COCOevalMaxDets to evaluate AP with the custom input.
442
+ if max_dets_per_image[2] != 100:
443
+ coco_eval = COCOevalMaxDets(coco_gt, coco_dt, iou_type)
444
+ if iou_type != "keypoints":
445
+ coco_eval.params.maxDets = max_dets_per_image
446
+
447
+ if img_ids is not None:
448
+ coco_eval.params.imgIds = img_ids
449
+
450
+ if iou_type == "keypoints":
451
+ # Use the COCO default keypoint OKS sigmas unless overrides are specified
452
+ if kpt_oks_sigmas:
453
+ assert hasattr(coco_eval.params, "kpt_oks_sigmas"), "annotator.oneformer.pycocotools is too old!"
454
+ coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas)
455
+ # COCOAPI requires every detection and every gt to have keypoints, so
456
+ # we just take the first entry from both
457
+ num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3
458
+ num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3
459
+ num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas)
460
+ assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, (
461
+ f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. "
462
+ f"Ground truth contains {num_keypoints_gt} keypoints. "
463
+ f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. "
464
+ "They have to agree with each other. For meaning of OKS, please refer to "
465
+ "http://cocodataset.org/#keypoints-eval."
466
+ )
467
+
468
+ coco_eval.evaluate()
469
+ coco_eval.accumulate()
470
+ coco_eval.summarize()
471
+
472
+ return coco_eval
473
+
474
+
475
+ class COCOevalMaxDets(COCOeval):
476
+ """
477
+ Modified version of COCOeval for evaluating AP with a custom
478
+ maxDets (by default for COCO, maxDets is 100)
479
+ """
480
+
481
+ def summarize(self):
482
+ """
483
+ Compute and display summary metrics for evaluation results given
484
+ a custom value for max_dets_per_image
485
+ """
486
+
487
+ def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
488
+ p = self.params
489
+ iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
490
+ titleStr = "Average Precision" if ap == 1 else "Average Recall"
491
+ typeStr = "(AP)" if ap == 1 else "(AR)"
492
+ iouStr = (
493
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
494
+ if iouThr is None
495
+ else "{:0.2f}".format(iouThr)
496
+ )
497
+
498
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
499
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
500
+ if ap == 1:
501
+ # dimension of precision: [TxRxKxAxM]
502
+ s = self.eval["precision"]
503
+ # IoU
504
+ if iouThr is not None:
505
+ t = np.where(iouThr == p.iouThrs)[0]
506
+ s = s[t]
507
+ s = s[:, :, :, aind, mind]
508
+ else:
509
+ # dimension of recall: [TxKxAxM]
510
+ s = self.eval["recall"]
511
+ if iouThr is not None:
512
+ t = np.where(iouThr == p.iouThrs)[0]
513
+ s = s[t]
514
+ s = s[:, :, aind, mind]
515
+ if len(s[s > -1]) == 0:
516
+ mean_s = -1
517
+ else:
518
+ mean_s = np.mean(s[s > -1])
519
+ print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
520
+ return mean_s
521
+
522
+ def _summarizeDets():
523
+ stats = np.zeros((12,))
524
+ # Evaluate AP using the custom limit on maximum detections per image
525
+ stats[0] = _summarize(1, maxDets=self.params.maxDets[2])
526
+ stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
527
+ stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
528
+ stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2])
529
+ stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2])
530
+ stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2])
531
+ stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
532
+ stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
533
+ stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
534
+ stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2])
535
+ stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2])
536
+ stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2])
537
+ return stats
538
+
539
+ def _summarizeKps():
540
+ stats = np.zeros((10,))
541
+ stats[0] = _summarize(1, maxDets=20)
542
+ stats[1] = _summarize(1, maxDets=20, iouThr=0.5)
543
+ stats[2] = _summarize(1, maxDets=20, iouThr=0.75)
544
+ stats[3] = _summarize(1, maxDets=20, areaRng="medium")
545
+ stats[4] = _summarize(1, maxDets=20, areaRng="large")
546
+ stats[5] = _summarize(0, maxDets=20)
547
+ stats[6] = _summarize(0, maxDets=20, iouThr=0.5)
548
+ stats[7] = _summarize(0, maxDets=20, iouThr=0.75)
549
+ stats[8] = _summarize(0, maxDets=20, areaRng="medium")
550
+ stats[9] = _summarize(0, maxDets=20, areaRng="large")
551
+ return stats
552
+
553
+ if not self.eval:
554
+ raise Exception("Please run accumulate() first")
555
+ iouType = self.params.iouType
556
+ if iouType == "segm":
557
+ summarize = _summarizeDets
558
+ elif iouType == "keypoints":
559
+ summarize = _summarizeKps
560
+ self.stats = summarize()
561
+
562
+ def __str__(self):
563
+ self.summarize()
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/detection_coco_evaluator.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/coco_evaluation.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import contextlib
7
+ import copy
8
+ import io
9
+ import itertools
10
+ import json
11
+ import logging
12
+ import numpy as np
13
+ import os
14
+ import pickle
15
+ from collections import OrderedDict
16
+ import annotator.oneformer.pycocotools.mask as mask_util
17
+ import torch
18
+ from annotator.oneformer.pycocotools.coco import COCO
19
+ from annotator.oneformer.pycocotools.cocoeval import COCOeval
20
+ from tabulate import tabulate
21
+
22
+ import annotator.oneformer.detectron2.utils.comm as comm
23
+ from annotator.oneformer.detectron2.config import CfgNode
24
+ from annotator.oneformer.detectron2.data import MetadataCatalog
25
+ from annotator.oneformer.detectron2.data.datasets.coco import convert_to_coco_json
26
+ from annotator.oneformer.detectron2.structures import Boxes, BoxMode, pairwise_iou
27
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
28
+ from annotator.oneformer.detectron2.utils.logger import create_small_table
29
+
30
+ from .evaluator import DatasetEvaluator
31
+
32
+ try:
33
+ from annotator.oneformer.detectron2.evaluation.fast_eval_api import COCOeval_opt
34
+ except ImportError:
35
+ COCOeval_opt = COCOeval
36
+
37
+
38
+ class DetectionCOCOEvaluator(DatasetEvaluator):
39
+ """
40
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
41
+ for keypoint detection outputs using COCO's metrics.
42
+ See http://cocodataset.org/#detection-eval and
43
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
44
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
45
+ the metric cannot be computed (e.g. due to no predictions made).
46
+
47
+ In addition to COCO, this evaluator is able to support any bounding box detection,
48
+ instance segmentation, or keypoint detection dataset.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ dataset_name,
54
+ tasks=None,
55
+ distributed=True,
56
+ output_dir=None,
57
+ *,
58
+ max_dets_per_image=None,
59
+ use_fast_impl=True,
60
+ kpt_oks_sigmas=(),
61
+ allow_cached_coco=True,
62
+ ):
63
+ """
64
+ Args:
65
+ dataset_name (str): name of the dataset to be evaluated.
66
+ It must have either the following corresponding metadata:
67
+
68
+ "json_file": the path to the COCO format annotation
69
+
70
+ Or it must be in detectron2's standard dataset format
71
+ so it can be converted to COCO format automatically.
72
+ tasks (tuple[str]): tasks that can be evaluated under the given
73
+ configuration. A task is one of "bbox", "segm", "keypoints".
74
+ By default, will infer this automatically from predictions.
75
+ distributed (True): if True, will collect results from all ranks and run evaluation
76
+ in the main process.
77
+ Otherwise, will only evaluate the results in the current process.
78
+ output_dir (str): optional, an output directory to dump all
79
+ results predicted on the dataset. The dump contains two files:
80
+
81
+ 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
82
+ contains all the results in the format they are produced by the model.
83
+ 2. "coco_instances_results.json" a json file in COCO's result format.
84
+ max_dets_per_image (int): limit on the maximum number of detections per image.
85
+ By default in COCO, this limit is to 100, but this can be customized
86
+ to be greater, as is needed in evaluation metrics AP fixed and AP pool
87
+ (see https://arxiv.org/pdf/2102.01066.pdf)
88
+ This doesn't affect keypoint evaluation.
89
+ use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
90
+ Although the results should be very close to the official implementation in COCO
91
+ API, it is still recommended to compute results with the official API for use in
92
+ papers. The faster implementation also uses more RAM.
93
+ kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
94
+ See http://cocodataset.org/#keypoints-eval
95
+ When empty, it will use the defaults in COCO.
96
+ Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
97
+ allow_cached_coco (bool): Whether to use cached coco json from previous validation
98
+ runs. You should set this to False if you need to use different validation data.
99
+ Defaults to True.
100
+ """
101
+ self._logger = logging.getLogger(__name__)
102
+ self._distributed = distributed
103
+ self._output_dir = output_dir
104
+
105
+ if use_fast_impl and (COCOeval_opt is COCOeval):
106
+ self._logger.info("Fast COCO eval is not built. Falling back to official COCO eval.")
107
+ use_fast_impl = False
108
+ self._use_fast_impl = use_fast_impl
109
+
110
+ # COCOeval requires the limit on the number of detections per image (maxDets) to be a list
111
+ # with at least 3 elements. The default maxDets in COCOeval is [1, 10, 100], in which the
112
+ # 3rd element (100) is used as the limit on the number of detections per image when
113
+ # evaluating AP. COCOEvaluator expects an integer for max_dets_per_image, so for COCOeval,
114
+ # we reformat max_dets_per_image into [1, 10, max_dets_per_image], based on the defaults.
115
+ if max_dets_per_image is None:
116
+ max_dets_per_image = [1, 10, 100]
117
+ else:
118
+ max_dets_per_image = [1, 10, max_dets_per_image]
119
+ self._max_dets_per_image = max_dets_per_image
120
+
121
+ if tasks is not None and isinstance(tasks, CfgNode):
122
+ kpt_oks_sigmas = (
123
+ tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas
124
+ )
125
+ self._logger.warn(
126
+ "COCO Evaluator instantiated using config, this is deprecated behavior."
127
+ " Please pass in explicit arguments instead."
128
+ )
129
+ self._tasks = None # Infering it from predictions should be better
130
+ else:
131
+ self._tasks = tasks
132
+
133
+ self._cpu_device = torch.device("cpu")
134
+
135
+ self._metadata = MetadataCatalog.get(dataset_name)
136
+ if not hasattr(self._metadata, "json_file"):
137
+ if output_dir is None:
138
+ raise ValueError(
139
+ "output_dir must be provided to COCOEvaluator "
140
+ "for datasets not in COCO format."
141
+ )
142
+ self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...")
143
+
144
+ cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
145
+ self._metadata.json_file = cache_path
146
+ convert_to_coco_json(dataset_name, cache_path, allow_cached=allow_cached_coco)
147
+
148
+ json_file = PathManager.get_local_path(self._metadata.json_file)
149
+ with contextlib.redirect_stdout(io.StringIO()):
150
+ self._coco_api = COCO(json_file)
151
+
152
+ # Test set json files do not contain annotations (evaluation must be
153
+ # performed using the COCO evaluation server).
154
+ self._do_evaluation = "annotations" in self._coco_api.dataset
155
+ if self._do_evaluation:
156
+ self._kpt_oks_sigmas = kpt_oks_sigmas
157
+
158
+ def reset(self):
159
+ self._predictions = []
160
+
161
+ def process(self, inputs, outputs):
162
+ """
163
+ Args:
164
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
165
+ It is a list of dict. Each dict corresponds to an image and
166
+ contains keys like "height", "width", "file_name", "image_id".
167
+ outputs: the outputs of a COCO model. It is a list of dicts with key
168
+ "box_instances" that contains :class:`Instances`.
169
+ """
170
+ for input, output in zip(inputs, outputs):
171
+ prediction = {"image_id": input["image_id"]}
172
+
173
+ if "box_instances" in output:
174
+ instances = output["box_instances"].to(self._cpu_device)
175
+ prediction["box_instances"] = instances_to_coco_json(instances, input["image_id"])
176
+ if "proposals" in output:
177
+ prediction["proposals"] = output["proposals"].to(self._cpu_device)
178
+ if len(prediction) > 1:
179
+ self._predictions.append(prediction)
180
+
181
+ def evaluate(self, img_ids=None):
182
+ """
183
+ Args:
184
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
185
+ """
186
+ if self._distributed:
187
+ comm.synchronize()
188
+ predictions = comm.gather(self._predictions, dst=0)
189
+ predictions = list(itertools.chain(*predictions))
190
+
191
+ if not comm.is_main_process():
192
+ return {}
193
+ else:
194
+ predictions = self._predictions
195
+
196
+ if len(predictions) == 0:
197
+ self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
198
+ return {}
199
+
200
+ if self._output_dir:
201
+ PathManager.mkdirs(self._output_dir)
202
+ file_path = os.path.join(self._output_dir, "instances_predictions.pth")
203
+ with PathManager.open(file_path, "wb") as f:
204
+ torch.save(predictions, f)
205
+
206
+ self._results = OrderedDict()
207
+ if "proposals" in predictions[0]:
208
+ self._eval_box_proposals(predictions)
209
+ if "box_instances" in predictions[0]:
210
+ self._eval_predictions(predictions, img_ids=img_ids)
211
+ # Copy so the caller can do whatever with results
212
+ return copy.deepcopy(self._results)
213
+
214
+ def _tasks_from_predictions(self, predictions):
215
+ """
216
+ Get COCO API "tasks" (i.e. iou_type) from COCO-format predictions.
217
+ """
218
+ tasks = {"bbox"}
219
+ for pred in predictions:
220
+ if "keypoints" in pred:
221
+ tasks.add("keypoints")
222
+ return sorted(tasks)
223
+
224
+ def _eval_predictions(self, predictions, img_ids=None):
225
+ """
226
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
227
+ """
228
+ self._logger.info("Preparing results for COCO format ...")
229
+ coco_results = list(itertools.chain(*[x["box_instances"] for x in predictions]))
230
+ tasks = self._tasks or self._tasks_from_predictions(coco_results)
231
+
232
+ # unmap the category ids for COCO
233
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
234
+ dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
235
+ all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
236
+ num_classes = len(all_contiguous_ids)
237
+ assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
238
+
239
+ reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
240
+ for result in coco_results:
241
+ category_id = result["category_id"]
242
+ assert category_id < num_classes, (
243
+ f"A prediction has class={category_id}, "
244
+ f"but the dataset only has {num_classes} classes and "
245
+ f"predicted class id should be in [0, {num_classes - 1}]."
246
+ )
247
+ result["category_id"] = reverse_id_mapping[category_id]
248
+
249
+ if self._output_dir:
250
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
251
+ self._logger.info("Saving results to {}".format(file_path))
252
+ with PathManager.open(file_path, "w") as f:
253
+ f.write(json.dumps(coco_results))
254
+ f.flush()
255
+
256
+ if not self._do_evaluation:
257
+ self._logger.info("Annotations are not available for evaluation.")
258
+ return
259
+
260
+ self._logger.info(
261
+ "Evaluating predictions with {} COCO API...".format(
262
+ "unofficial" if self._use_fast_impl else "official"
263
+ )
264
+ )
265
+ for task in sorted(tasks):
266
+ assert task in {"bbox", "keypoints"}, f"Got unknown task: {task}!"
267
+ coco_eval = (
268
+ _evaluate_predictions_on_coco(
269
+ self._coco_api,
270
+ coco_results,
271
+ task,
272
+ kpt_oks_sigmas=self._kpt_oks_sigmas,
273
+ use_fast_impl=self._use_fast_impl,
274
+ img_ids=img_ids,
275
+ max_dets_per_image=self._max_dets_per_image,
276
+ )
277
+ if len(coco_results) > 0
278
+ else None # cocoapi does not handle empty results very well
279
+ )
280
+
281
+ res = self._derive_coco_results(
282
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
283
+ )
284
+ self._results[task] = res
285
+
286
+ def _eval_box_proposals(self, predictions):
287
+ """
288
+ Evaluate the box proposals in predictions.
289
+ Fill self._results with the metrics for "box_proposals" task.
290
+ """
291
+ if self._output_dir:
292
+ # Saving generated box proposals to file.
293
+ # Predicted box_proposals are in XYXY_ABS mode.
294
+ bbox_mode = BoxMode.XYXY_ABS.value
295
+ ids, boxes, objectness_logits = [], [], []
296
+ for prediction in predictions:
297
+ ids.append(prediction["image_id"])
298
+ boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
299
+ objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
300
+
301
+ proposal_data = {
302
+ "boxes": boxes,
303
+ "objectness_logits": objectness_logits,
304
+ "ids": ids,
305
+ "bbox_mode": bbox_mode,
306
+ }
307
+ with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
308
+ pickle.dump(proposal_data, f)
309
+
310
+ if not self._do_evaluation:
311
+ self._logger.info("Annotations are not available for evaluation.")
312
+ return
313
+
314
+ self._logger.info("Evaluating bbox proposals ...")
315
+ res = {}
316
+ areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
317
+ for limit in [100, 1000]:
318
+ for area, suffix in areas.items():
319
+ stats = _evaluate_box_proposals(predictions, self._coco_api, area=area, limit=limit)
320
+ key = "AR{}@{:d}".format(suffix, limit)
321
+ res[key] = float(stats["ar"].item() * 100)
322
+ self._logger.info("Proposal metrics: \n" + create_small_table(res))
323
+ self._results["box_proposals"] = res
324
+
325
+ def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
326
+ """
327
+ Derive the desired score numbers from summarized COCOeval.
328
+
329
+ Args:
330
+ coco_eval (None or COCOEval): None represents no predictions from model.
331
+ iou_type (str):
332
+ class_names (None or list[str]): if provided, will use it to predict
333
+ per-category AP.
334
+
335
+ Returns:
336
+ a dict of {metric name: score}
337
+ """
338
+
339
+ metrics = {
340
+ "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
341
+ "keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
342
+ }[iou_type]
343
+
344
+ if coco_eval is None:
345
+ self._logger.warn("No predictions from the model!")
346
+ return {metric: float("nan") for metric in metrics}
347
+
348
+ # the standard metrics
349
+ results = {
350
+ metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
351
+ for idx, metric in enumerate(metrics)
352
+ }
353
+ self._logger.info(
354
+ "Evaluation results for {}: \n".format(iou_type) + create_small_table(results)
355
+ )
356
+ if not np.isfinite(sum(results.values())):
357
+ self._logger.info("Some metrics cannot be computed and is shown as NaN.")
358
+
359
+ if class_names is None or len(class_names) <= 1:
360
+ return results
361
+ # Compute per-category AP
362
+ # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
363
+ precisions = coco_eval.eval["precision"]
364
+ # precision has dims (iou, recall, cls, area range, max dets)
365
+ assert len(class_names) == precisions.shape[2]
366
+
367
+ results_per_category = []
368
+ for idx, name in enumerate(class_names):
369
+ # area range index 0: all area ranges
370
+ # max dets index -1: typically 100 per image
371
+ precision = precisions[:, :, idx, 0, -1]
372
+ precision = precision[precision > -1]
373
+ ap = np.mean(precision) if precision.size else float("nan")
374
+ results_per_category.append(("{}".format(name), float(ap * 100)))
375
+
376
+ # tabulate it
377
+ N_COLS = min(6, len(results_per_category) * 2)
378
+ results_flatten = list(itertools.chain(*results_per_category))
379
+ results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
380
+ table = tabulate(
381
+ results_2d,
382
+ tablefmt="pipe",
383
+ floatfmt=".3f",
384
+ headers=["category", "AP"] * (N_COLS // 2),
385
+ numalign="left",
386
+ )
387
+ self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
388
+
389
+ results.update({"AP-" + name: ap for name, ap in results_per_category})
390
+ return results
391
+
392
+
393
+ def instances_to_coco_json(instances, img_id):
394
+ """
395
+ Dump an "Instances" object to a COCO-format json that's used for evaluation.
396
+
397
+ Args:
398
+ instances (Instances):
399
+ img_id (int): the image id
400
+
401
+ Returns:
402
+ list[dict]: list of json annotations in COCO format.
403
+ """
404
+ num_instance = len(instances)
405
+ if num_instance == 0:
406
+ return []
407
+
408
+ boxes = instances.pred_boxes.tensor.numpy()
409
+ boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
410
+ boxes = boxes.tolist()
411
+ scores = instances.scores.tolist()
412
+ classes = instances.pred_classes.tolist()
413
+
414
+ has_mask = instances.has("pred_masks")
415
+ if has_mask:
416
+ # use RLE to encode the masks, because they are too large and takes memory
417
+ # since this evaluator stores outputs of the entire dataset
418
+ rles = [
419
+ mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
420
+ for mask in instances.pred_masks
421
+ ]
422
+ for rle in rles:
423
+ # "counts" is an array encoded by mask_util as a byte-stream. Python3's
424
+ # json writer which always produces strings cannot serialize a bytestream
425
+ # unless you decode it. Thankfully, utf-8 works out (which is also what
426
+ # the annotator.oneformer.pycocotools/_mask.pyx does).
427
+ rle["counts"] = rle["counts"].decode("utf-8")
428
+
429
+ has_keypoints = instances.has("pred_keypoints")
430
+ if has_keypoints:
431
+ keypoints = instances.pred_keypoints
432
+
433
+ results = []
434
+ for k in range(num_instance):
435
+ result = {
436
+ "image_id": img_id,
437
+ "category_id": classes[k],
438
+ "bbox": boxes[k],
439
+ "score": scores[k],
440
+ }
441
+ if has_mask:
442
+ result["segmentation"] = rles[k]
443
+ if has_keypoints:
444
+ # In COCO annotations,
445
+ # keypoints coordinates are pixel indices.
446
+ # However our predictions are floating point coordinates.
447
+ # Therefore we subtract 0.5 to be consistent with the annotation format.
448
+ # This is the inverse of data loading logic in `datasets/coco.py`.
449
+ keypoints[k][:, :2] -= 0.5
450
+ result["keypoints"] = keypoints[k].flatten().tolist()
451
+ results.append(result)
452
+ return results
453
+
454
+
455
+ # inspired from Detectron:
456
+ # https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
457
+ def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area="all", limit=None):
458
+ """
459
+ Evaluate detection proposal recall metrics. This function is a much
460
+ faster alternative to the official COCO API recall evaluation code. However,
461
+ it produces slightly different results.
462
+ """
463
+ # Record max overlap value for each gt box
464
+ # Return vector of overlap values
465
+ areas = {
466
+ "all": 0,
467
+ "small": 1,
468
+ "medium": 2,
469
+ "large": 3,
470
+ "96-128": 4,
471
+ "128-256": 5,
472
+ "256-512": 6,
473
+ "512-inf": 7,
474
+ }
475
+ area_ranges = [
476
+ [0**2, 1e5**2], # all
477
+ [0**2, 32**2], # small
478
+ [32**2, 96**2], # medium
479
+ [96**2, 1e5**2], # large
480
+ [96**2, 128**2], # 96-128
481
+ [128**2, 256**2], # 128-256
482
+ [256**2, 512**2], # 256-512
483
+ [512**2, 1e5**2],
484
+ ] # 512-inf
485
+ assert area in areas, "Unknown area range: {}".format(area)
486
+ area_range = area_ranges[areas[area]]
487
+ gt_overlaps = []
488
+ num_pos = 0
489
+
490
+ for prediction_dict in dataset_predictions:
491
+ predictions = prediction_dict["proposals"]
492
+
493
+ # sort predictions in descending order
494
+ # TODO maybe remove this and make it explicit in the documentation
495
+ inds = predictions.objectness_logits.sort(descending=True)[1]
496
+ predictions = predictions[inds]
497
+
498
+ ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"])
499
+ anno = coco_api.loadAnns(ann_ids)
500
+ gt_boxes = [
501
+ BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
502
+ for obj in anno
503
+ if obj["iscrowd"] == 0
504
+ ]
505
+ gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
506
+ gt_boxes = Boxes(gt_boxes)
507
+ gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
508
+
509
+ if len(gt_boxes) == 0 or len(predictions) == 0:
510
+ continue
511
+
512
+ valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
513
+ gt_boxes = gt_boxes[valid_gt_inds]
514
+
515
+ num_pos += len(gt_boxes)
516
+
517
+ if len(gt_boxes) == 0:
518
+ continue
519
+
520
+ if limit is not None and len(predictions) > limit:
521
+ predictions = predictions[:limit]
522
+
523
+ overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
524
+
525
+ _gt_overlaps = torch.zeros(len(gt_boxes))
526
+ for j in range(min(len(predictions), len(gt_boxes))):
527
+ # find which proposal box maximally covers each gt box
528
+ # and get the iou amount of coverage for each gt box
529
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
530
+
531
+ # find which gt box is 'best' covered (i.e. 'best' = most iou)
532
+ gt_ovr, gt_ind = max_overlaps.max(dim=0)
533
+ assert gt_ovr >= 0
534
+ # find the proposal box that covers the best covered gt box
535
+ box_ind = argmax_overlaps[gt_ind]
536
+ # record the iou coverage of this gt box
537
+ _gt_overlaps[j] = overlaps[box_ind, gt_ind]
538
+ assert _gt_overlaps[j] == gt_ovr
539
+ # mark the proposal box and the gt box as used
540
+ overlaps[box_ind, :] = -1
541
+ overlaps[:, gt_ind] = -1
542
+
543
+ # append recorded iou coverage level
544
+ gt_overlaps.append(_gt_overlaps)
545
+ gt_overlaps = (
546
+ torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
547
+ )
548
+ gt_overlaps, _ = torch.sort(gt_overlaps)
549
+
550
+ if thresholds is None:
551
+ step = 0.05
552
+ thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
553
+ recalls = torch.zeros_like(thresholds)
554
+ # compute recall for each iou threshold
555
+ for i, t in enumerate(thresholds):
556
+ recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
557
+ # ar = 2 * np.trapz(recalls, thresholds)
558
+ ar = recalls.mean()
559
+ return {
560
+ "ar": ar,
561
+ "recalls": recalls,
562
+ "thresholds": thresholds,
563
+ "gt_overlaps": gt_overlaps,
564
+ "num_pos": num_pos,
565
+ }
566
+
567
+
568
+ def _evaluate_predictions_on_coco(
569
+ coco_gt,
570
+ coco_results,
571
+ iou_type,
572
+ kpt_oks_sigmas=None,
573
+ use_fast_impl=True,
574
+ img_ids=None,
575
+ max_dets_per_image=None,
576
+ ):
577
+ """
578
+ Evaluate the coco results using COCOEval API.
579
+ """
580
+ assert len(coco_results) > 0
581
+
582
+ if iou_type == "segm":
583
+ coco_results = copy.deepcopy(coco_results)
584
+ # When evaluating mask AP, if the results contain bbox, cocoapi will
585
+ # use the box area as the area of the instance, instead of the mask area.
586
+ # This leads to a different definition of small/medium/large.
587
+ # We remove the bbox field to let mask AP use mask area.
588
+ for c in coco_results:
589
+ c.pop("bbox", None)
590
+
591
+ coco_dt = coco_gt.loadRes(coco_results)
592
+ coco_eval = (COCOeval_opt if use_fast_impl else COCOeval)(coco_gt, coco_dt, iou_type)
593
+ # For COCO, the default max_dets_per_image is [1, 10, 100].
594
+ if max_dets_per_image is None:
595
+ max_dets_per_image = [1, 10, 100] # Default from COCOEval
596
+ else:
597
+ assert (
598
+ len(max_dets_per_image) >= 3
599
+ ), "COCOeval requires maxDets (and max_dets_per_image) to have length at least 3"
600
+ # In the case that user supplies a custom input for max_dets_per_image,
601
+ # apply COCOevalMaxDets to evaluate AP with the custom input.
602
+ if max_dets_per_image[2] != 100:
603
+ coco_eval = COCOevalMaxDets(coco_gt, coco_dt, iou_type)
604
+ if iou_type != "keypoints":
605
+ coco_eval.params.maxDets = max_dets_per_image
606
+
607
+ if img_ids is not None:
608
+ coco_eval.params.imgIds = img_ids
609
+
610
+ if iou_type == "keypoints":
611
+ # Use the COCO default keypoint OKS sigmas unless overrides are specified
612
+ if kpt_oks_sigmas:
613
+ assert hasattr(coco_eval.params, "kpt_oks_sigmas"), "annotator.oneformer.pycocotools is too old!"
614
+ coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas)
615
+ # COCOAPI requires every detection and every gt to have keypoints, so
616
+ # we just take the first entry from both
617
+ num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3
618
+ num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3
619
+ num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas)
620
+ assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, (
621
+ f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. "
622
+ f"Ground truth contains {num_keypoints_gt} keypoints. "
623
+ f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. "
624
+ "They have to agree with each other. For meaning of OKS, please refer to "
625
+ "http://cocodataset.org/#keypoints-eval."
626
+ )
627
+
628
+ coco_eval.evaluate()
629
+ coco_eval.accumulate()
630
+ coco_eval.summarize()
631
+
632
+ return coco_eval
633
+
634
+
635
+ class COCOevalMaxDets(COCOeval):
636
+ """
637
+ Modified version of COCOeval for evaluating AP with a custom
638
+ maxDets (by default for COCO, maxDets is 100)
639
+ """
640
+
641
+ def summarize(self):
642
+ """
643
+ Compute and display summary metrics for evaluation results given
644
+ a custom value for max_dets_per_image
645
+ """
646
+
647
+ def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
648
+ p = self.params
649
+ iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
650
+ titleStr = "Average Precision" if ap == 1 else "Average Recall"
651
+ typeStr = "(AP)" if ap == 1 else "(AR)"
652
+ iouStr = (
653
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
654
+ if iouThr is None
655
+ else "{:0.2f}".format(iouThr)
656
+ )
657
+
658
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
659
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
660
+ if ap == 1:
661
+ # dimension of precision: [TxRxKxAxM]
662
+ s = self.eval["precision"]
663
+ # IoU
664
+ if iouThr is not None:
665
+ t = np.where(iouThr == p.iouThrs)[0]
666
+ s = s[t]
667
+ s = s[:, :, :, aind, mind]
668
+ else:
669
+ # dimension of recall: [TxKxAxM]
670
+ s = self.eval["recall"]
671
+ if iouThr is not None:
672
+ t = np.where(iouThr == p.iouThrs)[0]
673
+ s = s[t]
674
+ s = s[:, :, aind, mind]
675
+ if len(s[s > -1]) == 0:
676
+ mean_s = -1
677
+ else:
678
+ mean_s = np.mean(s[s > -1])
679
+ print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
680
+ return mean_s
681
+
682
+ def _summarizeDets():
683
+ stats = np.zeros((12,))
684
+ # Evaluate AP using the custom limit on maximum detections per image
685
+ stats[0] = _summarize(1, maxDets=self.params.maxDets[2])
686
+ stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
687
+ stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
688
+ stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2])
689
+ stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2])
690
+ stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2])
691
+ stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
692
+ stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
693
+ stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
694
+ stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2])
695
+ stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2])
696
+ stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2])
697
+ return stats
698
+
699
+ def _summarizeKps():
700
+ stats = np.zeros((10,))
701
+ stats[0] = _summarize(1, maxDets=20)
702
+ stats[1] = _summarize(1, maxDets=20, iouThr=0.5)
703
+ stats[2] = _summarize(1, maxDets=20, iouThr=0.75)
704
+ stats[3] = _summarize(1, maxDets=20, areaRng="medium")
705
+ stats[4] = _summarize(1, maxDets=20, areaRng="large")
706
+ stats[5] = _summarize(0, maxDets=20)
707
+ stats[6] = _summarize(0, maxDets=20, iouThr=0.5)
708
+ stats[7] = _summarize(0, maxDets=20, iouThr=0.75)
709
+ stats[8] = _summarize(0, maxDets=20, areaRng="medium")
710
+ stats[9] = _summarize(0, maxDets=20, areaRng="large")
711
+ return stats
712
+
713
+ if not self.eval:
714
+ raise Exception("Please run accumulate() first")
715
+ iouType = self.params.iouType
716
+ if iouType == "segm" or iouType == "bbox":
717
+ summarize = _summarizeDets
718
+ elif iouType == "keypoints":
719
+ summarize = _summarizeKps
720
+ self.stats = summarize()
721
+
722
+ def __str__(self):
723
+ self.summarize()
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/evaluator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/evaluator.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import datetime
7
+ import logging
8
+ import time
9
+ from collections import OrderedDict, abc
10
+ from contextlib import ExitStack, contextmanager
11
+ from typing import List, Union
12
+ import torch
13
+ from torch import nn
14
+
15
+ from annotator.oneformer.detectron2.utils.comm import get_world_size, is_main_process
16
+ from annotator.oneformer.detectron2.utils.logger import log_every_n_seconds
17
+
18
+
19
+ class DatasetEvaluator:
20
+ """
21
+ Base class for a dataset evaluator.
22
+
23
+ The function :func:`inference_on_dataset` runs the model over
24
+ all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
25
+
26
+ This class will accumulate information of the inputs/outputs (by :meth:`process`),
27
+ and produce evaluation results in the end (by :meth:`evaluate`).
28
+ """
29
+
30
+ def reset(self):
31
+ """
32
+ Preparation for a new round of evaluation.
33
+ Should be called before starting a round of evaluation.
34
+ """
35
+ pass
36
+
37
+ def process(self, inputs, outputs):
38
+ """
39
+ Process the pair of inputs and outputs.
40
+ If they contain batches, the pairs can be consumed one-by-one using `zip`:
41
+
42
+ .. code-block:: python
43
+
44
+ for input_, output in zip(inputs, outputs):
45
+ # do evaluation on single input/output pair
46
+ ...
47
+
48
+ Args:
49
+ inputs (list): the inputs that's used to call the model.
50
+ outputs (list): the return value of `model(inputs)`
51
+ """
52
+ pass
53
+
54
+ def evaluate(self):
55
+ """
56
+ Evaluate/summarize the performance, after processing all input/output pairs.
57
+
58
+ Returns:
59
+ dict:
60
+ A new evaluator class can return a dict of arbitrary format
61
+ as long as the user can process the results.
62
+ In our train_net.py, we expect the following format:
63
+
64
+ * key: the name of the task (e.g., bbox)
65
+ * value: a dict of {metric name: score}, e.g.: {"AP50": 80}
66
+ """
67
+ pass
68
+
69
+
70
+ class DatasetEvaluators(DatasetEvaluator):
71
+ """
72
+ Wrapper class to combine multiple :class:`DatasetEvaluator` instances.
73
+
74
+ This class dispatches every evaluation call to
75
+ all of its :class:`DatasetEvaluator`.
76
+ """
77
+
78
+ def __init__(self, evaluators):
79
+ """
80
+ Args:
81
+ evaluators (list): the evaluators to combine.
82
+ """
83
+ super().__init__()
84
+ self._evaluators = evaluators
85
+
86
+ def reset(self):
87
+ for evaluator in self._evaluators:
88
+ evaluator.reset()
89
+
90
+ def process(self, inputs, outputs):
91
+ for evaluator in self._evaluators:
92
+ evaluator.process(inputs, outputs)
93
+
94
+ def evaluate(self):
95
+ results = OrderedDict()
96
+ for evaluator in self._evaluators:
97
+ result = evaluator.evaluate()
98
+ if is_main_process() and result is not None:
99
+ for k, v in result.items():
100
+ assert (
101
+ k not in results
102
+ ), "Different evaluators produce results with the same key {}".format(k)
103
+ results[k] = v
104
+ return results
105
+
106
+
107
+ def inference_on_dataset(
108
+ model, data_loader, evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None]
109
+ ):
110
+ """
111
+ Run model on the data_loader and evaluate the metrics with evaluator.
112
+ Also benchmark the inference speed of `model.__call__` accurately.
113
+ The model will be used in eval mode.
114
+
115
+ Args:
116
+ model (callable): a callable which takes an object from
117
+ `data_loader` and returns some outputs.
118
+
119
+ If it's an nn.Module, it will be temporarily set to `eval` mode.
120
+ If you wish to evaluate a model in `training` mode instead, you can
121
+ wrap the given model and override its behavior of `.eval()` and `.train()`.
122
+ data_loader: an iterable object with a length.
123
+ The elements it generates will be the inputs to the model.
124
+ evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark,
125
+ but don't want to do any evaluation.
126
+
127
+ Returns:
128
+ The return value of `evaluator.evaluate()`
129
+ """
130
+ num_devices = get_world_size()
131
+ logger = logging.getLogger(__name__)
132
+ logger.info("Start inference on {} batches".format(len(data_loader)))
133
+
134
+ total = len(data_loader) # inference data loader must have a fixed length
135
+ if evaluator is None:
136
+ # create a no-op evaluator
137
+ evaluator = DatasetEvaluators([])
138
+ if isinstance(evaluator, abc.MutableSequence):
139
+ evaluator = DatasetEvaluators(evaluator)
140
+ evaluator.reset()
141
+
142
+ num_warmup = min(5, total - 1)
143
+ start_time = time.perf_counter()
144
+ total_data_time = 0
145
+ total_compute_time = 0
146
+ total_eval_time = 0
147
+ with ExitStack() as stack:
148
+ if isinstance(model, nn.Module):
149
+ stack.enter_context(inference_context(model))
150
+ stack.enter_context(torch.no_grad())
151
+
152
+ start_data_time = time.perf_counter()
153
+ for idx, inputs in enumerate(data_loader):
154
+ total_data_time += time.perf_counter() - start_data_time
155
+ if idx == num_warmup:
156
+ start_time = time.perf_counter()
157
+ total_data_time = 0
158
+ total_compute_time = 0
159
+ total_eval_time = 0
160
+
161
+ start_compute_time = time.perf_counter()
162
+ outputs = model(inputs)
163
+ if torch.cuda.is_available():
164
+ torch.cuda.synchronize()
165
+ total_compute_time += time.perf_counter() - start_compute_time
166
+
167
+ start_eval_time = time.perf_counter()
168
+ evaluator.process(inputs, outputs)
169
+ total_eval_time += time.perf_counter() - start_eval_time
170
+
171
+ iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
172
+ data_seconds_per_iter = total_data_time / iters_after_start
173
+ compute_seconds_per_iter = total_compute_time / iters_after_start
174
+ eval_seconds_per_iter = total_eval_time / iters_after_start
175
+ total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
176
+ if idx >= num_warmup * 2 or compute_seconds_per_iter > 5:
177
+ eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
178
+ log_every_n_seconds(
179
+ logging.INFO,
180
+ (
181
+ f"Inference done {idx + 1}/{total}. "
182
+ f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
183
+ f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
184
+ f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
185
+ f"Total: {total_seconds_per_iter:.4f} s/iter. "
186
+ f"ETA={eta}"
187
+ ),
188
+ n=5,
189
+ )
190
+ start_data_time = time.perf_counter()
191
+
192
+ # Measure the time only for this worker (before the synchronization barrier)
193
+ total_time = time.perf_counter() - start_time
194
+ total_time_str = str(datetime.timedelta(seconds=total_time))
195
+ # NOTE this format is parsed by grep
196
+ logger.info(
197
+ "Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format(
198
+ total_time_str, total_time / (total - num_warmup), num_devices
199
+ )
200
+ )
201
+ total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
202
+ logger.info(
203
+ "Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format(
204
+ total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
205
+ )
206
+ )
207
+
208
+ results = evaluator.evaluate()
209
+ # An evaluator may return None when not in main process.
210
+ # Replace it by an empty dict instead to make it easier for downstream code to handle
211
+ if results is None:
212
+ results = {}
213
+ return results
214
+
215
+
216
+ @contextmanager
217
+ def inference_context(model):
218
+ """
219
+ A context where the model is temporarily changed to eval mode,
220
+ and restored to previous mode afterwards.
221
+
222
+ Args:
223
+ model: a torch Module
224
+ """
225
+ training_mode = model.training
226
+ model.eval()
227
+ yield
228
+ model.train(training_mode)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/evaluation/instance_evaluation.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/evaluation/instance_evaluation.py
3
+ # ------------------------------------------------------------------------------
4
+
5
+ import contextlib
6
+ import copy
7
+ import io
8
+ import itertools
9
+ import json
10
+ import logging
11
+ import numpy as np
12
+ import os
13
+ import pickle
14
+ from collections import OrderedDict
15
+ import annotator.oneformer.pycocotools.mask as mask_util
16
+ import torch
17
+ from annotator.oneformer.pycocotools.coco import COCO
18
+ from annotator.oneformer.pycocotools.cocoeval import COCOeval
19
+ from tabulate import tabulate
20
+
21
+ import annotator.oneformer.detectron2.utils.comm as comm
22
+ from annotator.oneformer.detectron2.config import CfgNode
23
+ from annotator.oneformer.detectron2.data import MetadataCatalog
24
+ from annotator.oneformer.detectron2.data.datasets.coco import convert_to_coco_json
25
+ from annotator.oneformer.detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
26
+ from annotator.oneformer.detectron2.evaluation.fast_eval_api import COCOeval_opt
27
+ from annotator.oneformer.detectron2.structures import Boxes, BoxMode, pairwise_iou
28
+ from annotator.oneformer.detectron2.utils.file_io import PathManager
29
+ from annotator.oneformer.detectron2.utils.logger import create_small_table
30
+
31
+
32
+ # modified from COCOEvaluator for instance segmetnat
33
+ class InstanceSegEvaluator(COCOEvaluator):
34
+ """
35
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
36
+ for keypoint detection outputs using COCO's metrics.
37
+ See http://cocodataset.org/#detection-eval and
38
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
39
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
40
+ the metric cannot be computed (e.g. due to no predictions made).
41
+
42
+ In addition to COCO, this evaluator is able to support any bounding box detection,
43
+ instance segmentation, or keypoint detection dataset.
44
+ """
45
+
46
+ def _eval_predictions(self, predictions, img_ids=None):
47
+ """
48
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
49
+ """
50
+ self._logger.info("Preparing results for COCO format ...")
51
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
52
+ tasks = self._tasks or self._tasks_from_predictions(coco_results)
53
+
54
+ # unmap the category ids for COCO
55
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
56
+ dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
57
+ # all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
58
+ # num_classes = len(all_contiguous_ids)
59
+ # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
60
+
61
+ reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
62
+ for result in coco_results:
63
+ category_id = result["category_id"]
64
+ # assert category_id < num_classes, (
65
+ # f"A prediction has class={category_id}, "
66
+ # f"but the dataset only has {num_classes} classes and "
67
+ # f"predicted class id should be in [0, {num_classes - 1}]."
68
+ # )
69
+ assert category_id in reverse_id_mapping, (
70
+ f"A prediction has class={category_id}, "
71
+ f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
72
+ )
73
+ result["category_id"] = reverse_id_mapping[category_id]
74
+
75
+ if self._output_dir:
76
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
77
+ self._logger.info("Saving results to {}".format(file_path))
78
+ with PathManager.open(file_path, "w") as f:
79
+ f.write(json.dumps(coco_results))
80
+ f.flush()
81
+
82
+ if not self._do_evaluation:
83
+ self._logger.info("Annotations are not available for evaluation.")
84
+ return
85
+
86
+ self._logger.info(
87
+ "Evaluating predictions with {} COCO API...".format(
88
+ "unofficial" if self._use_fast_impl else "official"
89
+ )
90
+ )
91
+ for task in sorted(tasks):
92
+ assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
93
+ coco_eval = (
94
+ _evaluate_predictions_on_coco(
95
+ self._coco_api,
96
+ coco_results,
97
+ task,
98
+ kpt_oks_sigmas=self._kpt_oks_sigmas,
99
+ use_fast_impl=self._use_fast_impl,
100
+ img_ids=img_ids,
101
+ max_dets_per_image=self._max_dets_per_image,
102
+ )
103
+ if len(coco_results) > 0
104
+ else None # cocoapi does not handle empty results very well
105
+ )
106
+
107
+ res = self._derive_coco_results(
108
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
109
+ )
110
+ self._results[task] = res
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .backbone.swin import D2SwinTransformer
2
+ from .backbone.dinat import D2DiNAT
3
+ from .pixel_decoder.fpn import BasePixelDecoder
4
+ from .pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder
5
+ from .meta_arch.oneformer_head import OneFormerHead
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/backbone/dinat.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Neighborhood Attention Transformer
3
+ # Licensed under The MIT License
4
+ # Written by Ali Hassani
5
+ # --------------------------------------------------------
6
+
7
+ # Modified by Jitesh Jain
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.models.layers import DropPath
12
+ from annotator.oneformer.detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
13
+
14
+ class NeighborhoodAttention(nn.Module):
15
+ """
16
+ Neighborhood Attention 2D Module
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ num_heads,
23
+ kernel_size,
24
+ dilation=1,
25
+ bias=True,
26
+ qkv_bias=True,
27
+ qk_scale=None,
28
+ attn_drop=0.0,
29
+ proj_drop=0.0,
30
+ ):
31
+ super().__init__()
32
+
33
+
34
+ def forward(self, x):
35
+
36
+ return x
37
+
38
+ def extra_repr(self) -> str:
39
+ return (
40
+ f"head_dim={self.head_dim}, num_heads={self.num_heads}, "
41
+ + f"kernel_size={self.kernel_size}, dilation={self.dilation}, "
42
+ + f"rel_pos_bias={self.rpb is not None}"
43
+ )
44
+
45
+ class ConvTokenizer(nn.Module):
46
+ def __init__(self, in_chans=3, embed_dim=96, norm_layer=None):
47
+ super().__init__()
48
+ self.proj = nn.Sequential(
49
+ nn.Conv2d(in_chans, embed_dim // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
50
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
51
+ )
52
+ if norm_layer is not None:
53
+ self.norm = norm_layer(embed_dim)
54
+ else:
55
+ self.norm = None
56
+
57
+ def forward(self, x):
58
+ x = self.proj(x).permute(0, 2, 3, 1)
59
+ if self.norm is not None:
60
+ x = self.norm(x)
61
+ return x
62
+
63
+
64
+ class ConvDownsampler(nn.Module):
65
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
66
+ super().__init__()
67
+ self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
68
+ self.norm = norm_layer(2 * dim)
69
+
70
+ def forward(self, x):
71
+ x = self.reduction(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
72
+ x = self.norm(x)
73
+ return x
74
+
75
+
76
+ class Mlp(nn.Module):
77
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
78
+ super().__init__()
79
+ out_features = out_features or in_features
80
+ hidden_features = hidden_features or in_features
81
+ self.fc1 = nn.Linear(in_features, hidden_features)
82
+ self.act = act_layer()
83
+ self.fc2 = nn.Linear(hidden_features, out_features)
84
+ self.drop = nn.Dropout(drop)
85
+
86
+ def forward(self, x):
87
+ x = self.fc1(x)
88
+ x = self.act(x)
89
+ x = self.drop(x)
90
+ x = self.fc2(x)
91
+ x = self.drop(x)
92
+ return x
93
+
94
+
95
+ class NATLayer(nn.Module):
96
+ def __init__(self, dim, num_heads, kernel_size=7, dilation=None,
97
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
98
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=None):
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.num_heads = num_heads
102
+ self.mlp_ratio = mlp_ratio
103
+
104
+ self.norm1 = norm_layer(dim)
105
+ self.attn = NeighborhoodAttention(
106
+ dim, kernel_size=kernel_size, dilation=dilation, num_heads=num_heads,
107
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
108
+
109
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
110
+ self.norm2 = norm_layer(dim)
111
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
112
+ self.layer_scale = False
113
+ if layer_scale is not None and type(layer_scale) in [int, float]:
114
+ self.layer_scale = True
115
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
116
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
117
+
118
+ def forward(self, x):
119
+ if not self.layer_scale:
120
+ shortcut = x
121
+ x = self.norm1(x)
122
+ x = self.attn(x)
123
+ x = shortcut + self.drop_path(x)
124
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
125
+ return x
126
+ shortcut = x
127
+ x = self.norm1(x)
128
+ x = self.attn(x)
129
+ x = shortcut + self.drop_path(self.gamma1 * x)
130
+ x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
131
+ return x
132
+
133
+
134
+
135
+ class NATBlock(nn.Module):
136
+ def __init__(self, dim, depth, num_heads, kernel_size, dilations=None,
137
+ downsample=True,
138
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
139
+ drop_path=0., norm_layer=nn.LayerNorm, layer_scale=None):
140
+ super().__init__()
141
+ self.dim = dim
142
+ self.depth = depth
143
+
144
+ self.blocks = nn.ModuleList([
145
+ NATLayer(dim=dim,
146
+ num_heads=num_heads,
147
+ kernel_size=kernel_size,
148
+ dilation=None if dilations is None else dilations[i],
149
+ mlp_ratio=mlp_ratio,
150
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
151
+ drop=drop, attn_drop=attn_drop,
152
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
153
+ norm_layer=norm_layer,
154
+ layer_scale=layer_scale)
155
+ for i in range(depth)])
156
+
157
+ self.downsample = None if not downsample else ConvDownsampler(dim=dim, norm_layer=norm_layer)
158
+
159
+ def forward(self, x):
160
+ for blk in self.blocks:
161
+ x = blk(x)
162
+ if self.downsample is None:
163
+ return x, x
164
+ return self.downsample(x), x
165
+
166
+
167
+ class DiNAT(nn.Module):
168
+ def __init__(self,
169
+ embed_dim,
170
+ mlp_ratio,
171
+ depths,
172
+ num_heads,
173
+ drop_path_rate=0.2,
174
+ in_chans=3,
175
+ kernel_size=7,
176
+ dilations=None,
177
+ out_indices=(0, 1, 2, 3),
178
+ qkv_bias=True,
179
+ qk_scale=None,
180
+ drop_rate=0.,
181
+ attn_drop_rate=0.,
182
+ norm_layer=nn.LayerNorm,
183
+ frozen_stages=-1,
184
+ layer_scale=None,
185
+ **kwargs):
186
+ super().__init__()
187
+ self.num_levels = len(depths)
188
+ self.embed_dim = embed_dim
189
+ self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_levels)]
190
+ self.mlp_ratio = mlp_ratio
191
+
192
+ self.patch_embed = ConvTokenizer(in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer)
193
+
194
+ self.pos_drop = nn.Dropout(p=drop_rate)
195
+
196
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
197
+ self.levels = nn.ModuleList()
198
+ for i in range(self.num_levels):
199
+ level = NATBlock(dim=int(embed_dim * 2 ** i),
200
+ depth=depths[i],
201
+ num_heads=num_heads[i],
202
+ kernel_size=kernel_size,
203
+ dilations=None if dilations is None else dilations[i],
204
+ mlp_ratio=self.mlp_ratio,
205
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
206
+ drop=drop_rate, attn_drop=attn_drop_rate,
207
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
208
+ norm_layer=norm_layer,
209
+ downsample=(i < self.num_levels - 1),
210
+ layer_scale=layer_scale)
211
+ self.levels.append(level)
212
+
213
+ # add a norm layer for each output
214
+ self.out_indices = out_indices
215
+ for i_layer in self.out_indices:
216
+ layer = norm_layer(self.num_features[i_layer])
217
+ layer_name = f'norm{i_layer}'
218
+ self.add_module(layer_name, layer)
219
+
220
+ self.frozen_stages = frozen_stages
221
+
222
+ def _freeze_stages(self):
223
+ if self.frozen_stages >= 0:
224
+ self.patch_embed.eval()
225
+ for param in self.patch_embed.parameters():
226
+ param.requires_grad = False
227
+
228
+ if self.frozen_stages >= 2:
229
+ for i in range(0, self.frozen_stages - 1):
230
+ m = self.network[i]
231
+ m.eval()
232
+ for param in m.parameters():
233
+ param.requires_grad = False
234
+
235
+ def train(self, mode=True):
236
+ super(DiNAT, self).train(mode)
237
+ self._freeze_stages()
238
+
239
+ def forward_embeddings(self, x):
240
+ x = self.patch_embed(x)
241
+ return x
242
+
243
+ def forward_tokens(self, x):
244
+ outs = {}
245
+ for idx, level in enumerate(self.levels):
246
+ x, xo = level(x)
247
+ if idx in self.out_indices:
248
+ norm_layer = getattr(self, f'norm{idx}')
249
+ x_out = norm_layer(xo)
250
+ outs["res{}".format(idx + 2)] = x_out.permute(0, 3, 1, 2).contiguous()
251
+ return outs
252
+
253
+ def forward(self, x):
254
+ x = self.forward_embeddings(x)
255
+ return self.forward_tokens(x)
256
+
257
+
258
+ @BACKBONE_REGISTRY.register()
259
+ class D2DiNAT(DiNAT, Backbone):
260
+ def __init__(self, cfg, input_shape):
261
+
262
+ embed_dim = cfg.MODEL.DiNAT.EMBED_DIM
263
+ mlp_ratio = cfg.MODEL.DiNAT.MLP_RATIO
264
+ depths = cfg.MODEL.DiNAT.DEPTHS
265
+ num_heads = cfg.MODEL.DiNAT.NUM_HEADS
266
+ drop_path_rate = cfg.MODEL.DiNAT.DROP_PATH_RATE
267
+ kernel_size = cfg.MODEL.DiNAT.KERNEL_SIZE
268
+ out_indices = cfg.MODEL.DiNAT.OUT_INDICES
269
+ dilations = cfg.MODEL.DiNAT.DILATIONS
270
+
271
+ super().__init__(
272
+ embed_dim=embed_dim,
273
+ mlp_ratio=mlp_ratio,
274
+ depths=depths,
275
+ num_heads=num_heads,
276
+ drop_path_rate=drop_path_rate,
277
+ kernel_size=kernel_size,
278
+ out_indices=out_indices,
279
+ dilations=dilations,
280
+ )
281
+
282
+ self._out_features = cfg.MODEL.DiNAT.OUT_FEATURES
283
+
284
+ self._out_feature_strides = {
285
+ "res2": 4,
286
+ "res3": 8,
287
+ "res4": 16,
288
+ "res5": 32,
289
+ }
290
+ self._out_feature_channels = {
291
+ "res2": self.num_features[0],
292
+ "res3": self.num_features[1],
293
+ "res4": self.num_features[2],
294
+ "res5": self.num_features[3],
295
+ }
296
+
297
+ def forward(self, x):
298
+ """
299
+ Args:
300
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
301
+ Returns:
302
+ dict[str->Tensor]: names and the corresponding features
303
+ """
304
+ assert (
305
+ x.dim() == 4
306
+ ), f"DiNAT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
307
+ outputs = {}
308
+ y = super().forward(x)
309
+ for k in y.keys():
310
+ if k in self._out_features:
311
+ outputs[k] = y[k]
312
+ return outputs
313
+
314
+ def output_shape(self):
315
+ return {
316
+ name: ShapeSpec(
317
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
318
+ )
319
+ for name in self._out_features
320
+ }
321
+
322
+ @property
323
+ def size_divisibility(self):
324
+ return 32
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/backbone/swin.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # ------------------------------------------------------------------------------
9
+ # Reference: https://github.com/facebookresearch/Mask2Former
10
+ # ------------------------------------------------------------------------------
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint as checkpoint
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ from annotator.oneformer.detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
20
+
21
+
22
+ class Mlp(nn.Module):
23
+ """Multilayer perceptron."""
24
+
25
+ def __init__(
26
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
27
+ ):
28
+ super().__init__()
29
+ out_features = out_features or in_features
30
+ hidden_features = hidden_features or in_features
31
+ self.fc1 = nn.Linear(in_features, hidden_features)
32
+ self.act = act_layer()
33
+ self.fc2 = nn.Linear(hidden_features, out_features)
34
+ self.drop = nn.Dropout(drop)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+ x = self.act(x)
39
+ x = self.drop(x)
40
+ x = self.fc2(x)
41
+ x = self.drop(x)
42
+ return x
43
+
44
+
45
+ def window_partition(x, window_size):
46
+ """
47
+ Args:
48
+ x: (B, H, W, C)
49
+ window_size (int): window size
50
+ Returns:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ """
53
+ B, H, W, C = x.shape
54
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
55
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
56
+ return windows
57
+
58
+
59
+ def window_reverse(windows, window_size, H, W):
60
+ """
61
+ Args:
62
+ windows: (num_windows*B, window_size, window_size, C)
63
+ window_size (int): Window size
64
+ H (int): Height of image
65
+ W (int): Width of image
66
+ Returns:
67
+ x: (B, H, W, C)
68
+ """
69
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
70
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
71
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
72
+ return x
73
+
74
+
75
+ class WindowAttention(nn.Module):
76
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
77
+ It supports both of shifted and non-shifted window.
78
+ Args:
79
+ dim (int): Number of input channels.
80
+ window_size (tuple[int]): The height and width of the window.
81
+ num_heads (int): Number of attention heads.
82
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
83
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
84
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
85
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ dim,
91
+ window_size,
92
+ num_heads,
93
+ qkv_bias=True,
94
+ qk_scale=None,
95
+ attn_drop=0.0,
96
+ proj_drop=0.0,
97
+ ):
98
+
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.window_size = window_size # Wh, Ww
102
+ self.num_heads = num_heads
103
+ head_dim = dim // num_heads
104
+ self.scale = qk_scale or head_dim ** -0.5
105
+
106
+ # define a parameter table of relative position bias
107
+ self.relative_position_bias_table = nn.Parameter(
108
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
109
+ ) # 2*Wh-1 * 2*Ww-1, nH
110
+
111
+ # get pair-wise relative position index for each token inside the window
112
+ coords_h = torch.arange(self.window_size[0])
113
+ coords_w = torch.arange(self.window_size[1])
114
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
115
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
116
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
117
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
118
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
119
+ relative_coords[:, :, 1] += self.window_size[1] - 1
120
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
121
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
122
+ self.register_buffer("relative_position_index", relative_position_index)
123
+
124
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
125
+ self.attn_drop = nn.Dropout(attn_drop)
126
+ self.proj = nn.Linear(dim, dim)
127
+ self.proj_drop = nn.Dropout(proj_drop)
128
+
129
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
130
+ self.softmax = nn.Softmax(dim=-1)
131
+
132
+ def forward(self, x, mask=None):
133
+ """Forward function.
134
+ Args:
135
+ x: input features with shape of (num_windows*B, N, C)
136
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
137
+ """
138
+ B_, N, C = x.shape
139
+ qkv = (
140
+ self.qkv(x)
141
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
142
+ .permute(2, 0, 3, 1, 4)
143
+ )
144
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
145
+
146
+ q = q * self.scale
147
+ attn = q @ k.transpose(-2, -1)
148
+
149
+ relative_position_bias = self.relative_position_bias_table[
150
+ self.relative_position_index.view(-1)
151
+ ].view(
152
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
153
+ ) # Wh*Ww,Wh*Ww,nH
154
+ relative_position_bias = relative_position_bias.permute(
155
+ 2, 0, 1
156
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
157
+ attn = attn + relative_position_bias.unsqueeze(0)
158
+
159
+ if mask is not None:
160
+ nW = mask.shape[0]
161
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
162
+ attn = attn.view(-1, self.num_heads, N, N)
163
+ attn = self.softmax(attn)
164
+ else:
165
+ attn = self.softmax(attn)
166
+
167
+ attn = self.attn_drop(attn)
168
+
169
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
170
+ x = self.proj(x)
171
+ x = self.proj_drop(x)
172
+ return x
173
+
174
+
175
+ class SwinTransformerBlock(nn.Module):
176
+ """Swin Transformer Block.
177
+ Args:
178
+ dim (int): Number of input channels.
179
+ num_heads (int): Number of attention heads.
180
+ window_size (int): Window size.
181
+ shift_size (int): Shift size for SW-MSA.
182
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
183
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
184
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
185
+ drop (float, optional): Dropout rate. Default: 0.0
186
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
187
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
188
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
189
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ dim,
195
+ num_heads,
196
+ window_size=7,
197
+ shift_size=0,
198
+ mlp_ratio=4.0,
199
+ qkv_bias=True,
200
+ qk_scale=None,
201
+ drop=0.0,
202
+ attn_drop=0.0,
203
+ drop_path=0.0,
204
+ act_layer=nn.GELU,
205
+ norm_layer=nn.LayerNorm,
206
+ ):
207
+ super().__init__()
208
+ self.dim = dim
209
+ self.num_heads = num_heads
210
+ self.window_size = window_size
211
+ self.shift_size = shift_size
212
+ self.mlp_ratio = mlp_ratio
213
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
214
+
215
+ self.norm1 = norm_layer(dim)
216
+ self.attn = WindowAttention(
217
+ dim,
218
+ window_size=to_2tuple(self.window_size),
219
+ num_heads=num_heads,
220
+ qkv_bias=qkv_bias,
221
+ qk_scale=qk_scale,
222
+ attn_drop=attn_drop,
223
+ proj_drop=drop,
224
+ )
225
+
226
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
227
+ self.norm2 = norm_layer(dim)
228
+ mlp_hidden_dim = int(dim * mlp_ratio)
229
+ self.mlp = Mlp(
230
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
231
+ )
232
+
233
+ self.H = None
234
+ self.W = None
235
+
236
+ def forward(self, x, mask_matrix):
237
+ """Forward function.
238
+ Args:
239
+ x: Input feature, tensor size (B, H*W, C).
240
+ H, W: Spatial resolution of the input feature.
241
+ mask_matrix: Attention mask for cyclic shift.
242
+ """
243
+ B, L, C = x.shape
244
+ H, W = self.H, self.W
245
+ assert L == H * W, "input feature has wrong size"
246
+
247
+ shortcut = x
248
+ x = self.norm1(x)
249
+ x = x.view(B, H, W, C)
250
+
251
+ # pad feature maps to multiples of window size
252
+ pad_l = pad_t = 0
253
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
254
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
255
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
256
+ _, Hp, Wp, _ = x.shape
257
+
258
+ # cyclic shift
259
+ if self.shift_size > 0:
260
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
261
+ attn_mask = mask_matrix
262
+ else:
263
+ shifted_x = x
264
+ attn_mask = None
265
+
266
+ # partition windows
267
+ x_windows = window_partition(
268
+ shifted_x, self.window_size
269
+ ) # nW*B, window_size, window_size, C
270
+ x_windows = x_windows.view(
271
+ -1, self.window_size * self.window_size, C
272
+ ) # nW*B, window_size*window_size, C
273
+
274
+ # W-MSA/SW-MSA
275
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
276
+
277
+ # merge windows
278
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
279
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
280
+
281
+ # reverse cyclic shift
282
+ if self.shift_size > 0:
283
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
284
+ else:
285
+ x = shifted_x
286
+
287
+ if pad_r > 0 or pad_b > 0:
288
+ x = x[:, :H, :W, :].contiguous()
289
+
290
+ x = x.view(B, H * W, C)
291
+
292
+ # FFN
293
+ x = shortcut + self.drop_path(x)
294
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
295
+
296
+ return x
297
+
298
+
299
+ class PatchMerging(nn.Module):
300
+ """Patch Merging Layer
301
+ Args:
302
+ dim (int): Number of input channels.
303
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
304
+ """
305
+
306
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
307
+ super().__init__()
308
+ self.dim = dim
309
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
310
+ self.norm = norm_layer(4 * dim)
311
+
312
+ def forward(self, x, H, W):
313
+ """Forward function.
314
+ Args:
315
+ x: Input feature, tensor size (B, H*W, C).
316
+ H, W: Spatial resolution of the input feature.
317
+ """
318
+ B, L, C = x.shape
319
+ assert L == H * W, "input feature has wrong size"
320
+
321
+ x = x.view(B, H, W, C)
322
+
323
+ # padding
324
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
325
+ if pad_input:
326
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
327
+
328
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
329
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
330
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
331
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
332
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
333
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
334
+
335
+ x = self.norm(x)
336
+ x = self.reduction(x)
337
+
338
+ return x
339
+
340
+
341
+ class BasicLayer(nn.Module):
342
+ """A basic Swin Transformer layer for one stage.
343
+ Args:
344
+ dim (int): Number of feature channels
345
+ depth (int): Depths of this stage.
346
+ num_heads (int): Number of attention head.
347
+ window_size (int): Local window size. Default: 7.
348
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
349
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
350
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
351
+ drop (float, optional): Dropout rate. Default: 0.0
352
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
353
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
354
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
355
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
356
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
357
+ """
358
+
359
+ def __init__(
360
+ self,
361
+ dim,
362
+ depth,
363
+ num_heads,
364
+ window_size=7,
365
+ mlp_ratio=4.0,
366
+ qkv_bias=True,
367
+ qk_scale=None,
368
+ drop=0.0,
369
+ attn_drop=0.0,
370
+ drop_path=0.0,
371
+ norm_layer=nn.LayerNorm,
372
+ downsample=None,
373
+ use_checkpoint=False,
374
+ ):
375
+ super().__init__()
376
+ self.window_size = window_size
377
+ self.shift_size = window_size // 2
378
+ self.depth = depth
379
+ self.use_checkpoint = use_checkpoint
380
+
381
+ # build blocks
382
+ self.blocks = nn.ModuleList(
383
+ [
384
+ SwinTransformerBlock(
385
+ dim=dim,
386
+ num_heads=num_heads,
387
+ window_size=window_size,
388
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
389
+ mlp_ratio=mlp_ratio,
390
+ qkv_bias=qkv_bias,
391
+ qk_scale=qk_scale,
392
+ drop=drop,
393
+ attn_drop=attn_drop,
394
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
395
+ norm_layer=norm_layer,
396
+ )
397
+ for i in range(depth)
398
+ ]
399
+ )
400
+
401
+ # patch merging layer
402
+ if downsample is not None:
403
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
404
+ else:
405
+ self.downsample = None
406
+
407
+ def forward(self, x, H, W):
408
+ """Forward function.
409
+ Args:
410
+ x: Input feature, tensor size (B, H*W, C).
411
+ H, W: Spatial resolution of the input feature.
412
+ """
413
+
414
+ # calculate attention mask for SW-MSA
415
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
416
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
417
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
418
+ h_slices = (
419
+ slice(0, -self.window_size),
420
+ slice(-self.window_size, -self.shift_size),
421
+ slice(-self.shift_size, None),
422
+ )
423
+ w_slices = (
424
+ slice(0, -self.window_size),
425
+ slice(-self.window_size, -self.shift_size),
426
+ slice(-self.shift_size, None),
427
+ )
428
+ cnt = 0
429
+ for h in h_slices:
430
+ for w in w_slices:
431
+ img_mask[:, h, w, :] = cnt
432
+ cnt += 1
433
+
434
+ mask_windows = window_partition(
435
+ img_mask, self.window_size
436
+ ) # nW, window_size, window_size, 1
437
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
438
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
439
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
440
+ attn_mask == 0, float(0.0)
441
+ )
442
+
443
+ for blk in self.blocks:
444
+ blk.H, blk.W = H, W
445
+ if self.use_checkpoint:
446
+ x = checkpoint.checkpoint(blk, x, attn_mask)
447
+ else:
448
+ x = blk(x, attn_mask)
449
+ if self.downsample is not None:
450
+ x_down = self.downsample(x, H, W)
451
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
452
+ return x, H, W, x_down, Wh, Ww
453
+ else:
454
+ return x, H, W, x, H, W
455
+
456
+
457
+ class PatchEmbed(nn.Module):
458
+ """Image to Patch Embedding
459
+ Args:
460
+ patch_size (int): Patch token size. Default: 4.
461
+ in_chans (int): Number of input image channels. Default: 3.
462
+ embed_dim (int): Number of linear projection output channels. Default: 96.
463
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
464
+ """
465
+
466
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
467
+ super().__init__()
468
+ patch_size = to_2tuple(patch_size)
469
+ self.patch_size = patch_size
470
+
471
+ self.in_chans = in_chans
472
+ self.embed_dim = embed_dim
473
+
474
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
475
+ if norm_layer is not None:
476
+ self.norm = norm_layer(embed_dim)
477
+ else:
478
+ self.norm = None
479
+
480
+ def forward(self, x):
481
+ """Forward function."""
482
+ # padding
483
+ _, _, H, W = x.size()
484
+ if W % self.patch_size[1] != 0:
485
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
486
+ if H % self.patch_size[0] != 0:
487
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
488
+
489
+ x = self.proj(x) # B C Wh Ww
490
+ if self.norm is not None:
491
+ Wh, Ww = x.size(2), x.size(3)
492
+ x = x.flatten(2).transpose(1, 2)
493
+ x = self.norm(x)
494
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
495
+
496
+ return x
497
+
498
+
499
+ class SwinTransformer(nn.Module):
500
+ """Swin Transformer backbone.
501
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
502
+ https://arxiv.org/pdf/2103.14030
503
+ Args:
504
+ pretrain_img_size (int): Input image size for training the pretrained model,
505
+ used in absolute postion embedding. Default 224.
506
+ patch_size (int | tuple(int)): Patch size. Default: 4.
507
+ in_chans (int): Number of input image channels. Default: 3.
508
+ embed_dim (int): Number of linear projection output channels. Default: 96.
509
+ depths (tuple[int]): Depths of each Swin Transformer stage.
510
+ num_heads (tuple[int]): Number of attention head of each stage.
511
+ window_size (int): Window size. Default: 7.
512
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
513
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
514
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
515
+ drop_rate (float): Dropout rate.
516
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
517
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
518
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
519
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
520
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
521
+ out_indices (Sequence[int]): Output from which stages.
522
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
523
+ -1 means not freezing any parameters.
524
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
525
+ """
526
+
527
+ def __init__(
528
+ self,
529
+ pretrain_img_size=224,
530
+ patch_size=4,
531
+ in_chans=3,
532
+ embed_dim=96,
533
+ depths=[2, 2, 6, 2],
534
+ num_heads=[3, 6, 12, 24],
535
+ window_size=7,
536
+ mlp_ratio=4.0,
537
+ qkv_bias=True,
538
+ qk_scale=None,
539
+ drop_rate=0.0,
540
+ attn_drop_rate=0.0,
541
+ drop_path_rate=0.2,
542
+ norm_layer=nn.LayerNorm,
543
+ ape=False,
544
+ patch_norm=True,
545
+ out_indices=(0, 1, 2, 3),
546
+ frozen_stages=-1,
547
+ use_checkpoint=False,
548
+ ):
549
+ super().__init__()
550
+
551
+ self.pretrain_img_size = pretrain_img_size
552
+ self.num_layers = len(depths)
553
+ self.embed_dim = embed_dim
554
+ self.ape = ape
555
+ self.patch_norm = patch_norm
556
+ self.out_indices = out_indices
557
+ self.frozen_stages = frozen_stages
558
+
559
+ # split image into non-overlapping patches
560
+ self.patch_embed = PatchEmbed(
561
+ patch_size=patch_size,
562
+ in_chans=in_chans,
563
+ embed_dim=embed_dim,
564
+ norm_layer=norm_layer if self.patch_norm else None,
565
+ )
566
+
567
+ # absolute position embedding
568
+ if self.ape:
569
+ pretrain_img_size = to_2tuple(pretrain_img_size)
570
+ patch_size = to_2tuple(patch_size)
571
+ patches_resolution = [
572
+ pretrain_img_size[0] // patch_size[0],
573
+ pretrain_img_size[1] // patch_size[1],
574
+ ]
575
+
576
+ self.absolute_pos_embed = nn.Parameter(
577
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
578
+ )
579
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
580
+
581
+ self.pos_drop = nn.Dropout(p=drop_rate)
582
+
583
+ # stochastic depth
584
+ dpr = [
585
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
586
+ ] # stochastic depth decay rule
587
+
588
+ # build layers
589
+ self.layers = nn.ModuleList()
590
+ for i_layer in range(self.num_layers):
591
+ layer = BasicLayer(
592
+ dim=int(embed_dim * 2 ** i_layer),
593
+ depth=depths[i_layer],
594
+ num_heads=num_heads[i_layer],
595
+ window_size=window_size,
596
+ mlp_ratio=mlp_ratio,
597
+ qkv_bias=qkv_bias,
598
+ qk_scale=qk_scale,
599
+ drop=drop_rate,
600
+ attn_drop=attn_drop_rate,
601
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
602
+ norm_layer=norm_layer,
603
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
604
+ use_checkpoint=use_checkpoint,
605
+ )
606
+ self.layers.append(layer)
607
+
608
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
609
+ self.num_features = num_features
610
+
611
+ # add a norm layer for each output
612
+ for i_layer in out_indices:
613
+ layer = norm_layer(num_features[i_layer])
614
+ layer_name = f"norm{i_layer}"
615
+ self.add_module(layer_name, layer)
616
+
617
+ self._freeze_stages()
618
+
619
+ def _freeze_stages(self):
620
+ if self.frozen_stages >= 0:
621
+ self.patch_embed.eval()
622
+ for param in self.patch_embed.parameters():
623
+ param.requires_grad = False
624
+
625
+ if self.frozen_stages >= 1 and self.ape:
626
+ self.absolute_pos_embed.requires_grad = False
627
+
628
+ if self.frozen_stages >= 2:
629
+ self.pos_drop.eval()
630
+ for i in range(0, self.frozen_stages - 1):
631
+ m = self.layers[i]
632
+ m.eval()
633
+ for param in m.parameters():
634
+ param.requires_grad = False
635
+
636
+ def init_weights(self, pretrained=None):
637
+ """Initialize the weights in backbone.
638
+ Args:
639
+ pretrained (str, optional): Path to pre-trained weights.
640
+ Defaults to None.
641
+ """
642
+
643
+ def _init_weights(m):
644
+ if isinstance(m, nn.Linear):
645
+ trunc_normal_(m.weight, std=0.02)
646
+ if isinstance(m, nn.Linear) and m.bias is not None:
647
+ nn.init.constant_(m.bias, 0)
648
+ elif isinstance(m, nn.LayerNorm):
649
+ nn.init.constant_(m.bias, 0)
650
+ nn.init.constant_(m.weight, 1.0)
651
+
652
+ def forward(self, x):
653
+ """Forward function."""
654
+ x = self.patch_embed(x)
655
+
656
+ Wh, Ww = x.size(2), x.size(3)
657
+ if self.ape:
658
+ # interpolate the position embedding to the corresponding size
659
+ absolute_pos_embed = F.interpolate(
660
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
661
+ )
662
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
663
+ else:
664
+ x = x.flatten(2).transpose(1, 2)
665
+ x = self.pos_drop(x)
666
+
667
+ outs = {}
668
+ for i in range(self.num_layers):
669
+ layer = self.layers[i]
670
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
671
+
672
+ if i in self.out_indices:
673
+ norm_layer = getattr(self, f"norm{i}")
674
+ x_out = norm_layer(x_out)
675
+
676
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
677
+ outs["res{}".format(i + 2)] = out
678
+
679
+ return outs
680
+
681
+ def train(self, mode=True):
682
+ """Convert the model into training mode while keep layers freezed."""
683
+ super(SwinTransformer, self).train(mode)
684
+ self._freeze_stages()
685
+
686
+
687
+ @BACKBONE_REGISTRY.register()
688
+ class D2SwinTransformer(SwinTransformer, Backbone):
689
+ def __init__(self, cfg, input_shape):
690
+
691
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
692
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
693
+ in_chans = 3
694
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
695
+ depths = cfg.MODEL.SWIN.DEPTHS
696
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
697
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
698
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
699
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
700
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
701
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
702
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
703
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
704
+ norm_layer = nn.LayerNorm
705
+ ape = cfg.MODEL.SWIN.APE
706
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
707
+ use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT
708
+
709
+ super().__init__(
710
+ pretrain_img_size,
711
+ patch_size,
712
+ in_chans,
713
+ embed_dim,
714
+ depths,
715
+ num_heads,
716
+ window_size,
717
+ mlp_ratio,
718
+ qkv_bias,
719
+ qk_scale,
720
+ drop_rate,
721
+ attn_drop_rate,
722
+ drop_path_rate,
723
+ norm_layer,
724
+ ape,
725
+ patch_norm,
726
+ use_checkpoint=use_checkpoint,
727
+ )
728
+
729
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
730
+
731
+ self._out_feature_strides = {
732
+ "res2": 4,
733
+ "res3": 8,
734
+ "res4": 16,
735
+ "res5": 32,
736
+ }
737
+ self._out_feature_channels = {
738
+ "res2": self.num_features[0],
739
+ "res3": self.num_features[1],
740
+ "res4": self.num_features[2],
741
+ "res5": self.num_features[3],
742
+ }
743
+
744
+ def forward(self, x):
745
+ """
746
+ Args:
747
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
748
+ Returns:
749
+ dict[str->Tensor]: names and the corresponding features
750
+ """
751
+ assert (
752
+ x.dim() == 4
753
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
754
+ outputs = {}
755
+ y = super().forward(x)
756
+ for k in y.keys():
757
+ if k in self._out_features:
758
+ outputs[k] = y[k]
759
+ return outputs
760
+
761
+ def output_shape(self):
762
+ return {
763
+ name: ShapeSpec(
764
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
765
+ )
766
+ for name in self._out_features
767
+ }
768
+
769
+ @property
770
+ def size_divisibility(self):
771
+ return 32
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/matcher.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ """
7
+ Modules to compute the matching cost and solve the corresponding LSAP.
8
+ """
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from scipy.optimize import linear_sum_assignment
12
+ from torch import nn
13
+ from torch.cuda.amp import autocast
14
+ import numpy as np
15
+
16
+ # from annotator.oneformer.detectron2.projects.point_rend.point_features import point_sample
17
+
18
+
19
+ def linear_sum_assignment_with_nan(cost_matrix):
20
+ cost_matrix = np.asarray(cost_matrix)
21
+ nan = np.isnan(cost_matrix).any()
22
+ nan_all = np.isnan(cost_matrix).all()
23
+ empty = cost_matrix.size == 0
24
+
25
+ if not empty:
26
+ if nan_all:
27
+ print('Matrix contains all NaN values!')
28
+ elif nan:
29
+ print('Matrix contains NaN values!')
30
+
31
+ if nan_all:
32
+ cost_matrix = np.empty(shape=(0, 0))
33
+ elif nan:
34
+ cost_matrix[np.isnan(cost_matrix)] = 100
35
+
36
+ return linear_sum_assignment(cost_matrix)
37
+
38
+ def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
39
+ """
40
+ Compute the DICE loss, similar to generalized IOU for masks
41
+ Args:
42
+ inputs: A float tensor of arbitrary shape.
43
+ The predictions for each example.
44
+ targets: A float tensor with the same shape as inputs. Stores the binary
45
+ classification label for each element in inputs
46
+ (0 for the negative class and 1 for the positive class).
47
+ """
48
+ inputs = inputs.sigmoid()
49
+ inputs = inputs.flatten(1)
50
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
51
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
52
+ loss = 1 - (numerator + 1) / (denominator + 1)
53
+ return loss
54
+
55
+
56
+ batch_dice_loss_jit = torch.jit.script(
57
+ batch_dice_loss
58
+ ) # type: torch.jit.ScriptModule
59
+
60
+
61
+ def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
62
+ """
63
+ Args:
64
+ inputs: A float tensor of arbitrary shape.
65
+ The predictions for each example.
66
+ targets: A float tensor with the same shape as inputs. Stores the binary
67
+ classification label for each element in inputs
68
+ (0 for the negative class and 1 for the positive class).
69
+ Returns:
70
+ Loss tensor
71
+ """
72
+ hw = inputs.shape[1]
73
+
74
+ pos = F.binary_cross_entropy_with_logits(
75
+ inputs, torch.ones_like(inputs), reduction="none"
76
+ )
77
+ neg = F.binary_cross_entropy_with_logits(
78
+ inputs, torch.zeros_like(inputs), reduction="none"
79
+ )
80
+
81
+ loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
82
+ "nc,mc->nm", neg, (1 - targets)
83
+ )
84
+
85
+ return loss / hw
86
+
87
+
88
+ batch_sigmoid_ce_loss_jit = torch.jit.script(
89
+ batch_sigmoid_ce_loss
90
+ ) # type: torch.jit.ScriptModule
91
+
92
+
93
+ class HungarianMatcher(nn.Module):
94
+ """This class computes an assignment between the targets and the predictions of the network
95
+
96
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
97
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
98
+ while the others are un-matched (and thus treated as non-objects).
99
+ """
100
+
101
+ def __init__(self, cost_class: float = 1, cost_mask: float = 1,
102
+ cost_dice: float = 1, num_points: int = 0):
103
+ """Creates the matcher
104
+
105
+ Params:
106
+ cost_class: This is the relative weight of the classification error in the matching cost
107
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
108
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
109
+ """
110
+ super().__init__()
111
+ self.cost_class = cost_class
112
+ self.cost_mask = cost_mask
113
+ self.cost_dice = cost_dice
114
+
115
+ assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
116
+
117
+ self.num_points = num_points
118
+
119
+ @torch.no_grad()
120
+ def memory_efficient_forward(self, outputs, targets):
121
+ """More memory-friendly matching"""
122
+ bs, num_queries = outputs["pred_logits"].shape[:2]
123
+
124
+ indices = []
125
+
126
+ # Iterate through batch size
127
+ for b in range(bs):
128
+ out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
129
+ tgt_ids = targets[b]["labels"]
130
+
131
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
132
+ # but approximate it in 1 - proba[target class].
133
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
134
+ cost_class = -out_prob[:, tgt_ids]
135
+
136
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
137
+ # gt masks are already padded when preparing target
138
+ tgt_mask = targets[b]["masks"].to(out_mask)
139
+
140
+ out_mask = out_mask[:, None]
141
+ tgt_mask = tgt_mask[:, None]
142
+ # all masks share the same set of points for efficient matching!
143
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device)
144
+ # get gt labels
145
+ tgt_mask = point_sample(
146
+ tgt_mask,
147
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
148
+ align_corners=False,
149
+ ).squeeze(1)
150
+
151
+ out_mask = point_sample(
152
+ out_mask,
153
+ point_coords.repeat(out_mask.shape[0], 1, 1),
154
+ align_corners=False,
155
+ ).squeeze(1)
156
+
157
+ with autocast(enabled=False):
158
+ out_mask = out_mask.float()
159
+ tgt_mask = tgt_mask.float()
160
+ # Compute the focal loss between masks
161
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
162
+ # Compute the dice loss betwen masks
163
+ cost_dice = batch_dice_loss(out_mask, tgt_mask)
164
+
165
+ # Final cost matrix
166
+ C = (
167
+ self.cost_mask * cost_mask
168
+ + self.cost_class * cost_class
169
+ + self.cost_dice * cost_dice
170
+ )
171
+ C = C.reshape(num_queries, -1).cpu()
172
+
173
+ indices.append(linear_sum_assignment_with_nan(C))
174
+
175
+ return [
176
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
177
+ for i, j in indices
178
+ ]
179
+
180
+ @torch.no_grad()
181
+ def forward(self, outputs, targets):
182
+ """Performs the matching
183
+
184
+ Params:
185
+ outputs: This is a dict that contains at least these entries:
186
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
187
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
188
+
189
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
190
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
191
+ objects in the target) containing the class labels
192
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
193
+
194
+ Returns:
195
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
196
+ - index_i is the indices of the selected predictions (in order)
197
+ - index_j is the indices of the corresponding selected targets (in order)
198
+ For each batch element, it holds:
199
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
200
+ """
201
+
202
+ return self.memory_efficient_forward(outputs, targets)
203
+
204
+ def __repr__(self, _repr_indent=4):
205
+ head = "Matcher " + self.__class__.__name__
206
+ body = [
207
+ "cost_class: {}".format(self.cost_class),
208
+ "cost_mask: {}".format(self.cost_mask),
209
+ "cost_dice: {}".format(self.cost_dice),
210
+ ]
211
+ lines = [head] + [" " * _repr_indent + line for line in body]
212
+ return "\n".join(lines)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/meta_arch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/meta_arch/oneformer_head.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/meta_arch/mask_former_head.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import logging
7
+ from copy import deepcopy
8
+ from typing import Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import fvcore.nn.weight_init as weight_init
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+
14
+ from annotator.oneformer.detectron2.config import configurable
15
+ from annotator.oneformer.detectron2.layers import Conv2d, ShapeSpec, get_norm
16
+ from annotator.oneformer.detectron2.modeling import SEM_SEG_HEADS_REGISTRY
17
+ from ..pixel_decoder.fpn import build_pixel_decoder
18
+ from ..transformer_decoder.oneformer_transformer_decoder import build_transformer_decoder
19
+
20
+ @SEM_SEG_HEADS_REGISTRY.register()
21
+ class OneFormerHead(nn.Module):
22
+
23
+ _version = 2
24
+
25
+ def _load_from_state_dict(
26
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
27
+ ):
28
+ version = local_metadata.get("version", None)
29
+ if version is None or version < 2:
30
+ # Do not warn if train from scratch
31
+ scratch = True
32
+ logger = logging.getLogger(__name__)
33
+ for k in list(state_dict.keys()):
34
+ newk = k
35
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
36
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
37
+ # logger.debug(f"{k} ==> {newk}")
38
+ if newk != k:
39
+ state_dict[newk] = state_dict[k]
40
+ del state_dict[k]
41
+ scratch = False
42
+
43
+ if not scratch:
44
+ logger.warning(
45
+ f"Weight format of {self.__class__.__name__} have changed! "
46
+ "Please upgrade your models. Applying automatic conversion now ..."
47
+ )
48
+
49
+ @configurable
50
+ def __init__(
51
+ self,
52
+ input_shape: Dict[str, ShapeSpec],
53
+ *,
54
+ num_classes: int,
55
+ pixel_decoder: nn.Module,
56
+ loss_weight: float = 1.0,
57
+ ignore_value: int = -1,
58
+ # extra parameters
59
+ transformer_predictor: nn.Module,
60
+ transformer_in_feature: str,
61
+ ):
62
+ """
63
+ NOTE: this interface is experimental.
64
+ Args:
65
+ input_shape: shapes (channels and stride) of the input features
66
+ num_classes: number of classes to predict
67
+ pixel_decoder: the pixel decoder module
68
+ loss_weight: loss weight
69
+ ignore_value: category id to be ignored during training.
70
+ transformer_predictor: the transformer decoder that makes prediction
71
+ transformer_in_feature: input feature name to the transformer_predictor
72
+ """
73
+ super().__init__()
74
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
75
+ self.in_features = [k for k, v in input_shape]
76
+ feature_strides = [v.stride for k, v in input_shape]
77
+ feature_channels = [v.channels for k, v in input_shape]
78
+
79
+ self.ignore_value = ignore_value
80
+ self.common_stride = 4
81
+ self.loss_weight = loss_weight
82
+
83
+ self.pixel_decoder = pixel_decoder
84
+ self.predictor = transformer_predictor
85
+ self.transformer_in_feature = transformer_in_feature
86
+
87
+ self.num_classes = num_classes
88
+
89
+ @classmethod
90
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
91
+ # figure out in_channels to transformer predictor
92
+ if cfg.MODEL.ONE_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
93
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
94
+ elif cfg.MODEL.ONE_FORMER.TRANSFORMER_IN_FEATURE == "pixel_embedding":
95
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
96
+ elif cfg.MODEL.ONE_FORMER.TRANSFORMER_IN_FEATURE == "multi_scale_pixel_decoder":
97
+ transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
98
+ else:
99
+ transformer_predictor_in_channels = input_shape[cfg.MODEL.ONE_FORMER.TRANSFORMER_IN_FEATURE].channels
100
+
101
+ return {
102
+ "input_shape": {
103
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
104
+ },
105
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
106
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
107
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
108
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
109
+ "transformer_in_feature": cfg.MODEL.ONE_FORMER.TRANSFORMER_IN_FEATURE,
110
+ "transformer_predictor": build_transformer_decoder(
111
+ cfg,
112
+ transformer_predictor_in_channels,
113
+ mask_classification=True,
114
+ ),
115
+ }
116
+
117
+ def forward(self, features, tasks, mask=None):
118
+ return self.layers(features, tasks, mask)
119
+
120
+ def layers(self, features, tasks, mask=None):
121
+ mask_features, transformer_encoder_features, multi_scale_features, _, _ = self.pixel_decoder.forward_features(features)
122
+
123
+ if self.transformer_in_feature == "multi_scale_pixel_decoder":
124
+ predictions = self.predictor(multi_scale_features, mask_features, tasks, mask)
125
+ else:
126
+ if self.transformer_in_feature == "transformer_encoder":
127
+ assert (
128
+ transformer_encoder_features is not None
129
+ ), "Please use the TransformerEncoderPixelDecoder."
130
+ predictions = self.predictor(transformer_encoder_features, mask_features, mask)
131
+ elif self.transformer_in_feature == "pixel_embedding":
132
+ predictions = self.predictor(mask_features, mask_features, mask)
133
+ else:
134
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
135
+ return predictions
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/fpn.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11
+ from torch.cuda.amp import autocast
12
+
13
+ from annotator.oneformer.detectron2.config import configurable
14
+ from annotator.oneformer.detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm
15
+ from annotator.oneformer.detectron2.modeling import SEM_SEG_HEADS_REGISTRY
16
+
17
+ from ..transformer_decoder.position_encoding import PositionEmbeddingSine
18
+ from ..transformer_decoder.transformer import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn
19
+
20
+
21
+ def build_pixel_decoder(cfg, input_shape):
22
+ """
23
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
24
+ """
25
+ name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
26
+ model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
27
+ forward_features = getattr(model, "forward_features", None)
28
+ if not callable(forward_features):
29
+ raise ValueError(
30
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
31
+ f"Please implement forward_features for {name} to only return mask features."
32
+ )
33
+ return model
34
+
35
+
36
+ # This is a modified FPN decoder.
37
+ @SEM_SEG_HEADS_REGISTRY.register()
38
+ class BasePixelDecoder(nn.Module):
39
+ @configurable
40
+ def __init__(
41
+ self,
42
+ input_shape: Dict[str, ShapeSpec],
43
+ *,
44
+ conv_dim: int,
45
+ mask_dim: int,
46
+ norm: Optional[Union[str, Callable]] = None,
47
+ ):
48
+ """
49
+ NOTE: this interface is experimental.
50
+ Args:
51
+ input_shape: shapes (channels and stride) of the input features
52
+ conv_dims: number of output channels for the intermediate conv layers.
53
+ mask_dim: number of output channels for the final conv layer.
54
+ norm (str or callable): normalization for all conv layers
55
+ """
56
+ super().__init__()
57
+
58
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
59
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
60
+ feature_channels = [v.channels for k, v in input_shape]
61
+
62
+ lateral_convs = []
63
+ output_convs = []
64
+
65
+ use_bias = norm == ""
66
+ for idx, in_channels in enumerate(feature_channels):
67
+ if idx == len(self.in_features) - 1:
68
+ output_norm = get_norm(norm, conv_dim)
69
+ output_conv = Conv2d(
70
+ in_channels,
71
+ conv_dim,
72
+ kernel_size=3,
73
+ stride=1,
74
+ padding=1,
75
+ bias=use_bias,
76
+ norm=output_norm,
77
+ activation=F.relu,
78
+ )
79
+ weight_init.c2_xavier_fill(output_conv)
80
+ self.add_module("layer_{}".format(idx + 1), output_conv)
81
+
82
+ lateral_convs.append(None)
83
+ output_convs.append(output_conv)
84
+ else:
85
+ lateral_norm = get_norm(norm, conv_dim)
86
+ output_norm = get_norm(norm, conv_dim)
87
+
88
+ lateral_conv = Conv2d(
89
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
90
+ )
91
+ output_conv = Conv2d(
92
+ conv_dim,
93
+ conv_dim,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=use_bias,
98
+ norm=output_norm,
99
+ activation=F.relu,
100
+ )
101
+ weight_init.c2_xavier_fill(lateral_conv)
102
+ weight_init.c2_xavier_fill(output_conv)
103
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
104
+ self.add_module("layer_{}".format(idx + 1), output_conv)
105
+
106
+ lateral_convs.append(lateral_conv)
107
+ output_convs.append(output_conv)
108
+ # Place convs into top-down order (from low to high resolution)
109
+ # to make the top-down computation in forward clearer.
110
+ self.lateral_convs = lateral_convs[::-1]
111
+ self.output_convs = output_convs[::-1]
112
+
113
+ self.mask_dim = mask_dim
114
+ self.mask_features = Conv2d(
115
+ conv_dim,
116
+ mask_dim,
117
+ kernel_size=3,
118
+ stride=1,
119
+ padding=1,
120
+ )
121
+ weight_init.c2_xavier_fill(self.mask_features)
122
+
123
+ self.oneformer_num_feature_levels = 3 # always use 3 scales
124
+
125
+ @classmethod
126
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
127
+ ret = {}
128
+ ret["input_shape"] = {
129
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
130
+ }
131
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
132
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
133
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
134
+ return ret
135
+
136
+ def forward_features(self, features):
137
+ multi_scale_features = []
138
+ num_cur_levels = 0
139
+ # Reverse feature maps into top-down order (from low to high resolution)
140
+ for idx, f in enumerate(self.in_features[::-1]):
141
+ x = features[f]
142
+ lateral_conv = self.lateral_convs[idx]
143
+ output_conv = self.output_convs[idx]
144
+ if lateral_conv is None:
145
+ y = output_conv(x)
146
+ else:
147
+ cur_fpn = lateral_conv(x)
148
+ # Following FPN implementation, we use nearest upsampling here
149
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
150
+ y = output_conv(y)
151
+ if num_cur_levels < self.oneformer_num_feature_levels:
152
+ multi_scale_features.append(y)
153
+ num_cur_levels += 1
154
+ return self.mask_features(y), None, multi_scale_features
155
+
156
+ def forward(self, features, targets=None):
157
+ logger = logging.getLogger(__name__)
158
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
159
+ return self.forward_features(features)
160
+
161
+
162
+ class TransformerEncoderOnly(nn.Module):
163
+ def __init__(
164
+ self,
165
+ d_model=512,
166
+ nhead=8,
167
+ num_encoder_layers=6,
168
+ dim_feedforward=2048,
169
+ dropout=0.1,
170
+ activation="relu",
171
+ normalize_before=False,
172
+ ):
173
+ super().__init__()
174
+
175
+ encoder_layer = TransformerEncoderLayer(
176
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
177
+ )
178
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
179
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
180
+
181
+ self._reset_parameters()
182
+
183
+ self.d_model = d_model
184
+ self.nhead = nhead
185
+
186
+ def _reset_parameters(self):
187
+ for p in self.parameters():
188
+ if p.dim() > 1:
189
+ nn.init.xavier_uniform_(p)
190
+
191
+ def forward(self, src, mask, pos_embed):
192
+ # flatten NxCxHxW to HWxNxC
193
+ bs, c, h, w = src.shape
194
+ src = src.flatten(2).permute(2, 0, 1)
195
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
196
+ if mask is not None:
197
+ mask = mask.flatten(1)
198
+
199
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
200
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
201
+
202
+
203
+ # This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.
204
+ @SEM_SEG_HEADS_REGISTRY.register()
205
+ class TransformerEncoderPixelDecoder(BasePixelDecoder):
206
+ @configurable
207
+ def __init__(
208
+ self,
209
+ input_shape: Dict[str, ShapeSpec],
210
+ *,
211
+ transformer_dropout: float,
212
+ transformer_nheads: int,
213
+ transformer_dim_feedforward: int,
214
+ transformer_enc_layers: int,
215
+ transformer_pre_norm: bool,
216
+ conv_dim: int,
217
+ mask_dim: int,
218
+ norm: Optional[Union[str, Callable]] = None,
219
+ ):
220
+ """
221
+ NOTE: this interface is experimental.
222
+ Args:
223
+ input_shape: shapes (channels and stride) of the input features
224
+ transformer_dropout: dropout probability in transformer
225
+ transformer_nheads: number of heads in transformer
226
+ transformer_dim_feedforward: dimension of feedforward network
227
+ transformer_enc_layers: number of transformer encoder layers
228
+ transformer_pre_norm: whether to use pre-layernorm or not
229
+ conv_dims: number of output channels for the intermediate conv layers.
230
+ mask_dim: number of output channels for the final conv layer.
231
+ norm (str or callable): normalization for all conv layers
232
+ """
233
+ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
234
+
235
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
236
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
237
+ feature_strides = [v.stride for k, v in input_shape]
238
+ feature_channels = [v.channels for k, v in input_shape]
239
+
240
+ in_channels = feature_channels[len(self.in_features) - 1]
241
+ self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
242
+ weight_init.c2_xavier_fill(self.input_proj)
243
+ self.transformer = TransformerEncoderOnly(
244
+ d_model=conv_dim,
245
+ dropout=transformer_dropout,
246
+ nhead=transformer_nheads,
247
+ dim_feedforward=transformer_dim_feedforward,
248
+ num_encoder_layers=transformer_enc_layers,
249
+ normalize_before=transformer_pre_norm,
250
+ )
251
+ N_steps = conv_dim // 2
252
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
253
+
254
+ # update layer
255
+ use_bias = norm == ""
256
+ output_norm = get_norm(norm, conv_dim)
257
+ output_conv = Conv2d(
258
+ conv_dim,
259
+ conv_dim,
260
+ kernel_size=3,
261
+ stride=1,
262
+ padding=1,
263
+ bias=use_bias,
264
+ norm=output_norm,
265
+ activation=F.relu,
266
+ )
267
+ weight_init.c2_xavier_fill(output_conv)
268
+ delattr(self, "layer_{}".format(len(self.in_features)))
269
+ self.add_module("layer_{}".format(len(self.in_features)), output_conv)
270
+ self.output_convs[0] = output_conv
271
+
272
+ @classmethod
273
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
274
+ ret = super().from_config(cfg, input_shape)
275
+ ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
276
+ ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
277
+ ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
278
+ ret[
279
+ "transformer_enc_layers"
280
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
281
+ ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
282
+ return ret
283
+
284
+ def forward_features(self, features):
285
+ multi_scale_features = []
286
+ num_cur_levels = 0
287
+ # Reverse feature maps into top-down order (from low to high resolution)
288
+ for idx, f in enumerate(self.in_features[::-1]):
289
+ x = features[f]
290
+ lateral_conv = self.lateral_convs[idx]
291
+ output_conv = self.output_convs[idx]
292
+ if lateral_conv is None:
293
+ transformer = self.input_proj(x)
294
+ pos = self.pe_layer(x)
295
+ transformer = self.transformer(transformer, None, pos)
296
+ y = output_conv(transformer)
297
+ # save intermediate feature as input to Transformer decoder
298
+ transformer_encoder_features = transformer
299
+ else:
300
+ cur_fpn = lateral_conv(x)
301
+ # Following FPN implementation, we use nearest upsampling here
302
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
303
+ y = output_conv(y)
304
+ if num_cur_levels < self.oneformer_num_feature_levels:
305
+ multi_scale_features.append(y)
306
+ num_cur_levels += 1
307
+ return self.mask_features(y), transformer_encoder_features, multi_scale_features
308
+
309
+ def forward(self, features, targets=None):
310
+ logger = logging.getLogger(__name__)
311
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
312
+ return self.forward_features(features)
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/msdeformattn.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11
+ from torch.cuda.amp import autocast
12
+
13
+ from annotator.oneformer.detectron2.config import configurable
14
+ from annotator.oneformer.detectron2.layers import Conv2d, ShapeSpec, get_norm
15
+ from annotator.oneformer.detectron2.modeling import SEM_SEG_HEADS_REGISTRY
16
+
17
+ from ..transformer_decoder.position_encoding import PositionEmbeddingSine
18
+ from ..transformer_decoder.transformer import _get_clones, _get_activation_fn
19
+ from .ops.modules import MSDeformAttn
20
+
21
+
22
+ # MSDeformAttn Transformer encoder in deformable detr
23
+ class MSDeformAttnTransformerEncoderOnly(nn.Module):
24
+ def __init__(self, d_model=256, nhead=8,
25
+ num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
26
+ activation="relu",
27
+ num_feature_levels=4, enc_n_points=4,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.d_model = d_model
32
+ self.nhead = nhead
33
+
34
+ encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,
35
+ dropout, activation,
36
+ num_feature_levels, nhead, enc_n_points)
37
+ self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)
38
+
39
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
40
+
41
+ self._reset_parameters()
42
+
43
+ def _reset_parameters(self):
44
+ for p in self.parameters():
45
+ if p.dim() > 1:
46
+ nn.init.xavier_uniform_(p)
47
+ for m in self.modules():
48
+ if isinstance(m, MSDeformAttn):
49
+ m._reset_parameters()
50
+ normal_(self.level_embed)
51
+
52
+ def get_valid_ratio(self, mask):
53
+ _, H, W = mask.shape
54
+ valid_H = torch.sum(~mask[:, :, 0], 1)
55
+ valid_W = torch.sum(~mask[:, 0, :], 1)
56
+ valid_ratio_h = valid_H.float() / H
57
+ valid_ratio_w = valid_W.float() / W
58
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
59
+ return valid_ratio
60
+
61
+ def forward(self, srcs, pos_embeds):
62
+ masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
63
+ # prepare input for encoder
64
+ src_flatten = []
65
+ mask_flatten = []
66
+ lvl_pos_embed_flatten = []
67
+ spatial_shapes = []
68
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
69
+ bs, c, h, w = src.shape
70
+ spatial_shape = (h, w)
71
+ spatial_shapes.append(spatial_shape)
72
+ src = src.flatten(2).transpose(1, 2)
73
+ mask = mask.flatten(1)
74
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
75
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
76
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
77
+ src_flatten.append(src)
78
+ mask_flatten.append(mask)
79
+ src_flatten = torch.cat(src_flatten, 1)
80
+ mask_flatten = torch.cat(mask_flatten, 1)
81
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
82
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
83
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
84
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
85
+
86
+ # encoder
87
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
88
+
89
+ return memory, spatial_shapes, level_start_index, valid_ratios
90
+
91
+
92
+ class MSDeformAttnTransformerEncoderLayer(nn.Module):
93
+ def __init__(self,
94
+ d_model=256, d_ffn=1024,
95
+ dropout=0.1, activation="relu",
96
+ n_levels=4, n_heads=8, n_points=4):
97
+ super().__init__()
98
+
99
+ # self attention
100
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
101
+ self.dropout1 = nn.Dropout(dropout)
102
+ self.norm1 = nn.LayerNorm(d_model)
103
+
104
+ # ffn
105
+ self.linear1 = nn.Linear(d_model, d_ffn)
106
+ self.activation = _get_activation_fn(activation)
107
+ self.dropout2 = nn.Dropout(dropout)
108
+ self.linear2 = nn.Linear(d_ffn, d_model)
109
+ self.dropout3 = nn.Dropout(dropout)
110
+ self.norm2 = nn.LayerNorm(d_model)
111
+
112
+ @staticmethod
113
+ def with_pos_embed(tensor, pos):
114
+ return tensor if pos is None else tensor + pos
115
+
116
+ def forward_ffn(self, src):
117
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
118
+ src = src + self.dropout3(src2)
119
+ src = self.norm2(src)
120
+ return src
121
+
122
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
123
+ # self attention
124
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
125
+ src = src + self.dropout1(src2)
126
+ src = self.norm1(src)
127
+
128
+ # ffn
129
+ src = self.forward_ffn(src)
130
+
131
+ return src
132
+
133
+
134
+ class MSDeformAttnTransformerEncoder(nn.Module):
135
+ def __init__(self, encoder_layer, num_layers):
136
+ super().__init__()
137
+ self.layers = _get_clones(encoder_layer, num_layers)
138
+ self.num_layers = num_layers
139
+
140
+ @staticmethod
141
+ def get_reference_points(spatial_shapes, valid_ratios, device):
142
+ reference_points_list = []
143
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
144
+
145
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
146
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
147
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
148
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
149
+ ref = torch.stack((ref_x, ref_y), -1)
150
+ reference_points_list.append(ref)
151
+ reference_points = torch.cat(reference_points_list, 1)
152
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
153
+ return reference_points
154
+
155
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
156
+ output = src
157
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
158
+ for _, layer in enumerate(self.layers):
159
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
160
+
161
+ return output
162
+
163
+
164
+ @SEM_SEG_HEADS_REGISTRY.register()
165
+ class MSDeformAttnPixelDecoder(nn.Module):
166
+ @configurable
167
+ def __init__(
168
+ self,
169
+ input_shape: Dict[str, ShapeSpec],
170
+ *,
171
+ transformer_dropout: float,
172
+ transformer_nheads: int,
173
+ transformer_dim_feedforward: int,
174
+ transformer_enc_layers: int,
175
+ conv_dim: int,
176
+ mask_dim: int,
177
+ norm: Optional[Union[str, Callable]] = None,
178
+ # deformable transformer encoder args
179
+ transformer_in_features: List[str],
180
+ common_stride: int,
181
+ ):
182
+ """
183
+ NOTE: this interface is experimental.
184
+ Args:
185
+ input_shape: shapes (channels and stride) of the input features
186
+ transformer_dropout: dropout probability in transformer
187
+ transformer_nheads: number of heads in transformer
188
+ transformer_dim_feedforward: dimension of feedforward network
189
+ transformer_enc_layers: number of transformer encoder layers
190
+ conv_dims: number of output channels for the intermediate conv layers.
191
+ mask_dim: number of output channels for the final conv layer.
192
+ norm (str or callable): normalization for all conv layers
193
+ """
194
+ super().__init__()
195
+ transformer_input_shape = {
196
+ k: v for k, v in input_shape.items() if k in transformer_in_features
197
+ }
198
+
199
+ # this is the input shape of pixel decoder
200
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
201
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
202
+ self.feature_strides = [v.stride for k, v in input_shape]
203
+ self.feature_channels = [v.channels for k, v in input_shape]
204
+
205
+ # this is the input shape of transformer encoder (could use less features than pixel decoder
206
+ transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
207
+ self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5"
208
+ transformer_in_channels = [v.channels for k, v in transformer_input_shape]
209
+ self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers
210
+
211
+ self.transformer_num_feature_levels = len(self.transformer_in_features)
212
+ if self.transformer_num_feature_levels > 1:
213
+ input_proj_list = []
214
+ # from low resolution to high resolution (res5 -> res2)
215
+ for in_channels in transformer_in_channels[::-1]:
216
+ input_proj_list.append(nn.Sequential(
217
+ nn.Conv2d(in_channels, conv_dim, kernel_size=1),
218
+ nn.GroupNorm(32, conv_dim),
219
+ ))
220
+ self.input_proj = nn.ModuleList(input_proj_list)
221
+ else:
222
+ self.input_proj = nn.ModuleList([
223
+ nn.Sequential(
224
+ nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
225
+ nn.GroupNorm(32, conv_dim),
226
+ )])
227
+
228
+ for proj in self.input_proj:
229
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
230
+ nn.init.constant_(proj[0].bias, 0)
231
+
232
+ self.transformer = MSDeformAttnTransformerEncoderOnly(
233
+ d_model=conv_dim,
234
+ dropout=transformer_dropout,
235
+ nhead=transformer_nheads,
236
+ dim_feedforward=transformer_dim_feedforward,
237
+ num_encoder_layers=transformer_enc_layers,
238
+ num_feature_levels=self.transformer_num_feature_levels,
239
+ )
240
+ N_steps = conv_dim // 2
241
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
242
+
243
+ self.mask_dim = mask_dim
244
+ # use 1x1 conv instead
245
+ self.mask_features = Conv2d(
246
+ conv_dim,
247
+ mask_dim,
248
+ kernel_size=1,
249
+ stride=1,
250
+ padding=0,
251
+ )
252
+ weight_init.c2_xavier_fill(self.mask_features)
253
+
254
+ self.oneformer_num_feature_levels = 3 # always use 3 scales
255
+ self.common_stride = common_stride
256
+
257
+ # extra fpn levels
258
+ stride = min(self.transformer_feature_strides)
259
+ self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
260
+
261
+ lateral_convs = []
262
+ output_convs = []
263
+
264
+ use_bias = norm == ""
265
+ for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):
266
+ lateral_norm = get_norm(norm, conv_dim)
267
+ output_norm = get_norm(norm, conv_dim)
268
+
269
+ lateral_conv = Conv2d(
270
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
271
+ )
272
+ output_conv = Conv2d(
273
+ conv_dim,
274
+ conv_dim,
275
+ kernel_size=3,
276
+ stride=1,
277
+ padding=1,
278
+ bias=use_bias,
279
+ norm=output_norm,
280
+ activation=F.relu,
281
+ )
282
+ weight_init.c2_xavier_fill(lateral_conv)
283
+ weight_init.c2_xavier_fill(output_conv)
284
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
285
+ self.add_module("layer_{}".format(idx + 1), output_conv)
286
+
287
+ lateral_convs.append(lateral_conv)
288
+ output_convs.append(output_conv)
289
+ # Place convs into top-down order (from low to high resolution)
290
+ # to make the top-down computation in forward clearer.
291
+ self.lateral_convs = lateral_convs[::-1]
292
+ self.output_convs = output_convs[::-1]
293
+
294
+ @classmethod
295
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
296
+ ret = {}
297
+ ret["input_shape"] = {
298
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
299
+ }
300
+ ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
301
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
302
+ ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
303
+ ret["transformer_dropout"] = cfg.MODEL.ONE_FORMER.DROPOUT
304
+ ret["transformer_nheads"] = cfg.MODEL.ONE_FORMER.NHEADS
305
+ # ret["transformer_dim_feedforward"] = cfg.MODEL.ONE_FORMER.DIM_FEEDFORWARD
306
+ ret["transformer_dim_feedforward"] = 1024 # use 1024 for deformable transformer encoder
307
+ ret[
308
+ "transformer_enc_layers"
309
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
310
+ ret["transformer_in_features"] = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES
311
+ ret["common_stride"] = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
312
+ return ret
313
+
314
+ @autocast(enabled=False)
315
+ def forward_features(self, features):
316
+ srcs = []
317
+ pos = []
318
+ # Reverse feature maps into top-down order (from low to high resolution)
319
+ for idx, f in enumerate(self.transformer_in_features[::-1]):
320
+ x = features[f].float() # deformable detr does not support half precision
321
+ srcs.append(self.input_proj[idx](x))
322
+ pos.append(self.pe_layer(x))
323
+
324
+ y, spatial_shapes, level_start_index, valid_ratios = self.transformer(srcs, pos)
325
+ bs = y.shape[0]
326
+
327
+ split_size_or_sections = [None] * self.transformer_num_feature_levels
328
+ for i in range(self.transformer_num_feature_levels):
329
+ if i < self.transformer_num_feature_levels - 1:
330
+ split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
331
+ else:
332
+ split_size_or_sections[i] = y.shape[1] - level_start_index[i]
333
+ y = torch.split(y, split_size_or_sections, dim=1)
334
+
335
+ out = []
336
+ multi_scale_features = []
337
+ num_cur_levels = 0
338
+ for i, z in enumerate(y):
339
+ out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
340
+
341
+ # append `out` with extra FPN levels
342
+ # Reverse feature maps into top-down order (from low to high resolution)
343
+ for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
344
+ x = features[f].float()
345
+ lateral_conv = self.lateral_convs[idx]
346
+ output_conv = self.output_convs[idx]
347
+ cur_fpn = lateral_conv(x)
348
+ # Following FPN implementation, we use nearest upsampling here
349
+ y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
350
+ y = output_conv(y)
351
+ out.append(y)
352
+
353
+ for o in out:
354
+ if num_cur_levels < self.oneformer_num_feature_levels:
355
+ multi_scale_features.append(o)
356
+ num_cur_levels += 1
357
+
358
+ return self.mask_features(out[-1]), out[0], multi_scale_features, spatial_shapes, level_start_index
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/functions/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn_func import MSDeformAttnFunction
13
+
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+
13
+ from __future__ import absolute_import
14
+ from __future__ import print_function
15
+ from __future__ import division
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch.autograd import Function
20
+ from torch.autograd.function import once_differentiable
21
+
22
+ # if torch.cuda.is_available():
23
+ # try:
24
+ # import MultiScaleDeformableAttention as MSDA
25
+ # except ModuleNotFoundError as e:
26
+ # info_string = (
27
+ # "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
28
+ # "\t`cd oneformer/modeling/pixel_decoder/ops`\n"
29
+ # "\t`sh make.sh`\n"
30
+ # )
31
+ # raise ModuleNotFoundError(info_string)
32
+ # else:
33
+ # MultiScaleDeformableAttention = None
34
+
35
+
36
+
37
+ class MSDeformAttnFunction(Function):
38
+ @staticmethod
39
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
40
+ # ctx.im2col_step = im2col_step
41
+ output = ms_deform_attn_core_pytorch(
42
+ value, value_spatial_shapes, sampling_locations, attention_weights)
43
+ # ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
44
+ return output
45
+
46
+ # @staticmethod
47
+ # @once_differentiable
48
+ # def backward(ctx, grad_output):
49
+ # value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
50
+ # grad_value, grad_sampling_loc, grad_attn_weight = \
51
+ # MSDA.ms_deform_attn_backward(
52
+ # value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
53
+ #
54
+ # return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
55
+
56
+
57
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
58
+ # for debug and test only,
59
+ # need to use cuda version instead
60
+ N_, S_, M_, D_ = value.shape
61
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
62
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
63
+ sampling_grids = 2 * sampling_locations - 1
64
+ sampling_value_list = []
65
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
66
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
67
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
68
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
69
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
70
+ # N_*M_, D_, Lq_, P_
71
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
72
+ mode='bilinear', padding_mode='zeros', align_corners=False)
73
+ sampling_value_list.append(sampling_value_l_)
74
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
75
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
76
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
77
+ return output.transpose(1, 2).contiguous()
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/make.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ------------------------------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------------------------------
7
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ # ------------------------------------------------------------------------------------------------
9
+
10
+ # Copyright (c) Facebook, Inc. and its affiliates.
11
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
12
+
13
+ FORCE_CUDA=1 python setup.py build install
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn import MSDeformAttn
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/modules/ms_deform_attn.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import warnings
17
+ import math
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+ from torch.nn.init import xavier_uniform_, constant_
23
+
24
+ MSDeformAttnFunction = None
25
+ from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
26
+
27
+
28
+ def _is_power_of_2(n):
29
+ if (not isinstance(n, int)) or (n < 0):
30
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
31
+ return (n & (n-1) == 0) and n != 0
32
+
33
+
34
+ class MSDeformAttn(nn.Module):
35
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
36
+ """
37
+ Multi-Scale Deformable Attention Module
38
+ :param d_model hidden dimension
39
+ :param n_levels number of feature levels
40
+ :param n_heads number of attention heads
41
+ :param n_points number of sampling points per attention head per feature level
42
+ """
43
+ super().__init__()
44
+ if d_model % n_heads != 0:
45
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
46
+ _d_per_head = d_model // n_heads
47
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
48
+ if not _is_power_of_2(_d_per_head):
49
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
50
+ "which is more efficient in our CUDA implementation.")
51
+
52
+ self.im2col_step = 128
53
+
54
+ self.d_model = d_model
55
+ self.n_levels = n_levels
56
+ self.n_heads = n_heads
57
+ self.n_points = n_points
58
+
59
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
60
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
61
+ self.value_proj = nn.Linear(d_model, d_model)
62
+ self.output_proj = nn.Linear(d_model, d_model)
63
+
64
+ self._reset_parameters()
65
+
66
+ def _reset_parameters(self):
67
+ constant_(self.sampling_offsets.weight.data, 0.)
68
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
69
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
70
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
71
+ for i in range(self.n_points):
72
+ grid_init[:, :, i, :] *= i + 1
73
+ with torch.no_grad():
74
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
75
+ constant_(self.attention_weights.weight.data, 0.)
76
+ constant_(self.attention_weights.bias.data, 0.)
77
+ xavier_uniform_(self.value_proj.weight.data)
78
+ constant_(self.value_proj.bias.data, 0.)
79
+ xavier_uniform_(self.output_proj.weight.data)
80
+ constant_(self.output_proj.bias.data, 0.)
81
+
82
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
83
+ """
84
+ :param query (N, Length_{query}, C)
85
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
86
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
87
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
88
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
89
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
90
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
91
+ :return output (N, Length_{query}, C)
92
+ """
93
+ N, Len_q, _ = query.shape
94
+ N, Len_in, _ = input_flatten.shape
95
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
96
+
97
+ value = self.value_proj(input_flatten)
98
+ if input_padding_mask is not None:
99
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
100
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
101
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
102
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
103
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
104
+ # N, Len_q, n_heads, n_levels, n_points, 2
105
+ if reference_points.shape[-1] == 2:
106
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
107
+ sampling_locations = reference_points[:, :, None, :, None, :] \
108
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
109
+ elif reference_points.shape[-1] == 4:
110
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
111
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
112
+ else:
113
+ raise ValueError(
114
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
115
+ # try:
116
+ output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
117
+ # # For FLOPs calculation only
118
+ # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
119
+ output = self.output_proj(output)
120
+ return output
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/setup.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ import os
13
+ import glob
14
+
15
+ import torch
16
+
17
+ from torch.utils.cpp_extension import CUDA_HOME
18
+ from torch.utils.cpp_extension import CppExtension
19
+ from torch.utils.cpp_extension import CUDAExtension
20
+
21
+ from setuptools import find_packages
22
+ from setuptools import setup
23
+
24
+ requirements = ["torch", "torchvision"]
25
+
26
+ def get_extensions():
27
+ this_dir = os.path.dirname(os.path.abspath(__file__))
28
+ extensions_dir = os.path.join(this_dir, "src")
29
+
30
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
31
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
32
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
33
+
34
+ sources = main_file + source_cpu
35
+ extension = CppExtension
36
+ extra_compile_args = {"cxx": []}
37
+ define_macros = []
38
+
39
+ # Force cuda since torch ask for a device, not if cuda is in fact available.
40
+ if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
41
+ extension = CUDAExtension
42
+ sources += source_cuda
43
+ define_macros += [("WITH_CUDA", None)]
44
+ extra_compile_args["nvcc"] = [
45
+ "-DCUDA_HAS_FP16=1",
46
+ "-D__CUDA_NO_HALF_OPERATORS__",
47
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
48
+ "-D__CUDA_NO_HALF2_OPERATORS__",
49
+ ]
50
+ else:
51
+ if CUDA_HOME is None:
52
+ raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
53
+ else:
54
+ raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
55
+
56
+ sources = [os.path.join(extensions_dir, s) for s in sources]
57
+ include_dirs = [extensions_dir]
58
+ ext_modules = [
59
+ extension(
60
+ "MultiScaleDeformableAttention",
61
+ sources,
62
+ include_dirs=include_dirs,
63
+ define_macros=define_macros,
64
+ extra_compile_args=extra_compile_args,
65
+ )
66
+ ]
67
+ return ext_modules
68
+
69
+ setup(
70
+ name="MultiScaleDeformableAttention",
71
+ version="1.0",
72
+ author="Weijie Su",
73
+ url="https://github.com/fundamentalvision/Deformable-DETR",
74
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
75
+ packages=find_packages(exclude=("configs", "tests",)),
76
+ ext_modules=get_extensions(),
77
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
78
+ )
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include <vector>
17
+
18
+ #include <ATen/ATen.h>
19
+ #include <ATen/cuda/CUDAContext.h>
20
+
21
+
22
+ at::Tensor
23
+ ms_deform_attn_cpu_forward(
24
+ const at::Tensor &value,
25
+ const at::Tensor &spatial_shapes,
26
+ const at::Tensor &level_start_index,
27
+ const at::Tensor &sampling_loc,
28
+ const at::Tensor &attn_weight,
29
+ const int im2col_step)
30
+ {
31
+ AT_ERROR("Not implement on cpu");
32
+ }
33
+
34
+ std::vector<at::Tensor>
35
+ ms_deform_attn_cpu_backward(
36
+ const at::Tensor &value,
37
+ const at::Tensor &spatial_shapes,
38
+ const at::Tensor &level_start_index,
39
+ const at::Tensor &sampling_loc,
40
+ const at::Tensor &attn_weight,
41
+ const at::Tensor &grad_output,
42
+ const int im2col_step)
43
+ {
44
+ AT_ERROR("Not implement on cpu");
45
+ }
46
+
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+ #include <torch/extension.h>
18
+
19
+ at::Tensor
20
+ ms_deform_attn_cpu_forward(
21
+ const at::Tensor &value,
22
+ const at::Tensor &spatial_shapes,
23
+ const at::Tensor &level_start_index,
24
+ const at::Tensor &sampling_loc,
25
+ const at::Tensor &attn_weight,
26
+ const int im2col_step);
27
+
28
+ std::vector<at::Tensor>
29
+ ms_deform_attn_cpu_backward(
30
+ const at::Tensor &value,
31
+ const at::Tensor &spatial_shapes,
32
+ const at::Tensor &level_start_index,
33
+ const at::Tensor &sampling_loc,
34
+ const at::Tensor &attn_weight,
35
+ const at::Tensor &grad_output,
36
+ const int im2col_step);
37
+
38
+
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include <vector>
17
+ #include "cuda/ms_deform_im2col_cuda.cuh"
18
+
19
+ #include <ATen/ATen.h>
20
+ #include <ATen/cuda/CUDAContext.h>
21
+ #include <cuda.h>
22
+ #include <cuda_runtime.h>
23
+
24
+
25
+ at::Tensor ms_deform_attn_cuda_forward(
26
+ const at::Tensor &value,
27
+ const at::Tensor &spatial_shapes,
28
+ const at::Tensor &level_start_index,
29
+ const at::Tensor &sampling_loc,
30
+ const at::Tensor &attn_weight,
31
+ const int im2col_step)
32
+ {
33
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
34
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
35
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
36
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
37
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
38
+
39
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
40
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
41
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
42
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
43
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
44
+
45
+ const int batch = value.size(0);
46
+ const int spatial_size = value.size(1);
47
+ const int num_heads = value.size(2);
48
+ const int channels = value.size(3);
49
+
50
+ const int num_levels = spatial_shapes.size(0);
51
+
52
+ const int num_query = sampling_loc.size(1);
53
+ const int num_point = sampling_loc.size(4);
54
+
55
+ const int im2col_step_ = std::min(batch, im2col_step);
56
+
57
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
58
+
59
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
60
+
61
+ const int batch_n = im2col_step_;
62
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
63
+ auto per_value_size = spatial_size * num_heads * channels;
64
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
65
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
66
+ for (int n = 0; n < batch/im2col_step_; ++n)
67
+ {
68
+ auto columns = output_n.select(0, n);
69
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
70
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
71
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
72
+ spatial_shapes.data<int64_t>(),
73
+ level_start_index.data<int64_t>(),
74
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
75
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
76
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
77
+ columns.data<scalar_t>());
78
+
79
+ }));
80
+ }
81
+
82
+ output = output.view({batch, num_query, num_heads*channels});
83
+
84
+ return output;
85
+ }
86
+
87
+
88
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
89
+ const at::Tensor &value,
90
+ const at::Tensor &spatial_shapes,
91
+ const at::Tensor &level_start_index,
92
+ const at::Tensor &sampling_loc,
93
+ const at::Tensor &attn_weight,
94
+ const at::Tensor &grad_output,
95
+ const int im2col_step)
96
+ {
97
+
98
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
99
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
100
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
101
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
102
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
103
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
104
+
105
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
106
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
107
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
108
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
109
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
110
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
111
+
112
+ const int batch = value.size(0);
113
+ const int spatial_size = value.size(1);
114
+ const int num_heads = value.size(2);
115
+ const int channels = value.size(3);
116
+
117
+ const int num_levels = spatial_shapes.size(0);
118
+
119
+ const int num_query = sampling_loc.size(1);
120
+ const int num_point = sampling_loc.size(4);
121
+
122
+ const int im2col_step_ = std::min(batch, im2col_step);
123
+
124
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
125
+
126
+ auto grad_value = at::zeros_like(value);
127
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
128
+ auto grad_attn_weight = at::zeros_like(attn_weight);
129
+
130
+ const int batch_n = im2col_step_;
131
+ auto per_value_size = spatial_size * num_heads * channels;
132
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
133
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
134
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
135
+
136
+ for (int n = 0; n < batch/im2col_step_; ++n)
137
+ {
138
+ auto grad_output_g = grad_output_n.select(0, n);
139
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
140
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
141
+ grad_output_g.data<scalar_t>(),
142
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
143
+ spatial_shapes.data<int64_t>(),
144
+ level_start_index.data<int64_t>(),
145
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
147
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
148
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
149
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
150
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
151
+
152
+ }));
153
+ }
154
+
155
+ return {
156
+ grad_value, grad_sampling_loc, grad_attn_weight
157
+ };
158
+ }
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+ #include <torch/extension.h>
18
+
19
+ at::Tensor ms_deform_attn_cuda_forward(
20
+ const at::Tensor &value,
21
+ const at::Tensor &spatial_shapes,
22
+ const at::Tensor &level_start_index,
23
+ const at::Tensor &sampling_loc,
24
+ const at::Tensor &attn_weight,
25
+ const int im2col_step);
26
+
27
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
28
+ const at::Tensor &value,
29
+ const at::Tensor &spatial_shapes,
30
+ const at::Tensor &level_start_index,
31
+ const at::Tensor &sampling_loc,
32
+ const at::Tensor &attn_weight,
33
+ const at::Tensor &grad_output,
34
+ const int im2col_step);
35
+
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ /*!
13
+ * Copyright (c) Facebook, Inc. and its affiliates.
14
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
15
+ */
16
+
17
+ #include <cstdio>
18
+ #include <algorithm>
19
+ #include <cstring>
20
+
21
+ #include <ATen/ATen.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+
24
+ #include <THC/THCAtomics.cuh>
25
+
26
+ #define CUDA_KERNEL_LOOP(i, n) \
27
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
28
+ i < (n); \
29
+ i += blockDim.x * gridDim.x)
30
+
31
+ const int CUDA_NUM_THREADS = 1024;
32
+ inline int GET_BLOCKS(const int N, const int num_threads)
33
+ {
34
+ return (N + num_threads - 1) / num_threads;
35
+ }
36
+
37
+
38
+ template <typename scalar_t>
39
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
40
+ const int &height, const int &width, const int &nheads, const int &channels,
41
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
42
+ {
43
+ const int h_low = floor(h);
44
+ const int w_low = floor(w);
45
+ const int h_high = h_low + 1;
46
+ const int w_high = w_low + 1;
47
+
48
+ const scalar_t lh = h - h_low;
49
+ const scalar_t lw = w - w_low;
50
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
51
+
52
+ const int w_stride = nheads * channels;
53
+ const int h_stride = width * w_stride;
54
+ const int h_low_ptr_offset = h_low * h_stride;
55
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
56
+ const int w_low_ptr_offset = w_low * w_stride;
57
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
58
+ const int base_ptr = m * channels + c;
59
+
60
+ scalar_t v1 = 0;
61
+ if (h_low >= 0 && w_low >= 0)
62
+ {
63
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
64
+ v1 = bottom_data[ptr1];
65
+ }
66
+ scalar_t v2 = 0;
67
+ if (h_low >= 0 && w_high <= width - 1)
68
+ {
69
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
70
+ v2 = bottom_data[ptr2];
71
+ }
72
+ scalar_t v3 = 0;
73
+ if (h_high <= height - 1 && w_low >= 0)
74
+ {
75
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
76
+ v3 = bottom_data[ptr3];
77
+ }
78
+ scalar_t v4 = 0;
79
+ if (h_high <= height - 1 && w_high <= width - 1)
80
+ {
81
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
82
+ v4 = bottom_data[ptr4];
83
+ }
84
+
85
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
86
+
87
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
88
+ return val;
89
+ }
90
+
91
+
92
+ template <typename scalar_t>
93
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
94
+ const int &height, const int &width, const int &nheads, const int &channels,
95
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
96
+ const scalar_t &top_grad,
97
+ const scalar_t &attn_weight,
98
+ scalar_t* &grad_value,
99
+ scalar_t* grad_sampling_loc,
100
+ scalar_t* grad_attn_weight)
101
+ {
102
+ const int h_low = floor(h);
103
+ const int w_low = floor(w);
104
+ const int h_high = h_low + 1;
105
+ const int w_high = w_low + 1;
106
+
107
+ const scalar_t lh = h - h_low;
108
+ const scalar_t lw = w - w_low;
109
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
110
+
111
+ const int w_stride = nheads * channels;
112
+ const int h_stride = width * w_stride;
113
+ const int h_low_ptr_offset = h_low * h_stride;
114
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
115
+ const int w_low_ptr_offset = w_low * w_stride;
116
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
117
+ const int base_ptr = m * channels + c;
118
+
119
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
120
+ const scalar_t top_grad_value = top_grad * attn_weight;
121
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
122
+
123
+ scalar_t v1 = 0;
124
+ if (h_low >= 0 && w_low >= 0)
125
+ {
126
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
127
+ v1 = bottom_data[ptr1];
128
+ grad_h_weight -= hw * v1;
129
+ grad_w_weight -= hh * v1;
130
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
131
+ }
132
+ scalar_t v2 = 0;
133
+ if (h_low >= 0 && w_high <= width - 1)
134
+ {
135
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
136
+ v2 = bottom_data[ptr2];
137
+ grad_h_weight -= lw * v2;
138
+ grad_w_weight += hh * v2;
139
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
140
+ }
141
+ scalar_t v3 = 0;
142
+ if (h_high <= height - 1 && w_low >= 0)
143
+ {
144
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
145
+ v3 = bottom_data[ptr3];
146
+ grad_h_weight += hw * v3;
147
+ grad_w_weight -= lh * v3;
148
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
149
+ }
150
+ scalar_t v4 = 0;
151
+ if (h_high <= height - 1 && w_high <= width - 1)
152
+ {
153
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
154
+ v4 = bottom_data[ptr4];
155
+ grad_h_weight += lw * v4;
156
+ grad_w_weight += lh * v4;
157
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
158
+ }
159
+
160
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
161
+ *grad_attn_weight = top_grad * val;
162
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
163
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
164
+ }
165
+
166
+
167
+ template <typename scalar_t>
168
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
169
+ const int &height, const int &width, const int &nheads, const int &channels,
170
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
171
+ const scalar_t &top_grad,
172
+ const scalar_t &attn_weight,
173
+ scalar_t* &grad_value,
174
+ scalar_t* grad_sampling_loc,
175
+ scalar_t* grad_attn_weight)
176
+ {
177
+ const int h_low = floor(h);
178
+ const int w_low = floor(w);
179
+ const int h_high = h_low + 1;
180
+ const int w_high = w_low + 1;
181
+
182
+ const scalar_t lh = h - h_low;
183
+ const scalar_t lw = w - w_low;
184
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
185
+
186
+ const int w_stride = nheads * channels;
187
+ const int h_stride = width * w_stride;
188
+ const int h_low_ptr_offset = h_low * h_stride;
189
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
190
+ const int w_low_ptr_offset = w_low * w_stride;
191
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
192
+ const int base_ptr = m * channels + c;
193
+
194
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
195
+ const scalar_t top_grad_value = top_grad * attn_weight;
196
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
197
+
198
+ scalar_t v1 = 0;
199
+ if (h_low >= 0 && w_low >= 0)
200
+ {
201
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
202
+ v1 = bottom_data[ptr1];
203
+ grad_h_weight -= hw * v1;
204
+ grad_w_weight -= hh * v1;
205
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
206
+ }
207
+ scalar_t v2 = 0;
208
+ if (h_low >= 0 && w_high <= width - 1)
209
+ {
210
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
211
+ v2 = bottom_data[ptr2];
212
+ grad_h_weight -= lw * v2;
213
+ grad_w_weight += hh * v2;
214
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
215
+ }
216
+ scalar_t v3 = 0;
217
+ if (h_high <= height - 1 && w_low >= 0)
218
+ {
219
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
220
+ v3 = bottom_data[ptr3];
221
+ grad_h_weight += hw * v3;
222
+ grad_w_weight -= lh * v3;
223
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
224
+ }
225
+ scalar_t v4 = 0;
226
+ if (h_high <= height - 1 && w_high <= width - 1)
227
+ {
228
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
229
+ v4 = bottom_data[ptr4];
230
+ grad_h_weight += lw * v4;
231
+ grad_w_weight += lh * v4;
232
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
233
+ }
234
+
235
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
236
+ atomicAdd(grad_attn_weight, top_grad * val);
237
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
238
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
239
+ }
240
+
241
+
242
+ template <typename scalar_t>
243
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
244
+ const scalar_t *data_value,
245
+ const int64_t *data_spatial_shapes,
246
+ const int64_t *data_level_start_index,
247
+ const scalar_t *data_sampling_loc,
248
+ const scalar_t *data_attn_weight,
249
+ const int batch_size,
250
+ const int spatial_size,
251
+ const int num_heads,
252
+ const int channels,
253
+ const int num_levels,
254
+ const int num_query,
255
+ const int num_point,
256
+ scalar_t *data_col)
257
+ {
258
+ CUDA_KERNEL_LOOP(index, n)
259
+ {
260
+ int _temp = index;
261
+ const int c_col = _temp % channels;
262
+ _temp /= channels;
263
+ const int sampling_index = _temp;
264
+ const int m_col = _temp % num_heads;
265
+ _temp /= num_heads;
266
+ const int q_col = _temp % num_query;
267
+ _temp /= num_query;
268
+ const int b_col = _temp;
269
+
270
+ scalar_t *data_col_ptr = data_col + index;
271
+ int data_weight_ptr = sampling_index * num_levels * num_point;
272
+ int data_loc_w_ptr = data_weight_ptr << 1;
273
+ const int qid_stride = num_heads * channels;
274
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
275
+ scalar_t col = 0;
276
+
277
+ for (int l_col=0; l_col < num_levels; ++l_col)
278
+ {
279
+ const int level_start_id = data_level_start_index[l_col];
280
+ const int spatial_h_ptr = l_col << 1;
281
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
282
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
283
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
284
+ for (int p_col=0; p_col < num_point; ++p_col)
285
+ {
286
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
287
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
288
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
289
+
290
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
291
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
292
+
293
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
294
+ {
295
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
296
+ }
297
+
298
+ data_weight_ptr += 1;
299
+ data_loc_w_ptr += 2;
300
+ }
301
+ }
302
+ *data_col_ptr = col;
303
+ }
304
+ }
305
+
306
+ template <typename scalar_t, unsigned int blockSize>
307
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
308
+ const scalar_t *grad_col,
309
+ const scalar_t *data_value,
310
+ const int64_t *data_spatial_shapes,
311
+ const int64_t *data_level_start_index,
312
+ const scalar_t *data_sampling_loc,
313
+ const scalar_t *data_attn_weight,
314
+ const int batch_size,
315
+ const int spatial_size,
316
+ const int num_heads,
317
+ const int channels,
318
+ const int num_levels,
319
+ const int num_query,
320
+ const int num_point,
321
+ scalar_t *grad_value,
322
+ scalar_t *grad_sampling_loc,
323
+ scalar_t *grad_attn_weight)
324
+ {
325
+ CUDA_KERNEL_LOOP(index, n)
326
+ {
327
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
328
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
329
+ unsigned int tid = threadIdx.x;
330
+ int _temp = index;
331
+ const int c_col = _temp % channels;
332
+ _temp /= channels;
333
+ const int sampling_index = _temp;
334
+ const int m_col = _temp % num_heads;
335
+ _temp /= num_heads;
336
+ const int q_col = _temp % num_query;
337
+ _temp /= num_query;
338
+ const int b_col = _temp;
339
+
340
+ const scalar_t top_grad = grad_col[index];
341
+
342
+ int data_weight_ptr = sampling_index * num_levels * num_point;
343
+ int data_loc_w_ptr = data_weight_ptr << 1;
344
+ const int grad_sampling_ptr = data_weight_ptr;
345
+ grad_sampling_loc += grad_sampling_ptr << 1;
346
+ grad_attn_weight += grad_sampling_ptr;
347
+ const int grad_weight_stride = 1;
348
+ const int grad_loc_stride = 2;
349
+ const int qid_stride = num_heads * channels;
350
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
351
+
352
+ for (int l_col=0; l_col < num_levels; ++l_col)
353
+ {
354
+ const int level_start_id = data_level_start_index[l_col];
355
+ const int spatial_h_ptr = l_col << 1;
356
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
357
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
358
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
359
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
360
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
361
+
362
+ for (int p_col=0; p_col < num_point; ++p_col)
363
+ {
364
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
365
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
366
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
367
+
368
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
369
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
370
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
371
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
372
+ *(cache_grad_attn_weight+threadIdx.x)=0;
373
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
374
+ {
375
+ ms_deform_attn_col2im_bilinear(
376
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
377
+ top_grad, weight, grad_value_ptr,
378
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
379
+ }
380
+
381
+ __syncthreads();
382
+ if (tid == 0)
383
+ {
384
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
385
+ int sid=2;
386
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
387
+ {
388
+ _grad_w += cache_grad_sampling_loc[sid];
389
+ _grad_h += cache_grad_sampling_loc[sid + 1];
390
+ _grad_a += cache_grad_attn_weight[tid];
391
+ sid += 2;
392
+ }
393
+
394
+
395
+ *grad_sampling_loc = _grad_w;
396
+ *(grad_sampling_loc + 1) = _grad_h;
397
+ *grad_attn_weight = _grad_a;
398
+ }
399
+ __syncthreads();
400
+
401
+ data_weight_ptr += 1;
402
+ data_loc_w_ptr += 2;
403
+ grad_attn_weight += grad_weight_stride;
404
+ grad_sampling_loc += grad_loc_stride;
405
+ }
406
+ }
407
+ }
408
+ }
409
+
410
+
411
+ template <typename scalar_t, unsigned int blockSize>
412
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
413
+ const scalar_t *grad_col,
414
+ const scalar_t *data_value,
415
+ const int64_t *data_spatial_shapes,
416
+ const int64_t *data_level_start_index,
417
+ const scalar_t *data_sampling_loc,
418
+ const scalar_t *data_attn_weight,
419
+ const int batch_size,
420
+ const int spatial_size,
421
+ const int num_heads,
422
+ const int channels,
423
+ const int num_levels,
424
+ const int num_query,
425
+ const int num_point,
426
+ scalar_t *grad_value,
427
+ scalar_t *grad_sampling_loc,
428
+ scalar_t *grad_attn_weight)
429
+ {
430
+ CUDA_KERNEL_LOOP(index, n)
431
+ {
432
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
433
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
434
+ unsigned int tid = threadIdx.x;
435
+ int _temp = index;
436
+ const int c_col = _temp % channels;
437
+ _temp /= channels;
438
+ const int sampling_index = _temp;
439
+ const int m_col = _temp % num_heads;
440
+ _temp /= num_heads;
441
+ const int q_col = _temp % num_query;
442
+ _temp /= num_query;
443
+ const int b_col = _temp;
444
+
445
+ const scalar_t top_grad = grad_col[index];
446
+
447
+ int data_weight_ptr = sampling_index * num_levels * num_point;
448
+ int data_loc_w_ptr = data_weight_ptr << 1;
449
+ const int grad_sampling_ptr = data_weight_ptr;
450
+ grad_sampling_loc += grad_sampling_ptr << 1;
451
+ grad_attn_weight += grad_sampling_ptr;
452
+ const int grad_weight_stride = 1;
453
+ const int grad_loc_stride = 2;
454
+ const int qid_stride = num_heads * channels;
455
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
456
+
457
+ for (int l_col=0; l_col < num_levels; ++l_col)
458
+ {
459
+ const int level_start_id = data_level_start_index[l_col];
460
+ const int spatial_h_ptr = l_col << 1;
461
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
462
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
463
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
464
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
465
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
466
+
467
+ for (int p_col=0; p_col < num_point; ++p_col)
468
+ {
469
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
470
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
471
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
472
+
473
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
474
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
475
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
476
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
477
+ *(cache_grad_attn_weight+threadIdx.x)=0;
478
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
479
+ {
480
+ ms_deform_attn_col2im_bilinear(
481
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
482
+ top_grad, weight, grad_value_ptr,
483
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
484
+ }
485
+
486
+ __syncthreads();
487
+
488
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
489
+ {
490
+ if (tid < s) {
491
+ const unsigned int xid1 = tid << 1;
492
+ const unsigned int xid2 = (tid + s) << 1;
493
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
494
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
495
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
496
+ }
497
+ __syncthreads();
498
+ }
499
+
500
+ if (tid == 0)
501
+ {
502
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
503
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
504
+ *grad_attn_weight = cache_grad_attn_weight[0];
505
+ }
506
+ __syncthreads();
507
+
508
+ data_weight_ptr += 1;
509
+ data_loc_w_ptr += 2;
510
+ grad_attn_weight += grad_weight_stride;
511
+ grad_sampling_loc += grad_loc_stride;
512
+ }
513
+ }
514
+ }
515
+ }
516
+
517
+
518
+ template <typename scalar_t>
519
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
520
+ const scalar_t *grad_col,
521
+ const scalar_t *data_value,
522
+ const int64_t *data_spatial_shapes,
523
+ const int64_t *data_level_start_index,
524
+ const scalar_t *data_sampling_loc,
525
+ const scalar_t *data_attn_weight,
526
+ const int batch_size,
527
+ const int spatial_size,
528
+ const int num_heads,
529
+ const int channels,
530
+ const int num_levels,
531
+ const int num_query,
532
+ const int num_point,
533
+ scalar_t *grad_value,
534
+ scalar_t *grad_sampling_loc,
535
+ scalar_t *grad_attn_weight)
536
+ {
537
+ CUDA_KERNEL_LOOP(index, n)
538
+ {
539
+ extern __shared__ int _s[];
540
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
541
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
542
+ unsigned int tid = threadIdx.x;
543
+ int _temp = index;
544
+ const int c_col = _temp % channels;
545
+ _temp /= channels;
546
+ const int sampling_index = _temp;
547
+ const int m_col = _temp % num_heads;
548
+ _temp /= num_heads;
549
+ const int q_col = _temp % num_query;
550
+ _temp /= num_query;
551
+ const int b_col = _temp;
552
+
553
+ const scalar_t top_grad = grad_col[index];
554
+
555
+ int data_weight_ptr = sampling_index * num_levels * num_point;
556
+ int data_loc_w_ptr = data_weight_ptr << 1;
557
+ const int grad_sampling_ptr = data_weight_ptr;
558
+ grad_sampling_loc += grad_sampling_ptr << 1;
559
+ grad_attn_weight += grad_sampling_ptr;
560
+ const int grad_weight_stride = 1;
561
+ const int grad_loc_stride = 2;
562
+ const int qid_stride = num_heads * channels;
563
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
564
+
565
+ for (int l_col=0; l_col < num_levels; ++l_col)
566
+ {
567
+ const int level_start_id = data_level_start_index[l_col];
568
+ const int spatial_h_ptr = l_col << 1;
569
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
570
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
571
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
572
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
573
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
574
+
575
+ for (int p_col=0; p_col < num_point; ++p_col)
576
+ {
577
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
578
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
579
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
580
+
581
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
582
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
583
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
584
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
585
+ *(cache_grad_attn_weight+threadIdx.x)=0;
586
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
587
+ {
588
+ ms_deform_attn_col2im_bilinear(
589
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
590
+ top_grad, weight, grad_value_ptr,
591
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
592
+ }
593
+
594
+ __syncthreads();
595
+ if (tid == 0)
596
+ {
597
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
598
+ int sid=2;
599
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
600
+ {
601
+ _grad_w += cache_grad_sampling_loc[sid];
602
+ _grad_h += cache_grad_sampling_loc[sid + 1];
603
+ _grad_a += cache_grad_attn_weight[tid];
604
+ sid += 2;
605
+ }
606
+
607
+
608
+ *grad_sampling_loc = _grad_w;
609
+ *(grad_sampling_loc + 1) = _grad_h;
610
+ *grad_attn_weight = _grad_a;
611
+ }
612
+ __syncthreads();
613
+
614
+ data_weight_ptr += 1;
615
+ data_loc_w_ptr += 2;
616
+ grad_attn_weight += grad_weight_stride;
617
+ grad_sampling_loc += grad_loc_stride;
618
+ }
619
+ }
620
+ }
621
+ }
622
+
623
+ template <typename scalar_t>
624
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
625
+ const scalar_t *grad_col,
626
+ const scalar_t *data_value,
627
+ const int64_t *data_spatial_shapes,
628
+ const int64_t *data_level_start_index,
629
+ const scalar_t *data_sampling_loc,
630
+ const scalar_t *data_attn_weight,
631
+ const int batch_size,
632
+ const int spatial_size,
633
+ const int num_heads,
634
+ const int channels,
635
+ const int num_levels,
636
+ const int num_query,
637
+ const int num_point,
638
+ scalar_t *grad_value,
639
+ scalar_t *grad_sampling_loc,
640
+ scalar_t *grad_attn_weight)
641
+ {
642
+ CUDA_KERNEL_LOOP(index, n)
643
+ {
644
+ extern __shared__ int _s[];
645
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
646
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
647
+ unsigned int tid = threadIdx.x;
648
+ int _temp = index;
649
+ const int c_col = _temp % channels;
650
+ _temp /= channels;
651
+ const int sampling_index = _temp;
652
+ const int m_col = _temp % num_heads;
653
+ _temp /= num_heads;
654
+ const int q_col = _temp % num_query;
655
+ _temp /= num_query;
656
+ const int b_col = _temp;
657
+
658
+ const scalar_t top_grad = grad_col[index];
659
+
660
+ int data_weight_ptr = sampling_index * num_levels * num_point;
661
+ int data_loc_w_ptr = data_weight_ptr << 1;
662
+ const int grad_sampling_ptr = data_weight_ptr;
663
+ grad_sampling_loc += grad_sampling_ptr << 1;
664
+ grad_attn_weight += grad_sampling_ptr;
665
+ const int grad_weight_stride = 1;
666
+ const int grad_loc_stride = 2;
667
+ const int qid_stride = num_heads * channels;
668
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
669
+
670
+ for (int l_col=0; l_col < num_levels; ++l_col)
671
+ {
672
+ const int level_start_id = data_level_start_index[l_col];
673
+ const int spatial_h_ptr = l_col << 1;
674
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
675
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
676
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
677
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
678
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
679
+
680
+ for (int p_col=0; p_col < num_point; ++p_col)
681
+ {
682
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
683
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
684
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
685
+
686
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
687
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
688
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
689
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
690
+ *(cache_grad_attn_weight+threadIdx.x)=0;
691
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
692
+ {
693
+ ms_deform_attn_col2im_bilinear(
694
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
695
+ top_grad, weight, grad_value_ptr,
696
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
697
+ }
698
+
699
+ __syncthreads();
700
+
701
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
702
+ {
703
+ if (tid < s) {
704
+ const unsigned int xid1 = tid << 1;
705
+ const unsigned int xid2 = (tid + s) << 1;
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
709
+ if (tid + (s << 1) < spre)
710
+ {
711
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
712
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
713
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
714
+ }
715
+ }
716
+ __syncthreads();
717
+ }
718
+
719
+ if (tid == 0)
720
+ {
721
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
722
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
723
+ *grad_attn_weight = cache_grad_attn_weight[0];
724
+ }
725
+ __syncthreads();
726
+
727
+ data_weight_ptr += 1;
728
+ data_loc_w_ptr += 2;
729
+ grad_attn_weight += grad_weight_stride;
730
+ grad_sampling_loc += grad_loc_stride;
731
+ }
732
+ }
733
+ }
734
+ }
735
+
736
+ template <typename scalar_t>
737
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
738
+ const scalar_t *grad_col,
739
+ const scalar_t *data_value,
740
+ const int64_t *data_spatial_shapes,
741
+ const int64_t *data_level_start_index,
742
+ const scalar_t *data_sampling_loc,
743
+ const scalar_t *data_attn_weight,
744
+ const int batch_size,
745
+ const int spatial_size,
746
+ const int num_heads,
747
+ const int channels,
748
+ const int num_levels,
749
+ const int num_query,
750
+ const int num_point,
751
+ scalar_t *grad_value,
752
+ scalar_t *grad_sampling_loc,
753
+ scalar_t *grad_attn_weight)
754
+ {
755
+ CUDA_KERNEL_LOOP(index, n)
756
+ {
757
+ extern __shared__ int _s[];
758
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
759
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
760
+ unsigned int tid = threadIdx.x;
761
+ int _temp = index;
762
+ const int c_col = _temp % channels;
763
+ _temp /= channels;
764
+ const int sampling_index = _temp;
765
+ const int m_col = _temp % num_heads;
766
+ _temp /= num_heads;
767
+ const int q_col = _temp % num_query;
768
+ _temp /= num_query;
769
+ const int b_col = _temp;
770
+
771
+ const scalar_t top_grad = grad_col[index];
772
+
773
+ int data_weight_ptr = sampling_index * num_levels * num_point;
774
+ int data_loc_w_ptr = data_weight_ptr << 1;
775
+ const int grad_sampling_ptr = data_weight_ptr;
776
+ grad_sampling_loc += grad_sampling_ptr << 1;
777
+ grad_attn_weight += grad_sampling_ptr;
778
+ const int grad_weight_stride = 1;
779
+ const int grad_loc_stride = 2;
780
+ const int qid_stride = num_heads * channels;
781
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
782
+
783
+ for (int l_col=0; l_col < num_levels; ++l_col)
784
+ {
785
+ const int level_start_id = data_level_start_index[l_col];
786
+ const int spatial_h_ptr = l_col << 1;
787
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
788
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
789
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
790
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
791
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
792
+
793
+ for (int p_col=0; p_col < num_point; ++p_col)
794
+ {
795
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
796
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
797
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
798
+
799
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
800
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
801
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
802
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
803
+ *(cache_grad_attn_weight+threadIdx.x)=0;
804
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
805
+ {
806
+ ms_deform_attn_col2im_bilinear(
807
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
808
+ top_grad, weight, grad_value_ptr,
809
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
810
+ }
811
+
812
+ __syncthreads();
813
+
814
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
815
+ {
816
+ if (tid < s) {
817
+ const unsigned int xid1 = tid << 1;
818
+ const unsigned int xid2 = (tid + s) << 1;
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
822
+ if (tid + (s << 1) < spre)
823
+ {
824
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
825
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
826
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
827
+ }
828
+ }
829
+ __syncthreads();
830
+ }
831
+
832
+ if (tid == 0)
833
+ {
834
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
835
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
836
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
837
+ }
838
+ __syncthreads();
839
+
840
+ data_weight_ptr += 1;
841
+ data_loc_w_ptr += 2;
842
+ grad_attn_weight += grad_weight_stride;
843
+ grad_sampling_loc += grad_loc_stride;
844
+ }
845
+ }
846
+ }
847
+ }
848
+
849
+
850
+ template <typename scalar_t>
851
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
852
+ const scalar_t *grad_col,
853
+ const scalar_t *data_value,
854
+ const int64_t *data_spatial_shapes,
855
+ const int64_t *data_level_start_index,
856
+ const scalar_t *data_sampling_loc,
857
+ const scalar_t *data_attn_weight,
858
+ const int batch_size,
859
+ const int spatial_size,
860
+ const int num_heads,
861
+ const int channels,
862
+ const int num_levels,
863
+ const int num_query,
864
+ const int num_point,
865
+ scalar_t *grad_value,
866
+ scalar_t *grad_sampling_loc,
867
+ scalar_t *grad_attn_weight)
868
+ {
869
+ CUDA_KERNEL_LOOP(index, n)
870
+ {
871
+ int _temp = index;
872
+ const int c_col = _temp % channels;
873
+ _temp /= channels;
874
+ const int sampling_index = _temp;
875
+ const int m_col = _temp % num_heads;
876
+ _temp /= num_heads;
877
+ const int q_col = _temp % num_query;
878
+ _temp /= num_query;
879
+ const int b_col = _temp;
880
+
881
+ const scalar_t top_grad = grad_col[index];
882
+
883
+ int data_weight_ptr = sampling_index * num_levels * num_point;
884
+ int data_loc_w_ptr = data_weight_ptr << 1;
885
+ const int grad_sampling_ptr = data_weight_ptr;
886
+ grad_sampling_loc += grad_sampling_ptr << 1;
887
+ grad_attn_weight += grad_sampling_ptr;
888
+ const int grad_weight_stride = 1;
889
+ const int grad_loc_stride = 2;
890
+ const int qid_stride = num_heads * channels;
891
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
892
+
893
+ for (int l_col=0; l_col < num_levels; ++l_col)
894
+ {
895
+ const int level_start_id = data_level_start_index[l_col];
896
+ const int spatial_h_ptr = l_col << 1;
897
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
898
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
899
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
900
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
901
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
902
+
903
+ for (int p_col=0; p_col < num_point; ++p_col)
904
+ {
905
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
906
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
907
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
908
+
909
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
910
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
911
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
912
+ {
913
+ ms_deform_attn_col2im_bilinear_gm(
914
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
915
+ top_grad, weight, grad_value_ptr,
916
+ grad_sampling_loc, grad_attn_weight);
917
+ }
918
+ data_weight_ptr += 1;
919
+ data_loc_w_ptr += 2;
920
+ grad_attn_weight += grad_weight_stride;
921
+ grad_sampling_loc += grad_loc_stride;
922
+ }
923
+ }
924
+ }
925
+ }
926
+
927
+
928
+ template <typename scalar_t>
929
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
930
+ const scalar_t* data_value,
931
+ const int64_t* data_spatial_shapes,
932
+ const int64_t* data_level_start_index,
933
+ const scalar_t* data_sampling_loc,
934
+ const scalar_t* data_attn_weight,
935
+ const int batch_size,
936
+ const int spatial_size,
937
+ const int num_heads,
938
+ const int channels,
939
+ const int num_levels,
940
+ const int num_query,
941
+ const int num_point,
942
+ scalar_t* data_col)
943
+ {
944
+ const int num_kernels = batch_size * num_query * num_heads * channels;
945
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
946
+ const int num_threads = CUDA_NUM_THREADS;
947
+ ms_deformable_im2col_gpu_kernel<scalar_t>
948
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
949
+ 0, stream>>>(
950
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
951
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
952
+
953
+ cudaError_t err = cudaGetLastError();
954
+ if (err != cudaSuccess)
955
+ {
956
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
957
+ }
958
+
959
+ }
960
+
961
+ template <typename scalar_t>
962
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
963
+ const scalar_t* grad_col,
964
+ const scalar_t* data_value,
965
+ const int64_t * data_spatial_shapes,
966
+ const int64_t * data_level_start_index,
967
+ const scalar_t * data_sampling_loc,
968
+ const scalar_t * data_attn_weight,
969
+ const int batch_size,
970
+ const int spatial_size,
971
+ const int num_heads,
972
+ const int channels,
973
+ const int num_levels,
974
+ const int num_query,
975
+ const int num_point,
976
+ scalar_t* grad_value,
977
+ scalar_t* grad_sampling_loc,
978
+ scalar_t* grad_attn_weight)
979
+ {
980
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
981
+ const int num_kernels = batch_size * num_query * num_heads * channels;
982
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
983
+ if (channels > 1024)
984
+ {
985
+ if ((channels & 1023) == 0)
986
+ {
987
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
988
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
989
+ num_threads*3*sizeof(scalar_t), stream>>>(
990
+ num_kernels,
991
+ grad_col,
992
+ data_value,
993
+ data_spatial_shapes,
994
+ data_level_start_index,
995
+ data_sampling_loc,
996
+ data_attn_weight,
997
+ batch_size,
998
+ spatial_size,
999
+ num_heads,
1000
+ channels,
1001
+ num_levels,
1002
+ num_query,
1003
+ num_point,
1004
+ grad_value,
1005
+ grad_sampling_loc,
1006
+ grad_attn_weight);
1007
+ }
1008
+ else
1009
+ {
1010
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1011
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1012
+ 0, stream>>>(
1013
+ num_kernels,
1014
+ grad_col,
1015
+ data_value,
1016
+ data_spatial_shapes,
1017
+ data_level_start_index,
1018
+ data_sampling_loc,
1019
+ data_attn_weight,
1020
+ batch_size,
1021
+ spatial_size,
1022
+ num_heads,
1023
+ channels,
1024
+ num_levels,
1025
+ num_query,
1026
+ num_point,
1027
+ grad_value,
1028
+ grad_sampling_loc,
1029
+ grad_attn_weight);
1030
+ }
1031
+ }
1032
+ else{
1033
+ switch(channels)
1034
+ {
1035
+ case 1:
1036
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1037
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1038
+ 0, stream>>>(
1039
+ num_kernels,
1040
+ grad_col,
1041
+ data_value,
1042
+ data_spatial_shapes,
1043
+ data_level_start_index,
1044
+ data_sampling_loc,
1045
+ data_attn_weight,
1046
+ batch_size,
1047
+ spatial_size,
1048
+ num_heads,
1049
+ channels,
1050
+ num_levels,
1051
+ num_query,
1052
+ num_point,
1053
+ grad_value,
1054
+ grad_sampling_loc,
1055
+ grad_attn_weight);
1056
+ break;
1057
+ case 2:
1058
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1059
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1060
+ 0, stream>>>(
1061
+ num_kernels,
1062
+ grad_col,
1063
+ data_value,
1064
+ data_spatial_shapes,
1065
+ data_level_start_index,
1066
+ data_sampling_loc,
1067
+ data_attn_weight,
1068
+ batch_size,
1069
+ spatial_size,
1070
+ num_heads,
1071
+ channels,
1072
+ num_levels,
1073
+ num_query,
1074
+ num_point,
1075
+ grad_value,
1076
+ grad_sampling_loc,
1077
+ grad_attn_weight);
1078
+ break;
1079
+ case 4:
1080
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1081
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1082
+ 0, stream>>>(
1083
+ num_kernels,
1084
+ grad_col,
1085
+ data_value,
1086
+ data_spatial_shapes,
1087
+ data_level_start_index,
1088
+ data_sampling_loc,
1089
+ data_attn_weight,
1090
+ batch_size,
1091
+ spatial_size,
1092
+ num_heads,
1093
+ channels,
1094
+ num_levels,
1095
+ num_query,
1096
+ num_point,
1097
+ grad_value,
1098
+ grad_sampling_loc,
1099
+ grad_attn_weight);
1100
+ break;
1101
+ case 8:
1102
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1103
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1104
+ 0, stream>>>(
1105
+ num_kernels,
1106
+ grad_col,
1107
+ data_value,
1108
+ data_spatial_shapes,
1109
+ data_level_start_index,
1110
+ data_sampling_loc,
1111
+ data_attn_weight,
1112
+ batch_size,
1113
+ spatial_size,
1114
+ num_heads,
1115
+ channels,
1116
+ num_levels,
1117
+ num_query,
1118
+ num_point,
1119
+ grad_value,
1120
+ grad_sampling_loc,
1121
+ grad_attn_weight);
1122
+ break;
1123
+ case 16:
1124
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1125
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1126
+ 0, stream>>>(
1127
+ num_kernels,
1128
+ grad_col,
1129
+ data_value,
1130
+ data_spatial_shapes,
1131
+ data_level_start_index,
1132
+ data_sampling_loc,
1133
+ data_attn_weight,
1134
+ batch_size,
1135
+ spatial_size,
1136
+ num_heads,
1137
+ channels,
1138
+ num_levels,
1139
+ num_query,
1140
+ num_point,
1141
+ grad_value,
1142
+ grad_sampling_loc,
1143
+ grad_attn_weight);
1144
+ break;
1145
+ case 32:
1146
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1147
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1148
+ 0, stream>>>(
1149
+ num_kernels,
1150
+ grad_col,
1151
+ data_value,
1152
+ data_spatial_shapes,
1153
+ data_level_start_index,
1154
+ data_sampling_loc,
1155
+ data_attn_weight,
1156
+ batch_size,
1157
+ spatial_size,
1158
+ num_heads,
1159
+ channels,
1160
+ num_levels,
1161
+ num_query,
1162
+ num_point,
1163
+ grad_value,
1164
+ grad_sampling_loc,
1165
+ grad_attn_weight);
1166
+ break;
1167
+ case 64:
1168
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1169
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1170
+ 0, stream>>>(
1171
+ num_kernels,
1172
+ grad_col,
1173
+ data_value,
1174
+ data_spatial_shapes,
1175
+ data_level_start_index,
1176
+ data_sampling_loc,
1177
+ data_attn_weight,
1178
+ batch_size,
1179
+ spatial_size,
1180
+ num_heads,
1181
+ channels,
1182
+ num_levels,
1183
+ num_query,
1184
+ num_point,
1185
+ grad_value,
1186
+ grad_sampling_loc,
1187
+ grad_attn_weight);
1188
+ break;
1189
+ case 128:
1190
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1191
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1192
+ 0, stream>>>(
1193
+ num_kernels,
1194
+ grad_col,
1195
+ data_value,
1196
+ data_spatial_shapes,
1197
+ data_level_start_index,
1198
+ data_sampling_loc,
1199
+ data_attn_weight,
1200
+ batch_size,
1201
+ spatial_size,
1202
+ num_heads,
1203
+ channels,
1204
+ num_levels,
1205
+ num_query,
1206
+ num_point,
1207
+ grad_value,
1208
+ grad_sampling_loc,
1209
+ grad_attn_weight);
1210
+ break;
1211
+ case 256:
1212
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1213
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1214
+ 0, stream>>>(
1215
+ num_kernels,
1216
+ grad_col,
1217
+ data_value,
1218
+ data_spatial_shapes,
1219
+ data_level_start_index,
1220
+ data_sampling_loc,
1221
+ data_attn_weight,
1222
+ batch_size,
1223
+ spatial_size,
1224
+ num_heads,
1225
+ channels,
1226
+ num_levels,
1227
+ num_query,
1228
+ num_point,
1229
+ grad_value,
1230
+ grad_sampling_loc,
1231
+ grad_attn_weight);
1232
+ break;
1233
+ case 512:
1234
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1235
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1236
+ 0, stream>>>(
1237
+ num_kernels,
1238
+ grad_col,
1239
+ data_value,
1240
+ data_spatial_shapes,
1241
+ data_level_start_index,
1242
+ data_sampling_loc,
1243
+ data_attn_weight,
1244
+ batch_size,
1245
+ spatial_size,
1246
+ num_heads,
1247
+ channels,
1248
+ num_levels,
1249
+ num_query,
1250
+ num_point,
1251
+ grad_value,
1252
+ grad_sampling_loc,
1253
+ grad_attn_weight);
1254
+ break;
1255
+ case 1024:
1256
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1257
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1258
+ 0, stream>>>(
1259
+ num_kernels,
1260
+ grad_col,
1261
+ data_value,
1262
+ data_spatial_shapes,
1263
+ data_level_start_index,
1264
+ data_sampling_loc,
1265
+ data_attn_weight,
1266
+ batch_size,
1267
+ spatial_size,
1268
+ num_heads,
1269
+ channels,
1270
+ num_levels,
1271
+ num_query,
1272
+ num_point,
1273
+ grad_value,
1274
+ grad_sampling_loc,
1275
+ grad_attn_weight);
1276
+ break;
1277
+ default:
1278
+ if (channels < 64)
1279
+ {
1280
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1281
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1282
+ num_threads*3*sizeof(scalar_t), stream>>>(
1283
+ num_kernels,
1284
+ grad_col,
1285
+ data_value,
1286
+ data_spatial_shapes,
1287
+ data_level_start_index,
1288
+ data_sampling_loc,
1289
+ data_attn_weight,
1290
+ batch_size,
1291
+ spatial_size,
1292
+ num_heads,
1293
+ channels,
1294
+ num_levels,
1295
+ num_query,
1296
+ num_point,
1297
+ grad_value,
1298
+ grad_sampling_loc,
1299
+ grad_attn_weight);
1300
+ }
1301
+ else
1302
+ {
1303
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1304
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1305
+ num_threads*3*sizeof(scalar_t), stream>>>(
1306
+ num_kernels,
1307
+ grad_col,
1308
+ data_value,
1309
+ data_spatial_shapes,
1310
+ data_level_start_index,
1311
+ data_sampling_loc,
1312
+ data_attn_weight,
1313
+ batch_size,
1314
+ spatial_size,
1315
+ num_heads,
1316
+ channels,
1317
+ num_levels,
1318
+ num_query,
1319
+ num_point,
1320
+ grad_value,
1321
+ grad_sampling_loc,
1322
+ grad_attn_weight);
1323
+ }
1324
+ }
1325
+ }
1326
+ cudaError_t err = cudaGetLastError();
1327
+ if (err != cudaSuccess)
1328
+ {
1329
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1330
+ }
1331
+
1332
+ }
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/ms_deform_attn.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+
18
+ #include "cpu/ms_deform_attn_cpu.h"
19
+
20
+ #ifdef WITH_CUDA
21
+ #include "cuda/ms_deform_attn_cuda.h"
22
+ #endif
23
+
24
+
25
+ at::Tensor
26
+ ms_deform_attn_forward(
27
+ const at::Tensor &value,
28
+ const at::Tensor &spatial_shapes,
29
+ const at::Tensor &level_start_index,
30
+ const at::Tensor &sampling_loc,
31
+ const at::Tensor &attn_weight,
32
+ const int im2col_step)
33
+ {
34
+ if (value.type().is_cuda())
35
+ {
36
+ #ifdef WITH_CUDA
37
+ return ms_deform_attn_cuda_forward(
38
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
39
+ #else
40
+ AT_ERROR("Not compiled with GPU support");
41
+ #endif
42
+ }
43
+ AT_ERROR("Not implemented on the CPU");
44
+ }
45
+
46
+ std::vector<at::Tensor>
47
+ ms_deform_attn_backward(
48
+ const at::Tensor &value,
49
+ const at::Tensor &spatial_shapes,
50
+ const at::Tensor &level_start_index,
51
+ const at::Tensor &sampling_loc,
52
+ const at::Tensor &attn_weight,
53
+ const at::Tensor &grad_output,
54
+ const int im2col_step)
55
+ {
56
+ if (value.type().is_cuda())
57
+ {
58
+ #ifdef WITH_CUDA
59
+ return ms_deform_attn_cuda_backward(
60
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
61
+ #else
62
+ AT_ERROR("Not compiled with GPU support");
63
+ #endif
64
+ }
65
+ AT_ERROR("Not implemented on the CPU");
66
+ }
67
+
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/src/vision.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include "ms_deform_attn.h"
17
+
18
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
20
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
21
+ }
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/pixel_decoder/ops/test.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.autograd import gradcheck
20
+
21
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
22
+
23
+
24
+ N, M, D = 1, 2, 2
25
+ Lq, L, P = 2, 2, 2
26
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
27
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
28
+ S = sum([(H*W).item() for H, W in shapes])
29
+
30
+
31
+ torch.manual_seed(3)
32
+
33
+
34
+ @torch.no_grad()
35
+ def check_forward_equal_with_pytorch_double():
36
+ value = torch.rand(N, S, M, D).cuda() * 0.01
37
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
38
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
39
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
40
+ im2col_step = 2
41
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
42
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
43
+ fwdok = torch.allclose(output_cuda, output_pytorch)
44
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
45
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
46
+
47
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
48
+
49
+
50
+ @torch.no_grad()
51
+ def check_forward_equal_with_pytorch_float():
52
+ value = torch.rand(N, S, M, D).cuda() * 0.01
53
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
54
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
55
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
56
+ im2col_step = 2
57
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
58
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
59
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
60
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
61
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
62
+
63
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
64
+
65
+
66
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
67
+
68
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
69
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
70
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
71
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
72
+ im2col_step = 2
73
+ func = MSDeformAttnFunction.apply
74
+
75
+ value.requires_grad = grad_value
76
+ sampling_locations.requires_grad = grad_sampling_loc
77
+ attention_weights.requires_grad = grad_attn_weight
78
+
79
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
80
+
81
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
82
+
83
+
84
+ if __name__ == '__main__':
85
+ check_forward_equal_with_pytorch_double()
86
+ check_forward_equal_with_pytorch_float()
87
+
88
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
89
+ check_gradient_numerical(channels, True, True, True)
90
+
91
+
92
+
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/transformer_decoder/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .oneformer_transformer_decoder import ContrastiveMultiScaleMaskedTransformerDecoder
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/modeling/transformer_decoder/oneformer_transformer_decoder.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import logging
7
+ import fvcore.nn.weight_init as weight_init
8
+ from typing import Optional
9
+ import torch
10
+ from torch import nn, Tensor
11
+ from torch.nn import functional as F
12
+
13
+ from annotator.oneformer.detectron2.config import configurable
14
+ from annotator.oneformer.detectron2.layers import Conv2d
15
+
16
+ from .position_encoding import PositionEmbeddingSine
17
+ from .transformer import Transformer
18
+
19
+ from annotator.oneformer.detectron2.utils.registry import Registry
20
+
21
+
22
+ TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE")
23
+ TRANSFORMER_DECODER_REGISTRY.__doc__ = """
24
+ Registry for transformer module in OneFormer.
25
+ """
26
+
27
+
28
+ def build_transformer_decoder(cfg, in_channels, mask_classification=True):
29
+ """
30
+ Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`.
31
+ """
32
+ name = cfg.MODEL.ONE_FORMER.TRANSFORMER_DECODER_NAME
33
+ return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, in_channels, mask_classification)
34
+
35
+
36
+ class SelfAttentionLayer(nn.Module):
37
+
38
+ def __init__(self, d_model, nhead, dropout=0.0,
39
+ activation="relu", normalize_before=False):
40
+ super().__init__()
41
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
42
+
43
+ self.norm = nn.LayerNorm(d_model)
44
+ self.dropout = nn.Dropout(dropout)
45
+
46
+ self.activation = _get_activation_fn(activation)
47
+ self.normalize_before = normalize_before
48
+
49
+ self._reset_parameters()
50
+
51
+ def _reset_parameters(self):
52
+ for p in self.parameters():
53
+ if p.dim() > 1:
54
+ nn.init.xavier_uniform_(p)
55
+
56
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
57
+ return tensor if pos is None else tensor + pos
58
+
59
+ def forward_post(self, tgt,
60
+ tgt_mask: Optional[Tensor] = None,
61
+ tgt_key_padding_mask: Optional[Tensor] = None,
62
+ query_pos: Optional[Tensor] = None):
63
+ q = k = self.with_pos_embed(tgt, query_pos)
64
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
65
+ key_padding_mask=tgt_key_padding_mask)[0]
66
+ tgt = tgt + self.dropout(tgt2)
67
+ tgt = self.norm(tgt)
68
+
69
+ return tgt
70
+
71
+ def forward_pre(self, tgt,
72
+ tgt_mask: Optional[Tensor] = None,
73
+ tgt_key_padding_mask: Optional[Tensor] = None,
74
+ query_pos: Optional[Tensor] = None):
75
+ tgt2 = self.norm(tgt)
76
+ q = k = self.with_pos_embed(tgt2, query_pos)
77
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
78
+ key_padding_mask=tgt_key_padding_mask)[0]
79
+ tgt = tgt + self.dropout(tgt2)
80
+
81
+ return tgt
82
+
83
+ def forward(self, tgt,
84
+ tgt_mask: Optional[Tensor] = None,
85
+ tgt_key_padding_mask: Optional[Tensor] = None,
86
+ query_pos: Optional[Tensor] = None):
87
+ if self.normalize_before:
88
+ return self.forward_pre(tgt, tgt_mask,
89
+ tgt_key_padding_mask, query_pos)
90
+ return self.forward_post(tgt, tgt_mask,
91
+ tgt_key_padding_mask, query_pos)
92
+
93
+
94
+ class CrossAttentionLayer(nn.Module):
95
+
96
+ def __init__(self, d_model, nhead, dropout=0.0,
97
+ activation="relu", normalize_before=False):
98
+ super().__init__()
99
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
100
+
101
+ self.norm = nn.LayerNorm(d_model)
102
+ self.dropout = nn.Dropout(dropout)
103
+
104
+ self.activation = _get_activation_fn(activation)
105
+ self.normalize_before = normalize_before
106
+
107
+ self._reset_parameters()
108
+
109
+ def _reset_parameters(self):
110
+ for p in self.parameters():
111
+ if p.dim() > 1:
112
+ nn.init.xavier_uniform_(p)
113
+
114
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
115
+ return tensor if pos is None else tensor + pos
116
+
117
+ def forward_post(self, tgt, memory,
118
+ memory_mask: Optional[Tensor] = None,
119
+ memory_key_padding_mask: Optional[Tensor] = None,
120
+ pos: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None):
122
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
123
+ key=self.with_pos_embed(memory, pos),
124
+ value=memory, attn_mask=memory_mask,
125
+ key_padding_mask=memory_key_padding_mask)[0]
126
+ tgt = tgt + self.dropout(tgt2)
127
+ tgt = self.norm(tgt)
128
+
129
+ return tgt
130
+
131
+ def forward_pre(self, tgt, memory,
132
+ memory_mask: Optional[Tensor] = None,
133
+ memory_key_padding_mask: Optional[Tensor] = None,
134
+ pos: Optional[Tensor] = None,
135
+ query_pos: Optional[Tensor] = None):
136
+ tgt2 = self.norm(tgt)
137
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
138
+ key=self.with_pos_embed(memory, pos),
139
+ value=memory, attn_mask=memory_mask,
140
+ key_padding_mask=memory_key_padding_mask)[0]
141
+ tgt = tgt + self.dropout(tgt2)
142
+
143
+ return tgt
144
+
145
+ def forward(self, tgt, memory,
146
+ memory_mask: Optional[Tensor] = None,
147
+ memory_key_padding_mask: Optional[Tensor] = None,
148
+ pos: Optional[Tensor] = None,
149
+ query_pos: Optional[Tensor] = None):
150
+ if self.normalize_before:
151
+ return self.forward_pre(tgt, memory, memory_mask,
152
+ memory_key_padding_mask, pos, query_pos)
153
+ return self.forward_post(tgt, memory, memory_mask,
154
+ memory_key_padding_mask, pos, query_pos)
155
+
156
+
157
+ class FFNLayer(nn.Module):
158
+
159
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
160
+ activation="relu", normalize_before=False):
161
+ super().__init__()
162
+ # Implementation of Feedforward model
163
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
164
+ self.dropout = nn.Dropout(dropout)
165
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
166
+
167
+ self.norm = nn.LayerNorm(d_model)
168
+
169
+ self.activation = _get_activation_fn(activation)
170
+ self.normalize_before = normalize_before
171
+
172
+ self._reset_parameters()
173
+
174
+ def _reset_parameters(self):
175
+ for p in self.parameters():
176
+ if p.dim() > 1:
177
+ nn.init.xavier_uniform_(p)
178
+
179
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
180
+ return tensor if pos is None else tensor + pos
181
+
182
+ def forward_post(self, tgt):
183
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
184
+ tgt = tgt + self.dropout(tgt2)
185
+ tgt = self.norm(tgt)
186
+ return tgt
187
+
188
+ def forward_pre(self, tgt):
189
+ tgt2 = self.norm(tgt)
190
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
191
+ tgt = tgt + self.dropout(tgt2)
192
+ return tgt
193
+
194
+ def forward(self, tgt):
195
+ if self.normalize_before:
196
+ return self.forward_pre(tgt)
197
+ return self.forward_post(tgt)
198
+
199
+
200
+ def _get_activation_fn(activation):
201
+ """Return an activation function given a string"""
202
+ if activation == "relu":
203
+ return F.relu
204
+ if activation == "gelu":
205
+ return F.gelu
206
+ if activation == "glu":
207
+ return F.glu
208
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
209
+
210
+
211
+ class MLP(nn.Module):
212
+ """ Very simple multi-layer perceptron (also called FFN)"""
213
+
214
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
215
+ super().__init__()
216
+ self.num_layers = num_layers
217
+ h = [hidden_dim] * (num_layers - 1)
218
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
219
+
220
+ def forward(self, x):
221
+ for i, layer in enumerate(self.layers):
222
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
223
+ return x
224
+
225
+
226
+ @TRANSFORMER_DECODER_REGISTRY.register()
227
+ class ContrastiveMultiScaleMaskedTransformerDecoder(nn.Module):
228
+
229
+ _version = 2
230
+
231
+ def _load_from_state_dict(
232
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
233
+ ):
234
+ version = local_metadata.get("version", None)
235
+ if version is None or version < 2:
236
+ # Do not warn if train from scratch
237
+ scratch = True
238
+ logger = logging.getLogger(__name__)
239
+ for k in list(state_dict.keys()):
240
+ newk = k
241
+ if "static_query" in k:
242
+ newk = k.replace("static_query", "query_feat")
243
+ if newk != k:
244
+ state_dict[newk] = state_dict[k]
245
+ del state_dict[k]
246
+ scratch = False
247
+
248
+ if not scratch:
249
+ logger.warning(
250
+ f"Weight format of {self.__class__.__name__} have changed! "
251
+ "Please upgrade your models. Applying automatic conversion now ..."
252
+ )
253
+
254
+ @configurable
255
+ def __init__(
256
+ self,
257
+ in_channels,
258
+ mask_classification=True,
259
+ *,
260
+ num_classes: int,
261
+ hidden_dim: int,
262
+ num_queries: int,
263
+ nheads: int,
264
+ dropout: float,
265
+ dim_feedforward: int,
266
+ enc_layers: int,
267
+ is_train: bool,
268
+ dec_layers: int,
269
+ class_dec_layers: int,
270
+ pre_norm: bool,
271
+ mask_dim: int,
272
+ enforce_input_project: bool,
273
+ use_task_norm: bool,
274
+ ):
275
+ """
276
+ NOTE: this interface is experimental.
277
+ Args:
278
+ in_channels: channels of the input features
279
+ mask_classification: whether to add mask classifier or not
280
+ num_classes: number of classes
281
+ hidden_dim: Transformer feature dimension
282
+ num_queries: number of queries
283
+ nheads: number of heads
284
+ dim_feedforward: feature dimension in feedforward network
285
+ enc_layers: number of Transformer encoder layers
286
+ dec_layers: number of Transformer decoder layers
287
+ pre_norm: whether to use pre-LayerNorm or not
288
+ mask_dim: mask feature dimension
289
+ enforce_input_project: add input project 1x1 conv even if input
290
+ channels and hidden dim is identical
291
+ """
292
+ super().__init__()
293
+
294
+ assert mask_classification, "Only support mask classification model"
295
+ self.mask_classification = mask_classification
296
+ self.is_train = is_train
297
+ self.use_task_norm = use_task_norm
298
+
299
+ # positional encoding
300
+ N_steps = hidden_dim // 2
301
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
302
+
303
+ self.class_transformer = Transformer(
304
+ d_model=hidden_dim,
305
+ dropout=dropout,
306
+ nhead=nheads,
307
+ dim_feedforward=dim_feedforward,
308
+ num_encoder_layers=enc_layers,
309
+ num_decoder_layers=class_dec_layers,
310
+ normalize_before=pre_norm,
311
+ return_intermediate_dec=False,
312
+ )
313
+
314
+ # define Transformer decoder here
315
+ self.num_heads = nheads
316
+ self.num_layers = dec_layers
317
+ self.transformer_self_attention_layers = nn.ModuleList()
318
+ self.transformer_cross_attention_layers = nn.ModuleList()
319
+ self.transformer_ffn_layers = nn.ModuleList()
320
+
321
+ for _ in range(self.num_layers):
322
+ self.transformer_self_attention_layers.append(
323
+ SelfAttentionLayer(
324
+ d_model=hidden_dim,
325
+ nhead=nheads,
326
+ dropout=0.0,
327
+ normalize_before=pre_norm,
328
+ )
329
+ )
330
+
331
+ self.transformer_cross_attention_layers.append(
332
+ CrossAttentionLayer(
333
+ d_model=hidden_dim,
334
+ nhead=nheads,
335
+ dropout=0.0,
336
+ normalize_before=pre_norm,
337
+ )
338
+ )
339
+
340
+ self.transformer_ffn_layers.append(
341
+ FFNLayer(
342
+ d_model=hidden_dim,
343
+ dim_feedforward=dim_feedforward,
344
+ dropout=0.0,
345
+ normalize_before=pre_norm,
346
+ )
347
+ )
348
+
349
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
350
+
351
+ self.num_queries = num_queries
352
+ # learnable query p.e.
353
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
354
+
355
+ # level embedding (we always use 3 scales)
356
+ self.num_feature_levels = 3
357
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
358
+ self.input_proj = nn.ModuleList()
359
+ for _ in range(self.num_feature_levels):
360
+ if in_channels != hidden_dim or enforce_input_project:
361
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
362
+ weight_init.c2_xavier_fill(self.input_proj[-1])
363
+ else:
364
+ self.input_proj.append(nn.Sequential())
365
+
366
+ self.class_input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
367
+ weight_init.c2_xavier_fill(self.class_input_proj)
368
+
369
+ # output FFNs
370
+ if self.mask_classification:
371
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
372
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
373
+
374
+ @classmethod
375
+ def from_config(cls, cfg, in_channels, mask_classification):
376
+ ret = {}
377
+ ret["in_channels"] = in_channels
378
+ ret["mask_classification"] = mask_classification
379
+
380
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
381
+ ret["hidden_dim"] = cfg.MODEL.ONE_FORMER.HIDDEN_DIM
382
+ ret["num_queries"] = cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES
383
+ # Transformer parameters:
384
+ ret["nheads"] = cfg.MODEL.ONE_FORMER.NHEADS
385
+ ret["dim_feedforward"] = cfg.MODEL.ONE_FORMER.DIM_FEEDFORWARD
386
+
387
+ # NOTE: because we add learnable query features which requires supervision,
388
+ # we add minus 1 to decoder layers to be consistent with our loss
389
+ # implementation: that is, number of auxiliary losses is always
390
+ # equal to number of decoder layers. With learnable query features, the number of
391
+ # auxiliary losses equals number of decoders plus 1.
392
+ assert cfg.MODEL.ONE_FORMER.DEC_LAYERS >= 1
393
+ ret["dec_layers"] = cfg.MODEL.ONE_FORMER.DEC_LAYERS - 1
394
+ ret["class_dec_layers"] = cfg.MODEL.ONE_FORMER.CLASS_DEC_LAYERS
395
+ ret["enc_layers"] = cfg.MODEL.ONE_FORMER.ENC_LAYERS
396
+ ret["dropout"] = cfg.MODEL.ONE_FORMER.DROPOUT
397
+ ret["pre_norm"] = cfg.MODEL.ONE_FORMER.PRE_NORM
398
+ ret["enforce_input_project"] = cfg.MODEL.ONE_FORMER.ENFORCE_INPUT_PROJ
399
+ ret["is_train"] = cfg.MODEL.IS_TRAIN
400
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
401
+ ret["use_task_norm"] = cfg.MODEL.ONE_FORMER.USE_TASK_NORM
402
+
403
+ return ret
404
+
405
+ def forward(self, x, mask_features, tasks, mask = None):
406
+ # x is a list of multi-scale feature
407
+ assert len(x) == self.num_feature_levels
408
+ src = []
409
+ pos = []
410
+ size_list = []
411
+
412
+ # disable mask, it does not affect performance
413
+ del mask
414
+
415
+ for i in range(self.num_feature_levels):
416
+ size_list.append(x[i].shape[-2:])
417
+ pos.append(self.pe_layer(x[i], None).flatten(2))
418
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
419
+
420
+ # flatten NxCxHxW to HWxNxC
421
+ pos[-1] = pos[-1].permute(2, 0, 1)
422
+ src[-1] = src[-1].permute(2, 0, 1)
423
+
424
+ _, bs, _ = src[0].shape
425
+
426
+ # QxNxC
427
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
428
+ tasks = tasks.unsqueeze(0)
429
+ if self.use_task_norm:
430
+ tasks = self.decoder_norm(tasks)
431
+
432
+ feats = self.pe_layer(mask_features, None)
433
+
434
+ out_t, _ = self.class_transformer(feats, None,
435
+ self.query_embed.weight[:-1],
436
+ self.class_input_proj(mask_features),
437
+ tasks if self.use_task_norm else None)
438
+ out_t = out_t[0].permute(1, 0, 2)
439
+
440
+ out = torch.cat([out_t, tasks], dim=0)
441
+
442
+ output = out.clone()
443
+
444
+ predictions_class = []
445
+ predictions_mask = []
446
+
447
+ # prediction heads on learnable query features
448
+ outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], i=0)
449
+ predictions_class.append(outputs_class)
450
+ predictions_mask.append(outputs_mask)
451
+
452
+ for i in range(self.num_layers):
453
+ level_index = i % self.num_feature_levels
454
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
455
+ # attention: cross-attention first
456
+ output = self.transformer_cross_attention_layers[i](
457
+ output, src[level_index],
458
+ memory_mask=attn_mask,
459
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
460
+ pos=pos[level_index], query_pos=query_embed
461
+ )
462
+
463
+ output = self.transformer_self_attention_layers[i](
464
+ output, tgt_mask=None,
465
+ tgt_key_padding_mask=None,
466
+ query_pos=query_embed
467
+ )
468
+
469
+ # FFN
470
+ output = self.transformer_ffn_layers[i](
471
+ output
472
+ )
473
+
474
+ outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], i=i+1)
475
+ predictions_class.append(outputs_class)
476
+ predictions_mask.append(outputs_mask)
477
+
478
+ assert len(predictions_class) == self.num_layers + 1
479
+ if self.is_train:
480
+ query_class = out.permute(1, 0, 2)
481
+ else:
482
+ query_class = None
483
+ out = {
484
+ 'contrastive_logits': query_class,
485
+ 'pred_logits': predictions_class[-1],
486
+ 'pred_masks': predictions_mask[-1],
487
+ 'aux_outputs': self._set_aux_loss(
488
+ predictions_class if self.mask_classification else None,
489
+ predictions_mask,
490
+ )
491
+ }
492
+
493
+ return out
494
+
495
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, i):
496
+ decoder_output = self.decoder_norm(output)
497
+ decoder_output = decoder_output.transpose(0, 1)
498
+ outputs_class = self.class_embed(decoder_output)
499
+ mask_embed = self.mask_embed(decoder_output)
500
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
501
+
502
+ # NOTE: prediction is of higher-resolution
503
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
504
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
505
+
506
+ # save_attn_masks(attn_mask.sigmoid() < 0.5, fname=f'demo/maps/{i}_pre_bool')
507
+
508
+ # must use bool type
509
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
510
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
511
+ attn_mask = attn_mask.detach()
512
+
513
+ return outputs_class, outputs_mask, attn_mask
514
+
515
+ @torch.jit.unused
516
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks):
517
+ # this is a workaround to make torchscript happy, as torchscript
518
+ # doesn't support dictionary with non-homogeneous values, such
519
+ # as a dict having both a Tensor and a list.
520
+ if self.mask_classification:
521
+ aux_list = [
522
+ {"pred_logits": a, "pred_masks": b}
523
+ for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
524
+ ]
525
+ else:
526
+ aux_list = [{"pred_masks": b} for b, in outputs_seg_masks[:-1]]
527
+
528
+ return aux_list