509291f2031bcb75c1e4d3756a977dee17a9881d5c89f6cfa26bfccdbea5b413
Browse files- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation.py +380 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation_impl.py +736 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/data/transforms/transform.py +351 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/__init__.py +12 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/defaults.py +715 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/hooks.py +690 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/launch.py +123 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/train_loop.py +469 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/__init__.py +12 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/cityscapes_evaluation.py +197 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/coco_evaluation.py +722 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/evaluator.py +224 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/fast_eval_api.py +121 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/lvis_evaluation.py +380 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/panoptic_evaluation.py +199 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/pascal_voc_evaluation.py +300 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/rotated_coco_evaluation.py +207 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/sem_seg_evaluation.py +265 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/testing.py +85 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/README.md +15 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/__init__.py +30 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/api.py +230 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/c10.py +557 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/caffe2_export.py +203 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/caffe2_inference.py +161 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/caffe2_modeling.py +419 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/caffe2_patch.py +152 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/flatten.py +330 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/shared.py +1039 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/torchscript.py +132 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/torchscript_patch.py +406 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/__init__.py +26 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/aspp.py +144 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/batch_norm.py +300 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/blocks.py +111 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/README.md +7 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h +115 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp +522 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu +443 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h +35 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp +39 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu +130 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h +370 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/cocoeval/cocoeval.cpp +507 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/cocoeval/cocoeval.h +88 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/cuda_version.cu +26 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/deformable/deform_conv.h +377 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/deformable/deform_conv_cuda.cu +1223 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu +1288 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/nms_rotated/nms_rotated.h +39 -0
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import inspect
|
5 |
+
import numpy as np
|
6 |
+
import pprint
|
7 |
+
from typing import Any, List, Optional, Tuple, Union
|
8 |
+
from fvcore.transforms.transform import Transform, TransformList
|
9 |
+
|
10 |
+
"""
|
11 |
+
See "Data Augmentation" tutorial for an overview of the system:
|
12 |
+
https://detectron2.readthedocs.io/tutorials/augmentation.html
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"Augmentation",
|
18 |
+
"AugmentationList",
|
19 |
+
"AugInput",
|
20 |
+
"TransformGen",
|
21 |
+
"apply_transform_gens",
|
22 |
+
"StandardAugInput",
|
23 |
+
"apply_augmentations",
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
def _check_img_dtype(img):
|
28 |
+
assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format(
|
29 |
+
type(img)
|
30 |
+
)
|
31 |
+
assert not isinstance(img.dtype, np.integer) or (
|
32 |
+
img.dtype == np.uint8
|
33 |
+
), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format(
|
34 |
+
img.dtype
|
35 |
+
)
|
36 |
+
assert img.ndim in [2, 3], img.ndim
|
37 |
+
|
38 |
+
|
39 |
+
def _get_aug_input_args(aug, aug_input) -> List[Any]:
|
40 |
+
"""
|
41 |
+
Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``.
|
42 |
+
"""
|
43 |
+
if aug.input_args is None:
|
44 |
+
# Decide what attributes are needed automatically
|
45 |
+
prms = list(inspect.signature(aug.get_transform).parameters.items())
|
46 |
+
# The default behavior is: if there is one parameter, then its "image"
|
47 |
+
# (work automatically for majority of use cases, and also avoid BC breaking),
|
48 |
+
# Otherwise, use the argument names.
|
49 |
+
if len(prms) == 1:
|
50 |
+
names = ("image",)
|
51 |
+
else:
|
52 |
+
names = []
|
53 |
+
for name, prm in prms:
|
54 |
+
if prm.kind in (
|
55 |
+
inspect.Parameter.VAR_POSITIONAL,
|
56 |
+
inspect.Parameter.VAR_KEYWORD,
|
57 |
+
):
|
58 |
+
raise TypeError(
|
59 |
+
f""" \
|
60 |
+
The default implementation of `{type(aug)}.__call__` does not allow \
|
61 |
+
`{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \
|
62 |
+
If arguments are unknown, reimplement `__call__` instead. \
|
63 |
+
"""
|
64 |
+
)
|
65 |
+
names.append(name)
|
66 |
+
aug.input_args = tuple(names)
|
67 |
+
|
68 |
+
args = []
|
69 |
+
for f in aug.input_args:
|
70 |
+
try:
|
71 |
+
args.append(getattr(aug_input, f))
|
72 |
+
except AttributeError as e:
|
73 |
+
raise AttributeError(
|
74 |
+
f"{type(aug)}.get_transform needs input attribute '{f}', "
|
75 |
+
f"but it is not an attribute of {type(aug_input)}!"
|
76 |
+
) from e
|
77 |
+
return args
|
78 |
+
|
79 |
+
|
80 |
+
class Augmentation:
|
81 |
+
"""
|
82 |
+
Augmentation defines (often random) policies/strategies to generate :class:`Transform`
|
83 |
+
from data. It is often used for pre-processing of input data.
|
84 |
+
|
85 |
+
A "policy" that generates a :class:`Transform` may, in the most general case,
|
86 |
+
need arbitrary information from input data in order to determine what transforms
|
87 |
+
to apply. Therefore, each :class:`Augmentation` instance defines the arguments
|
88 |
+
needed by its :meth:`get_transform` method. When called with the positional arguments,
|
89 |
+
the :meth:`get_transform` method executes the policy.
|
90 |
+
|
91 |
+
Note that :class:`Augmentation` defines the policies to create a :class:`Transform`,
|
92 |
+
but not how to execute the actual transform operations to those data.
|
93 |
+
Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform.
|
94 |
+
|
95 |
+
The returned `Transform` object is meant to describe deterministic transformation, which means
|
96 |
+
it can be re-applied on associated data, e.g. the geometry of an image and its segmentation
|
97 |
+
masks need to be transformed together.
|
98 |
+
(If such re-application is not needed, then determinism is not a crucial requirement.)
|
99 |
+
"""
|
100 |
+
|
101 |
+
input_args: Optional[Tuple[str]] = None
|
102 |
+
"""
|
103 |
+
Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``.
|
104 |
+
By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only
|
105 |
+
contain "image". As long as the argument name convention is followed, there is no need for
|
106 |
+
users to touch this attribute.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def _init(self, params=None):
|
110 |
+
if params:
|
111 |
+
for k, v in params.items():
|
112 |
+
if k != "self" and not k.startswith("_"):
|
113 |
+
setattr(self, k, v)
|
114 |
+
|
115 |
+
def get_transform(self, *args) -> Transform:
|
116 |
+
"""
|
117 |
+
Execute the policy based on input data, and decide what transform to apply to inputs.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
args: Any fixed-length positional arguments. By default, the name of the arguments
|
121 |
+
should exist in the :class:`AugInput` to be used.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Transform: Returns the deterministic transform to apply to the input.
|
125 |
+
|
126 |
+
Examples:
|
127 |
+
::
|
128 |
+
class MyAug:
|
129 |
+
# if a policy needs to know both image and semantic segmentation
|
130 |
+
def get_transform(image, sem_seg) -> T.Transform:
|
131 |
+
pass
|
132 |
+
tfm: Transform = MyAug().get_transform(image, sem_seg)
|
133 |
+
new_image = tfm.apply_image(image)
|
134 |
+
|
135 |
+
Notes:
|
136 |
+
Users can freely use arbitrary new argument names in custom
|
137 |
+
:meth:`get_transform` method, as long as they are available in the
|
138 |
+
input data. In detectron2 we use the following convention:
|
139 |
+
|
140 |
+
* image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
|
141 |
+
floating point in range [0, 1] or [0, 255].
|
142 |
+
* boxes: (N,4) ndarray of float32. It represents the instance bounding boxes
|
143 |
+
of N instances. Each is in XYXY format in unit of absolute coordinates.
|
144 |
+
* sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel.
|
145 |
+
|
146 |
+
We do not specify convention for other types and do not include builtin
|
147 |
+
:class:`Augmentation` that uses other types in detectron2.
|
148 |
+
"""
|
149 |
+
raise NotImplementedError
|
150 |
+
|
151 |
+
def __call__(self, aug_input) -> Transform:
|
152 |
+
"""
|
153 |
+
Augment the given `aug_input` **in-place**, and return the transform that's used.
|
154 |
+
|
155 |
+
This method will be called to apply the augmentation. In most augmentation, it
|
156 |
+
is enough to use the default implementation, which calls :meth:`get_transform`
|
157 |
+
using the inputs. But a subclass can overwrite it to have more complicated logic.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
aug_input (AugInput): an object that has attributes needed by this augmentation
|
161 |
+
(defined by ``self.get_transform``). Its ``transform`` method will be called
|
162 |
+
to in-place transform it.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Transform: the transform that is applied on the input.
|
166 |
+
"""
|
167 |
+
args = _get_aug_input_args(self, aug_input)
|
168 |
+
tfm = self.get_transform(*args)
|
169 |
+
assert isinstance(tfm, (Transform, TransformList)), (
|
170 |
+
f"{type(self)}.get_transform must return an instance of Transform! "
|
171 |
+
f"Got {type(tfm)} instead."
|
172 |
+
)
|
173 |
+
aug_input.transform(tfm)
|
174 |
+
return tfm
|
175 |
+
|
176 |
+
def _rand_range(self, low=1.0, high=None, size=None):
|
177 |
+
"""
|
178 |
+
Uniform float random number between low and high.
|
179 |
+
"""
|
180 |
+
if high is None:
|
181 |
+
low, high = 0, low
|
182 |
+
if size is None:
|
183 |
+
size = []
|
184 |
+
return np.random.uniform(low, high, size)
|
185 |
+
|
186 |
+
def __repr__(self):
|
187 |
+
"""
|
188 |
+
Produce something like:
|
189 |
+
"MyAugmentation(field1={self.field1}, field2={self.field2})"
|
190 |
+
"""
|
191 |
+
try:
|
192 |
+
sig = inspect.signature(self.__init__)
|
193 |
+
classname = type(self).__name__
|
194 |
+
argstr = []
|
195 |
+
for name, param in sig.parameters.items():
|
196 |
+
assert (
|
197 |
+
param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD
|
198 |
+
), "The default __repr__ doesn't support *args or **kwargs"
|
199 |
+
assert hasattr(self, name), (
|
200 |
+
"Attribute {} not found! "
|
201 |
+
"Default __repr__ only works if attributes match the constructor.".format(name)
|
202 |
+
)
|
203 |
+
attr = getattr(self, name)
|
204 |
+
default = param.default
|
205 |
+
if default is attr:
|
206 |
+
continue
|
207 |
+
attr_str = pprint.pformat(attr)
|
208 |
+
if "\n" in attr_str:
|
209 |
+
# don't show it if pformat decides to use >1 lines
|
210 |
+
attr_str = "..."
|
211 |
+
argstr.append("{}={}".format(name, attr_str))
|
212 |
+
return "{}({})".format(classname, ", ".join(argstr))
|
213 |
+
except AssertionError:
|
214 |
+
return super().__repr__()
|
215 |
+
|
216 |
+
__str__ = __repr__
|
217 |
+
|
218 |
+
|
219 |
+
class _TransformToAug(Augmentation):
|
220 |
+
def __init__(self, tfm: Transform):
|
221 |
+
self.tfm = tfm
|
222 |
+
|
223 |
+
def get_transform(self, *args):
|
224 |
+
return self.tfm
|
225 |
+
|
226 |
+
def __repr__(self):
|
227 |
+
return repr(self.tfm)
|
228 |
+
|
229 |
+
__str__ = __repr__
|
230 |
+
|
231 |
+
|
232 |
+
def _transform_to_aug(tfm_or_aug):
|
233 |
+
"""
|
234 |
+
Wrap Transform into Augmentation.
|
235 |
+
Private, used internally to implement augmentations.
|
236 |
+
"""
|
237 |
+
assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug
|
238 |
+
if isinstance(tfm_or_aug, Augmentation):
|
239 |
+
return tfm_or_aug
|
240 |
+
else:
|
241 |
+
return _TransformToAug(tfm_or_aug)
|
242 |
+
|
243 |
+
|
244 |
+
class AugmentationList(Augmentation):
|
245 |
+
"""
|
246 |
+
Apply a sequence of augmentations.
|
247 |
+
|
248 |
+
It has ``__call__`` method to apply the augmentations.
|
249 |
+
|
250 |
+
Note that :meth:`get_transform` method is impossible (will throw error if called)
|
251 |
+
for :class:`AugmentationList`, because in order to apply a sequence of augmentations,
|
252 |
+
the kth augmentation must be applied first, to provide inputs needed by the (k+1)th
|
253 |
+
augmentation.
|
254 |
+
"""
|
255 |
+
|
256 |
+
def __init__(self, augs):
|
257 |
+
"""
|
258 |
+
Args:
|
259 |
+
augs (list[Augmentation or Transform]):
|
260 |
+
"""
|
261 |
+
super().__init__()
|
262 |
+
self.augs = [_transform_to_aug(x) for x in augs]
|
263 |
+
|
264 |
+
def __call__(self, aug_input) -> TransformList:
|
265 |
+
tfms = []
|
266 |
+
for x in self.augs:
|
267 |
+
tfm = x(aug_input)
|
268 |
+
tfms.append(tfm)
|
269 |
+
return TransformList(tfms)
|
270 |
+
|
271 |
+
def __repr__(self):
|
272 |
+
msgs = [str(x) for x in self.augs]
|
273 |
+
return "AugmentationList[{}]".format(", ".join(msgs))
|
274 |
+
|
275 |
+
__str__ = __repr__
|
276 |
+
|
277 |
+
|
278 |
+
class AugInput:
|
279 |
+
"""
|
280 |
+
Input that can be used with :meth:`Augmentation.__call__`.
|
281 |
+
This is a standard implementation for the majority of use cases.
|
282 |
+
This class provides the standard attributes **"image", "boxes", "sem_seg"**
|
283 |
+
defined in :meth:`__init__` and they may be needed by different augmentations.
|
284 |
+
Most augmentation policies do not need attributes beyond these three.
|
285 |
+
|
286 |
+
After applying augmentations to these attributes (using :meth:`AugInput.transform`),
|
287 |
+
the returned transforms can then be used to transform other data structures that users have.
|
288 |
+
|
289 |
+
Examples:
|
290 |
+
::
|
291 |
+
input = AugInput(image, boxes=boxes)
|
292 |
+
tfms = augmentation(input)
|
293 |
+
transformed_image = input.image
|
294 |
+
transformed_boxes = input.boxes
|
295 |
+
transformed_other_data = tfms.apply_other(other_data)
|
296 |
+
|
297 |
+
An extended project that works with new data types may implement augmentation policies
|
298 |
+
that need other inputs. An algorithm may need to transform inputs in a way different
|
299 |
+
from the standard approach defined in this class. In those rare situations, users can
|
300 |
+
implement a class similar to this class, that satify the following condition:
|
301 |
+
|
302 |
+
* The input must provide access to these data in the form of attribute access
|
303 |
+
(``getattr``). For example, if an :class:`Augmentation` to be applied needs "image"
|
304 |
+
and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg".
|
305 |
+
* The input must have a ``transform(tfm: Transform) -> None`` method which
|
306 |
+
in-place transforms all its attributes.
|
307 |
+
"""
|
308 |
+
|
309 |
+
# TODO maybe should support more builtin data types here
|
310 |
+
def __init__(
|
311 |
+
self,
|
312 |
+
image: np.ndarray,
|
313 |
+
*,
|
314 |
+
boxes: Optional[np.ndarray] = None,
|
315 |
+
sem_seg: Optional[np.ndarray] = None,
|
316 |
+
):
|
317 |
+
"""
|
318 |
+
Args:
|
319 |
+
image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
|
320 |
+
floating point in range [0, 1] or [0, 255]. The meaning of C is up
|
321 |
+
to users.
|
322 |
+
boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode
|
323 |
+
sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element
|
324 |
+
is an integer label of pixel.
|
325 |
+
"""
|
326 |
+
_check_img_dtype(image)
|
327 |
+
self.image = image
|
328 |
+
self.boxes = boxes
|
329 |
+
self.sem_seg = sem_seg
|
330 |
+
|
331 |
+
def transform(self, tfm: Transform) -> None:
|
332 |
+
"""
|
333 |
+
In-place transform all attributes of this class.
|
334 |
+
|
335 |
+
By "in-place", it means after calling this method, accessing an attribute such
|
336 |
+
as ``self.image`` will return transformed data.
|
337 |
+
"""
|
338 |
+
self.image = tfm.apply_image(self.image)
|
339 |
+
if self.boxes is not None:
|
340 |
+
self.boxes = tfm.apply_box(self.boxes)
|
341 |
+
if self.sem_seg is not None:
|
342 |
+
self.sem_seg = tfm.apply_segmentation(self.sem_seg)
|
343 |
+
|
344 |
+
def apply_augmentations(
|
345 |
+
self, augmentations: List[Union[Augmentation, Transform]]
|
346 |
+
) -> TransformList:
|
347 |
+
"""
|
348 |
+
Equivalent of ``AugmentationList(augmentations)(self)``
|
349 |
+
"""
|
350 |
+
return AugmentationList(augmentations)(self)
|
351 |
+
|
352 |
+
|
353 |
+
def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs):
|
354 |
+
"""
|
355 |
+
Use ``T.AugmentationList(augmentations)(inputs)`` instead.
|
356 |
+
"""
|
357 |
+
if isinstance(inputs, np.ndarray):
|
358 |
+
# handle the common case of image-only Augmentation, also for backward compatibility
|
359 |
+
image_only = True
|
360 |
+
inputs = AugInput(inputs)
|
361 |
+
else:
|
362 |
+
image_only = False
|
363 |
+
tfms = inputs.apply_augmentations(augmentations)
|
364 |
+
return inputs.image if image_only else inputs, tfms
|
365 |
+
|
366 |
+
|
367 |
+
apply_transform_gens = apply_augmentations
|
368 |
+
"""
|
369 |
+
Alias for backward-compatibility.
|
370 |
+
"""
|
371 |
+
|
372 |
+
TransformGen = Augmentation
|
373 |
+
"""
|
374 |
+
Alias for Augmentation, since it is something that generates :class:`Transform`s
|
375 |
+
"""
|
376 |
+
|
377 |
+
StandardAugInput = AugInput
|
378 |
+
"""
|
379 |
+
Alias for compatibility. It's not worth the complexity to have two classes.
|
380 |
+
"""
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation_impl.py
ADDED
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
"""
|
4 |
+
Implement many useful :class:`Augmentation`.
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import sys
|
8 |
+
from numpy import random
|
9 |
+
from typing import Tuple
|
10 |
+
import torch
|
11 |
+
from fvcore.transforms.transform import (
|
12 |
+
BlendTransform,
|
13 |
+
CropTransform,
|
14 |
+
HFlipTransform,
|
15 |
+
NoOpTransform,
|
16 |
+
PadTransform,
|
17 |
+
Transform,
|
18 |
+
TransformList,
|
19 |
+
VFlipTransform,
|
20 |
+
)
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from annotator.oneformer.detectron2.structures import Boxes, pairwise_iou
|
24 |
+
|
25 |
+
from .augmentation import Augmentation, _transform_to_aug
|
26 |
+
from .transform import ExtentTransform, ResizeTransform, RotationTransform
|
27 |
+
|
28 |
+
__all__ = [
|
29 |
+
"FixedSizeCrop",
|
30 |
+
"RandomApply",
|
31 |
+
"RandomBrightness",
|
32 |
+
"RandomContrast",
|
33 |
+
"RandomCrop",
|
34 |
+
"RandomExtent",
|
35 |
+
"RandomFlip",
|
36 |
+
"RandomSaturation",
|
37 |
+
"RandomLighting",
|
38 |
+
"RandomRotation",
|
39 |
+
"Resize",
|
40 |
+
"ResizeScale",
|
41 |
+
"ResizeShortestEdge",
|
42 |
+
"RandomCrop_CategoryAreaConstraint",
|
43 |
+
"RandomResize",
|
44 |
+
"MinIoURandomCrop",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
class RandomApply(Augmentation):
|
49 |
+
"""
|
50 |
+
Randomly apply an augmentation with a given probability.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, tfm_or_aug, prob=0.5):
|
54 |
+
"""
|
55 |
+
Args:
|
56 |
+
tfm_or_aug (Transform, Augmentation): the transform or augmentation
|
57 |
+
to be applied. It can either be a `Transform` or `Augmentation`
|
58 |
+
instance.
|
59 |
+
prob (float): probability between 0.0 and 1.0 that
|
60 |
+
the wrapper transformation is applied
|
61 |
+
"""
|
62 |
+
super().__init__()
|
63 |
+
self.aug = _transform_to_aug(tfm_or_aug)
|
64 |
+
assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
|
65 |
+
self.prob = prob
|
66 |
+
|
67 |
+
def get_transform(self, *args):
|
68 |
+
do = self._rand_range() < self.prob
|
69 |
+
if do:
|
70 |
+
return self.aug.get_transform(*args)
|
71 |
+
else:
|
72 |
+
return NoOpTransform()
|
73 |
+
|
74 |
+
def __call__(self, aug_input):
|
75 |
+
do = self._rand_range() < self.prob
|
76 |
+
if do:
|
77 |
+
return self.aug(aug_input)
|
78 |
+
else:
|
79 |
+
return NoOpTransform()
|
80 |
+
|
81 |
+
|
82 |
+
class RandomFlip(Augmentation):
|
83 |
+
"""
|
84 |
+
Flip the image horizontally or vertically with the given probability.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, prob=0.5, *, horizontal=True, vertical=False):
|
88 |
+
"""
|
89 |
+
Args:
|
90 |
+
prob (float): probability of flip.
|
91 |
+
horizontal (boolean): whether to apply horizontal flipping
|
92 |
+
vertical (boolean): whether to apply vertical flipping
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
|
96 |
+
if horizontal and vertical:
|
97 |
+
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
|
98 |
+
if not horizontal and not vertical:
|
99 |
+
raise ValueError("At least one of horiz or vert has to be True!")
|
100 |
+
self._init(locals())
|
101 |
+
|
102 |
+
def get_transform(self, image):
|
103 |
+
h, w = image.shape[:2]
|
104 |
+
do = self._rand_range() < self.prob
|
105 |
+
if do:
|
106 |
+
if self.horizontal:
|
107 |
+
return HFlipTransform(w)
|
108 |
+
elif self.vertical:
|
109 |
+
return VFlipTransform(h)
|
110 |
+
else:
|
111 |
+
return NoOpTransform()
|
112 |
+
|
113 |
+
|
114 |
+
class Resize(Augmentation):
|
115 |
+
"""Resize image to a fixed target size"""
|
116 |
+
|
117 |
+
def __init__(self, shape, interp=Image.BILINEAR):
|
118 |
+
"""
|
119 |
+
Args:
|
120 |
+
shape: (h, w) tuple or a int
|
121 |
+
interp: PIL interpolation method
|
122 |
+
"""
|
123 |
+
if isinstance(shape, int):
|
124 |
+
shape = (shape, shape)
|
125 |
+
shape = tuple(shape)
|
126 |
+
self._init(locals())
|
127 |
+
|
128 |
+
def get_transform(self, image):
|
129 |
+
return ResizeTransform(
|
130 |
+
image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
class ResizeShortestEdge(Augmentation):
|
135 |
+
"""
|
136 |
+
Resize the image while keeping the aspect ratio unchanged.
|
137 |
+
It attempts to scale the shorter edge to the given `short_edge_length`,
|
138 |
+
as long as the longer edge does not exceed `max_size`.
|
139 |
+
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
140 |
+
"""
|
141 |
+
|
142 |
+
@torch.jit.unused
|
143 |
+
def __init__(
|
144 |
+
self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
|
145 |
+
):
|
146 |
+
"""
|
147 |
+
Args:
|
148 |
+
short_edge_length (list[int]): If ``sample_style=="range"``,
|
149 |
+
a [min, max] interval from which to sample the shortest edge length.
|
150 |
+
If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
|
151 |
+
max_size (int): maximum allowed longest edge length.
|
152 |
+
sample_style (str): either "range" or "choice".
|
153 |
+
"""
|
154 |
+
super().__init__()
|
155 |
+
assert sample_style in ["range", "choice"], sample_style
|
156 |
+
|
157 |
+
self.is_range = sample_style == "range"
|
158 |
+
if isinstance(short_edge_length, int):
|
159 |
+
short_edge_length = (short_edge_length, short_edge_length)
|
160 |
+
if self.is_range:
|
161 |
+
assert len(short_edge_length) == 2, (
|
162 |
+
"short_edge_length must be two values using 'range' sample style."
|
163 |
+
f" Got {short_edge_length}!"
|
164 |
+
)
|
165 |
+
self._init(locals())
|
166 |
+
|
167 |
+
@torch.jit.unused
|
168 |
+
def get_transform(self, image):
|
169 |
+
h, w = image.shape[:2]
|
170 |
+
if self.is_range:
|
171 |
+
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
172 |
+
else:
|
173 |
+
size = np.random.choice(self.short_edge_length)
|
174 |
+
if size == 0:
|
175 |
+
return NoOpTransform()
|
176 |
+
|
177 |
+
newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size)
|
178 |
+
return ResizeTransform(h, w, newh, neww, self.interp)
|
179 |
+
|
180 |
+
@staticmethod
|
181 |
+
def get_output_shape(
|
182 |
+
oldh: int, oldw: int, short_edge_length: int, max_size: int
|
183 |
+
) -> Tuple[int, int]:
|
184 |
+
"""
|
185 |
+
Compute the output size given input size and target short edge length.
|
186 |
+
"""
|
187 |
+
h, w = oldh, oldw
|
188 |
+
size = short_edge_length * 1.0
|
189 |
+
scale = size / min(h, w)
|
190 |
+
if h < w:
|
191 |
+
newh, neww = size, scale * w
|
192 |
+
else:
|
193 |
+
newh, neww = scale * h, size
|
194 |
+
if max(newh, neww) > max_size:
|
195 |
+
scale = max_size * 1.0 / max(newh, neww)
|
196 |
+
newh = newh * scale
|
197 |
+
neww = neww * scale
|
198 |
+
neww = int(neww + 0.5)
|
199 |
+
newh = int(newh + 0.5)
|
200 |
+
return (newh, neww)
|
201 |
+
|
202 |
+
|
203 |
+
class ResizeScale(Augmentation):
|
204 |
+
"""
|
205 |
+
Takes target size as input and randomly scales the given target size between `min_scale`
|
206 |
+
and `max_scale`. It then scales the input image such that it fits inside the scaled target
|
207 |
+
box, keeping the aspect ratio constant.
|
208 |
+
This implements the resize part of the Google's 'resize_and_crop' data augmentation:
|
209 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
|
210 |
+
"""
|
211 |
+
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
min_scale: float,
|
215 |
+
max_scale: float,
|
216 |
+
target_height: int,
|
217 |
+
target_width: int,
|
218 |
+
interp: int = Image.BILINEAR,
|
219 |
+
):
|
220 |
+
"""
|
221 |
+
Args:
|
222 |
+
min_scale: minimum image scale range.
|
223 |
+
max_scale: maximum image scale range.
|
224 |
+
target_height: target image height.
|
225 |
+
target_width: target image width.
|
226 |
+
interp: image interpolation method.
|
227 |
+
"""
|
228 |
+
super().__init__()
|
229 |
+
self._init(locals())
|
230 |
+
|
231 |
+
def _get_resize(self, image: np.ndarray, scale: float) -> Transform:
|
232 |
+
input_size = image.shape[:2]
|
233 |
+
|
234 |
+
# Compute new target size given a scale.
|
235 |
+
target_size = (self.target_height, self.target_width)
|
236 |
+
target_scale_size = np.multiply(target_size, scale)
|
237 |
+
|
238 |
+
# Compute actual rescaling applied to input image and output size.
|
239 |
+
output_scale = np.minimum(
|
240 |
+
target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
|
241 |
+
)
|
242 |
+
output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
|
243 |
+
|
244 |
+
return ResizeTransform(
|
245 |
+
input_size[0], input_size[1], output_size[0], output_size[1], self.interp
|
246 |
+
)
|
247 |
+
|
248 |
+
def get_transform(self, image: np.ndarray) -> Transform:
|
249 |
+
random_scale = np.random.uniform(self.min_scale, self.max_scale)
|
250 |
+
return self._get_resize(image, random_scale)
|
251 |
+
|
252 |
+
|
253 |
+
class RandomRotation(Augmentation):
|
254 |
+
"""
|
255 |
+
This method returns a copy of this image, rotated the given
|
256 |
+
number of degrees counter clockwise around the given center.
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
|
260 |
+
"""
|
261 |
+
Args:
|
262 |
+
angle (list[float]): If ``sample_style=="range"``,
|
263 |
+
a [min, max] interval from which to sample the angle (in degrees).
|
264 |
+
If ``sample_style=="choice"``, a list of angles to sample from
|
265 |
+
expand (bool): choose if the image should be resized to fit the whole
|
266 |
+
rotated image (default), or simply cropped
|
267 |
+
center (list[[float, float]]): If ``sample_style=="range"``,
|
268 |
+
a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
|
269 |
+
[0, 0] being the top left of the image and [1, 1] the bottom right.
|
270 |
+
If ``sample_style=="choice"``, a list of centers to sample from
|
271 |
+
Default: None, which means that the center of rotation is the center of the image
|
272 |
+
center has no effect if expand=True because it only affects shifting
|
273 |
+
"""
|
274 |
+
super().__init__()
|
275 |
+
assert sample_style in ["range", "choice"], sample_style
|
276 |
+
self.is_range = sample_style == "range"
|
277 |
+
if isinstance(angle, (float, int)):
|
278 |
+
angle = (angle, angle)
|
279 |
+
if center is not None and isinstance(center[0], (float, int)):
|
280 |
+
center = (center, center)
|
281 |
+
self._init(locals())
|
282 |
+
|
283 |
+
def get_transform(self, image):
|
284 |
+
h, w = image.shape[:2]
|
285 |
+
center = None
|
286 |
+
if self.is_range:
|
287 |
+
angle = np.random.uniform(self.angle[0], self.angle[1])
|
288 |
+
if self.center is not None:
|
289 |
+
center = (
|
290 |
+
np.random.uniform(self.center[0][0], self.center[1][0]),
|
291 |
+
np.random.uniform(self.center[0][1], self.center[1][1]),
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
angle = np.random.choice(self.angle)
|
295 |
+
if self.center is not None:
|
296 |
+
center = np.random.choice(self.center)
|
297 |
+
|
298 |
+
if center is not None:
|
299 |
+
center = (w * center[0], h * center[1]) # Convert to absolute coordinates
|
300 |
+
|
301 |
+
if angle % 360 == 0:
|
302 |
+
return NoOpTransform()
|
303 |
+
|
304 |
+
return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
|
305 |
+
|
306 |
+
|
307 |
+
class FixedSizeCrop(Augmentation):
|
308 |
+
"""
|
309 |
+
If `crop_size` is smaller than the input image size, then it uses a random crop of
|
310 |
+
the crop size. If `crop_size` is larger than the input image size, then it pads
|
311 |
+
the right and the bottom of the image to the crop size if `pad` is True, otherwise
|
312 |
+
it returns the smaller image.
|
313 |
+
"""
|
314 |
+
|
315 |
+
def __init__(
|
316 |
+
self,
|
317 |
+
crop_size: Tuple[int],
|
318 |
+
pad: bool = True,
|
319 |
+
pad_value: float = 128.0,
|
320 |
+
seg_pad_value: int = 255,
|
321 |
+
):
|
322 |
+
"""
|
323 |
+
Args:
|
324 |
+
crop_size: target image (height, width).
|
325 |
+
pad: if True, will pad images smaller than `crop_size` up to `crop_size`
|
326 |
+
pad_value: the padding value to the image.
|
327 |
+
seg_pad_value: the padding value to the segmentation mask.
|
328 |
+
"""
|
329 |
+
super().__init__()
|
330 |
+
self._init(locals())
|
331 |
+
|
332 |
+
def _get_crop(self, image: np.ndarray) -> Transform:
|
333 |
+
# Compute the image scale and scaled size.
|
334 |
+
input_size = image.shape[:2]
|
335 |
+
output_size = self.crop_size
|
336 |
+
|
337 |
+
# Add random crop if the image is scaled up.
|
338 |
+
max_offset = np.subtract(input_size, output_size)
|
339 |
+
max_offset = np.maximum(max_offset, 0)
|
340 |
+
offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
|
341 |
+
offset = np.round(offset).astype(int)
|
342 |
+
return CropTransform(
|
343 |
+
offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
|
344 |
+
)
|
345 |
+
|
346 |
+
def _get_pad(self, image: np.ndarray) -> Transform:
|
347 |
+
# Compute the image scale and scaled size.
|
348 |
+
input_size = image.shape[:2]
|
349 |
+
output_size = self.crop_size
|
350 |
+
|
351 |
+
# Add padding if the image is scaled down.
|
352 |
+
pad_size = np.subtract(output_size, input_size)
|
353 |
+
pad_size = np.maximum(pad_size, 0)
|
354 |
+
original_size = np.minimum(input_size, output_size)
|
355 |
+
return PadTransform(
|
356 |
+
0,
|
357 |
+
0,
|
358 |
+
pad_size[1],
|
359 |
+
pad_size[0],
|
360 |
+
original_size[1],
|
361 |
+
original_size[0],
|
362 |
+
self.pad_value,
|
363 |
+
self.seg_pad_value,
|
364 |
+
)
|
365 |
+
|
366 |
+
def get_transform(self, image: np.ndarray) -> TransformList:
|
367 |
+
transforms = [self._get_crop(image)]
|
368 |
+
if self.pad:
|
369 |
+
transforms.append(self._get_pad(image))
|
370 |
+
return TransformList(transforms)
|
371 |
+
|
372 |
+
|
373 |
+
class RandomCrop(Augmentation):
|
374 |
+
"""
|
375 |
+
Randomly crop a rectangle region out of an image.
|
376 |
+
"""
|
377 |
+
|
378 |
+
def __init__(self, crop_type: str, crop_size):
|
379 |
+
"""
|
380 |
+
Args:
|
381 |
+
crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
|
382 |
+
crop_size (tuple[float, float]): two floats, explained below.
|
383 |
+
|
384 |
+
- "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
|
385 |
+
size (H, W). crop size should be in (0, 1]
|
386 |
+
- "relative_range": uniformly sample two values from [crop_size[0], 1]
|
387 |
+
and [crop_size[1]], 1], and use them as in "relative" crop type.
|
388 |
+
- "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
|
389 |
+
crop_size must be smaller than the input image size.
|
390 |
+
- "absolute_range", for an input of size (H, W), uniformly sample H_crop in
|
391 |
+
[crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
|
392 |
+
Then crop a region (H_crop, W_crop).
|
393 |
+
"""
|
394 |
+
# TODO style of relative_range and absolute_range are not consistent:
|
395 |
+
# one takes (h, w) but another takes (min, max)
|
396 |
+
super().__init__()
|
397 |
+
assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
|
398 |
+
self._init(locals())
|
399 |
+
|
400 |
+
def get_transform(self, image):
|
401 |
+
h, w = image.shape[:2]
|
402 |
+
croph, cropw = self.get_crop_size((h, w))
|
403 |
+
assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
|
404 |
+
h0 = np.random.randint(h - croph + 1)
|
405 |
+
w0 = np.random.randint(w - cropw + 1)
|
406 |
+
return CropTransform(w0, h0, cropw, croph)
|
407 |
+
|
408 |
+
def get_crop_size(self, image_size):
|
409 |
+
"""
|
410 |
+
Args:
|
411 |
+
image_size (tuple): height, width
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
crop_size (tuple): height, width in absolute pixels
|
415 |
+
"""
|
416 |
+
h, w = image_size
|
417 |
+
if self.crop_type == "relative":
|
418 |
+
ch, cw = self.crop_size
|
419 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
420 |
+
elif self.crop_type == "relative_range":
|
421 |
+
crop_size = np.asarray(self.crop_size, dtype=np.float32)
|
422 |
+
ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
|
423 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
424 |
+
elif self.crop_type == "absolute":
|
425 |
+
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
|
426 |
+
elif self.crop_type == "absolute_range":
|
427 |
+
assert self.crop_size[0] <= self.crop_size[1]
|
428 |
+
ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
|
429 |
+
cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
|
430 |
+
return ch, cw
|
431 |
+
else:
|
432 |
+
raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
|
433 |
+
|
434 |
+
|
435 |
+
class RandomCrop_CategoryAreaConstraint(Augmentation):
|
436 |
+
"""
|
437 |
+
Similar to :class:`RandomCrop`, but find a cropping window such that no single category
|
438 |
+
occupies a ratio of more than `single_category_max_area` in semantic segmentation ground
|
439 |
+
truth, which can cause unstability in training. The function attempts to find such a valid
|
440 |
+
cropping window for at most 10 times.
|
441 |
+
"""
|
442 |
+
|
443 |
+
def __init__(
|
444 |
+
self,
|
445 |
+
crop_type: str,
|
446 |
+
crop_size,
|
447 |
+
single_category_max_area: float = 1.0,
|
448 |
+
ignored_category: int = None,
|
449 |
+
):
|
450 |
+
"""
|
451 |
+
Args:
|
452 |
+
crop_type, crop_size: same as in :class:`RandomCrop`
|
453 |
+
single_category_max_area: the maximum allowed area ratio of a
|
454 |
+
category. Set to 1.0 to disable
|
455 |
+
ignored_category: allow this category in the semantic segmentation
|
456 |
+
ground truth to exceed the area ratio. Usually set to the category
|
457 |
+
that's ignored in training.
|
458 |
+
"""
|
459 |
+
self.crop_aug = RandomCrop(crop_type, crop_size)
|
460 |
+
self._init(locals())
|
461 |
+
|
462 |
+
def get_transform(self, image, sem_seg):
|
463 |
+
if self.single_category_max_area >= 1.0:
|
464 |
+
return self.crop_aug.get_transform(image)
|
465 |
+
else:
|
466 |
+
h, w = sem_seg.shape
|
467 |
+
for _ in range(10):
|
468 |
+
crop_size = self.crop_aug.get_crop_size((h, w))
|
469 |
+
y0 = np.random.randint(h - crop_size[0] + 1)
|
470 |
+
x0 = np.random.randint(w - crop_size[1] + 1)
|
471 |
+
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
|
472 |
+
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
|
473 |
+
if self.ignored_category is not None:
|
474 |
+
cnt = cnt[labels != self.ignored_category]
|
475 |
+
if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area:
|
476 |
+
break
|
477 |
+
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
|
478 |
+
return crop_tfm
|
479 |
+
|
480 |
+
|
481 |
+
class RandomExtent(Augmentation):
|
482 |
+
"""
|
483 |
+
Outputs an image by cropping a random "subrect" of the source image.
|
484 |
+
|
485 |
+
The subrect can be parameterized to include pixels outside the source image,
|
486 |
+
in which case they will be set to zeros (i.e. black). The size of the output
|
487 |
+
image will vary with the size of the random subrect.
|
488 |
+
"""
|
489 |
+
|
490 |
+
def __init__(self, scale_range, shift_range):
|
491 |
+
"""
|
492 |
+
Args:
|
493 |
+
output_size (h, w): Dimensions of output image
|
494 |
+
scale_range (l, h): Range of input-to-output size scaling factor
|
495 |
+
shift_range (x, y): Range of shifts of the cropped subrect. The rect
|
496 |
+
is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
|
497 |
+
where (w, h) is the (width, height) of the input image. Set each
|
498 |
+
component to zero to crop at the image's center.
|
499 |
+
"""
|
500 |
+
super().__init__()
|
501 |
+
self._init(locals())
|
502 |
+
|
503 |
+
def get_transform(self, image):
|
504 |
+
img_h, img_w = image.shape[:2]
|
505 |
+
|
506 |
+
# Initialize src_rect to fit the input image.
|
507 |
+
src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
|
508 |
+
|
509 |
+
# Apply a random scaling to the src_rect.
|
510 |
+
src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
|
511 |
+
|
512 |
+
# Apply a random shift to the coordinates origin.
|
513 |
+
src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
|
514 |
+
src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
|
515 |
+
|
516 |
+
# Map src_rect coordinates into image coordinates (center at corner).
|
517 |
+
src_rect[0::2] += 0.5 * img_w
|
518 |
+
src_rect[1::2] += 0.5 * img_h
|
519 |
+
|
520 |
+
return ExtentTransform(
|
521 |
+
src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
|
522 |
+
output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
|
523 |
+
)
|
524 |
+
|
525 |
+
|
526 |
+
class RandomContrast(Augmentation):
|
527 |
+
"""
|
528 |
+
Randomly transforms image contrast.
|
529 |
+
|
530 |
+
Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
|
531 |
+
- intensity < 1 will reduce contrast
|
532 |
+
- intensity = 1 will preserve the input image
|
533 |
+
- intensity > 1 will increase contrast
|
534 |
+
|
535 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
536 |
+
"""
|
537 |
+
|
538 |
+
def __init__(self, intensity_min, intensity_max):
|
539 |
+
"""
|
540 |
+
Args:
|
541 |
+
intensity_min (float): Minimum augmentation
|
542 |
+
intensity_max (float): Maximum augmentation
|
543 |
+
"""
|
544 |
+
super().__init__()
|
545 |
+
self._init(locals())
|
546 |
+
|
547 |
+
def get_transform(self, image):
|
548 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
549 |
+
return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w)
|
550 |
+
|
551 |
+
|
552 |
+
class RandomBrightness(Augmentation):
|
553 |
+
"""
|
554 |
+
Randomly transforms image brightness.
|
555 |
+
|
556 |
+
Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
|
557 |
+
- intensity < 1 will reduce brightness
|
558 |
+
- intensity = 1 will preserve the input image
|
559 |
+
- intensity > 1 will increase brightness
|
560 |
+
|
561 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
562 |
+
"""
|
563 |
+
|
564 |
+
def __init__(self, intensity_min, intensity_max):
|
565 |
+
"""
|
566 |
+
Args:
|
567 |
+
intensity_min (float): Minimum augmentation
|
568 |
+
intensity_max (float): Maximum augmentation
|
569 |
+
"""
|
570 |
+
super().__init__()
|
571 |
+
self._init(locals())
|
572 |
+
|
573 |
+
def get_transform(self, image):
|
574 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
575 |
+
return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
|
576 |
+
|
577 |
+
|
578 |
+
class RandomSaturation(Augmentation):
|
579 |
+
"""
|
580 |
+
Randomly transforms saturation of an RGB image.
|
581 |
+
Input images are assumed to have 'RGB' channel order.
|
582 |
+
|
583 |
+
Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
|
584 |
+
- intensity < 1 will reduce saturation (make the image more grayscale)
|
585 |
+
- intensity = 1 will preserve the input image
|
586 |
+
- intensity > 1 will increase saturation
|
587 |
+
|
588 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
589 |
+
"""
|
590 |
+
|
591 |
+
def __init__(self, intensity_min, intensity_max):
|
592 |
+
"""
|
593 |
+
Args:
|
594 |
+
intensity_min (float): Minimum augmentation (1 preserves input).
|
595 |
+
intensity_max (float): Maximum augmentation (1 preserves input).
|
596 |
+
"""
|
597 |
+
super().__init__()
|
598 |
+
self._init(locals())
|
599 |
+
|
600 |
+
def get_transform(self, image):
|
601 |
+
assert image.shape[-1] == 3, "RandomSaturation only works on RGB images"
|
602 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
603 |
+
grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
|
604 |
+
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
|
605 |
+
|
606 |
+
|
607 |
+
class RandomLighting(Augmentation):
|
608 |
+
"""
|
609 |
+
The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet.
|
610 |
+
Input images are assumed to have 'RGB' channel order.
|
611 |
+
|
612 |
+
The degree of color jittering is randomly sampled via a normal distribution,
|
613 |
+
with standard deviation given by the scale parameter.
|
614 |
+
"""
|
615 |
+
|
616 |
+
def __init__(self, scale):
|
617 |
+
"""
|
618 |
+
Args:
|
619 |
+
scale (float): Standard deviation of principal component weighting.
|
620 |
+
"""
|
621 |
+
super().__init__()
|
622 |
+
self._init(locals())
|
623 |
+
self.eigen_vecs = np.array(
|
624 |
+
[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
|
625 |
+
)
|
626 |
+
self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
|
627 |
+
|
628 |
+
def get_transform(self, image):
|
629 |
+
assert image.shape[-1] == 3, "RandomLighting only works on RGB images"
|
630 |
+
weights = np.random.normal(scale=self.scale, size=3)
|
631 |
+
return BlendTransform(
|
632 |
+
src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
|
633 |
+
)
|
634 |
+
|
635 |
+
|
636 |
+
class RandomResize(Augmentation):
|
637 |
+
"""Randomly resize image to a target size in shape_list"""
|
638 |
+
|
639 |
+
def __init__(self, shape_list, interp=Image.BILINEAR):
|
640 |
+
"""
|
641 |
+
Args:
|
642 |
+
shape_list: a list of shapes in (h, w)
|
643 |
+
interp: PIL interpolation method
|
644 |
+
"""
|
645 |
+
self.shape_list = shape_list
|
646 |
+
self._init(locals())
|
647 |
+
|
648 |
+
def get_transform(self, image):
|
649 |
+
shape_idx = np.random.randint(low=0, high=len(self.shape_list))
|
650 |
+
h, w = self.shape_list[shape_idx]
|
651 |
+
return ResizeTransform(image.shape[0], image.shape[1], h, w, self.interp)
|
652 |
+
|
653 |
+
|
654 |
+
class MinIoURandomCrop(Augmentation):
|
655 |
+
"""Random crop the image & bboxes, the cropped patches have minimum IoU
|
656 |
+
requirement with original image & bboxes, the IoU threshold is randomly
|
657 |
+
selected from min_ious.
|
658 |
+
|
659 |
+
Args:
|
660 |
+
min_ious (tuple): minimum IoU threshold for all intersections with
|
661 |
+
bounding boxes
|
662 |
+
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
|
663 |
+
where a >= min_crop_size)
|
664 |
+
mode_trials: number of trials for sampling min_ious threshold
|
665 |
+
crop_trials: number of trials for sampling crop_size after cropping
|
666 |
+
"""
|
667 |
+
|
668 |
+
def __init__(
|
669 |
+
self,
|
670 |
+
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
|
671 |
+
min_crop_size=0.3,
|
672 |
+
mode_trials=1000,
|
673 |
+
crop_trials=50,
|
674 |
+
):
|
675 |
+
self.min_ious = min_ious
|
676 |
+
self.sample_mode = (1, *min_ious, 0)
|
677 |
+
self.min_crop_size = min_crop_size
|
678 |
+
self.mode_trials = mode_trials
|
679 |
+
self.crop_trials = crop_trials
|
680 |
+
|
681 |
+
def get_transform(self, image, boxes):
|
682 |
+
"""Call function to crop images and bounding boxes with minimum IoU
|
683 |
+
constraint.
|
684 |
+
|
685 |
+
Args:
|
686 |
+
boxes: ground truth boxes in (x1, y1, x2, y2) format
|
687 |
+
"""
|
688 |
+
if boxes is None:
|
689 |
+
return NoOpTransform()
|
690 |
+
h, w, c = image.shape
|
691 |
+
for _ in range(self.mode_trials):
|
692 |
+
mode = random.choice(self.sample_mode)
|
693 |
+
self.mode = mode
|
694 |
+
if mode == 1:
|
695 |
+
return NoOpTransform()
|
696 |
+
|
697 |
+
min_iou = mode
|
698 |
+
for _ in range(self.crop_trials):
|
699 |
+
new_w = random.uniform(self.min_crop_size * w, w)
|
700 |
+
new_h = random.uniform(self.min_crop_size * h, h)
|
701 |
+
|
702 |
+
# h / w in [0.5, 2]
|
703 |
+
if new_h / new_w < 0.5 or new_h / new_w > 2:
|
704 |
+
continue
|
705 |
+
|
706 |
+
left = random.uniform(w - new_w)
|
707 |
+
top = random.uniform(h - new_h)
|
708 |
+
|
709 |
+
patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h)))
|
710 |
+
# Line or point crop is not allowed
|
711 |
+
if patch[2] == patch[0] or patch[3] == patch[1]:
|
712 |
+
continue
|
713 |
+
overlaps = pairwise_iou(
|
714 |
+
Boxes(patch.reshape(-1, 4)), Boxes(boxes.reshape(-1, 4))
|
715 |
+
).reshape(-1)
|
716 |
+
if len(overlaps) > 0 and overlaps.min() < min_iou:
|
717 |
+
continue
|
718 |
+
|
719 |
+
# center of boxes should inside the crop img
|
720 |
+
# only adjust boxes and instance masks when the gt is not empty
|
721 |
+
if len(overlaps) > 0:
|
722 |
+
# adjust boxes
|
723 |
+
def is_center_of_bboxes_in_patch(boxes, patch):
|
724 |
+
center = (boxes[:, :2] + boxes[:, 2:]) / 2
|
725 |
+
mask = (
|
726 |
+
(center[:, 0] > patch[0])
|
727 |
+
* (center[:, 1] > patch[1])
|
728 |
+
* (center[:, 0] < patch[2])
|
729 |
+
* (center[:, 1] < patch[3])
|
730 |
+
)
|
731 |
+
return mask
|
732 |
+
|
733 |
+
mask = is_center_of_bboxes_in_patch(boxes, patch)
|
734 |
+
if not mask.any():
|
735 |
+
continue
|
736 |
+
return CropTransform(int(left), int(top), int(new_w), int(new_h))
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/data/transforms/transform.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
"""
|
5 |
+
See "Data Augmentation" tutorial for an overview of the system:
|
6 |
+
https://detectron2.readthedocs.io/tutorials/augmentation.html
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from fvcore.transforms.transform import (
|
13 |
+
CropTransform,
|
14 |
+
HFlipTransform,
|
15 |
+
NoOpTransform,
|
16 |
+
Transform,
|
17 |
+
TransformList,
|
18 |
+
)
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
try:
|
22 |
+
import cv2 # noqa
|
23 |
+
except ImportError:
|
24 |
+
# OpenCV is an optional dependency at the moment
|
25 |
+
pass
|
26 |
+
|
27 |
+
__all__ = [
|
28 |
+
"ExtentTransform",
|
29 |
+
"ResizeTransform",
|
30 |
+
"RotationTransform",
|
31 |
+
"ColorTransform",
|
32 |
+
"PILColorTransform",
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
class ExtentTransform(Transform):
|
37 |
+
"""
|
38 |
+
Extracts a subregion from the source image and scales it to the output size.
|
39 |
+
|
40 |
+
The fill color is used to map pixels from the source rect that fall outside
|
41 |
+
the source image.
|
42 |
+
|
43 |
+
See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, src_rect, output_size, interp=Image.LINEAR, fill=0):
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
src_rect (x0, y0, x1, y1): src coordinates
|
50 |
+
output_size (h, w): dst image size
|
51 |
+
interp: PIL interpolation methods
|
52 |
+
fill: Fill color used when src_rect extends outside image
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self._set_attributes(locals())
|
56 |
+
|
57 |
+
def apply_image(self, img, interp=None):
|
58 |
+
h, w = self.output_size
|
59 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
60 |
+
pil_image = Image.fromarray(img[:, :, 0], mode="L")
|
61 |
+
else:
|
62 |
+
pil_image = Image.fromarray(img)
|
63 |
+
pil_image = pil_image.transform(
|
64 |
+
size=(w, h),
|
65 |
+
method=Image.EXTENT,
|
66 |
+
data=self.src_rect,
|
67 |
+
resample=interp if interp else self.interp,
|
68 |
+
fill=self.fill,
|
69 |
+
)
|
70 |
+
ret = np.asarray(pil_image)
|
71 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
72 |
+
ret = np.expand_dims(ret, -1)
|
73 |
+
return ret
|
74 |
+
|
75 |
+
def apply_coords(self, coords):
|
76 |
+
# Transform image center from source coordinates into output coordinates
|
77 |
+
# and then map the new origin to the corner of the output image.
|
78 |
+
h, w = self.output_size
|
79 |
+
x0, y0, x1, y1 = self.src_rect
|
80 |
+
new_coords = coords.astype(np.float32)
|
81 |
+
new_coords[:, 0] -= 0.5 * (x0 + x1)
|
82 |
+
new_coords[:, 1] -= 0.5 * (y0 + y1)
|
83 |
+
new_coords[:, 0] *= w / (x1 - x0)
|
84 |
+
new_coords[:, 1] *= h / (y1 - y0)
|
85 |
+
new_coords[:, 0] += 0.5 * w
|
86 |
+
new_coords[:, 1] += 0.5 * h
|
87 |
+
return new_coords
|
88 |
+
|
89 |
+
def apply_segmentation(self, segmentation):
|
90 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
91 |
+
return segmentation
|
92 |
+
|
93 |
+
|
94 |
+
class ResizeTransform(Transform):
|
95 |
+
"""
|
96 |
+
Resize the image to a target size.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, h, w, new_h, new_w, interp=None):
|
100 |
+
"""
|
101 |
+
Args:
|
102 |
+
h, w (int): original image size
|
103 |
+
new_h, new_w (int): new image size
|
104 |
+
interp: PIL interpolation methods, defaults to bilinear.
|
105 |
+
"""
|
106 |
+
# TODO decide on PIL vs opencv
|
107 |
+
super().__init__()
|
108 |
+
if interp is None:
|
109 |
+
interp = Image.BILINEAR
|
110 |
+
self._set_attributes(locals())
|
111 |
+
|
112 |
+
def apply_image(self, img, interp=None):
|
113 |
+
assert img.shape[:2] == (self.h, self.w)
|
114 |
+
assert len(img.shape) <= 4
|
115 |
+
interp_method = interp if interp is not None else self.interp
|
116 |
+
|
117 |
+
if img.dtype == np.uint8:
|
118 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
119 |
+
pil_image = Image.fromarray(img[:, :, 0], mode="L")
|
120 |
+
else:
|
121 |
+
pil_image = Image.fromarray(img)
|
122 |
+
pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
|
123 |
+
ret = np.asarray(pil_image)
|
124 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
125 |
+
ret = np.expand_dims(ret, -1)
|
126 |
+
else:
|
127 |
+
# PIL only supports uint8
|
128 |
+
if any(x < 0 for x in img.strides):
|
129 |
+
img = np.ascontiguousarray(img)
|
130 |
+
img = torch.from_numpy(img)
|
131 |
+
shape = list(img.shape)
|
132 |
+
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
|
133 |
+
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
|
134 |
+
_PIL_RESIZE_TO_INTERPOLATE_MODE = {
|
135 |
+
Image.NEAREST: "nearest",
|
136 |
+
Image.BILINEAR: "bilinear",
|
137 |
+
Image.BICUBIC: "bicubic",
|
138 |
+
}
|
139 |
+
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method]
|
140 |
+
align_corners = None if mode == "nearest" else False
|
141 |
+
img = F.interpolate(
|
142 |
+
img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners
|
143 |
+
)
|
144 |
+
shape[:2] = (self.new_h, self.new_w)
|
145 |
+
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
|
146 |
+
|
147 |
+
return ret
|
148 |
+
|
149 |
+
def apply_coords(self, coords):
|
150 |
+
coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
|
151 |
+
coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
|
152 |
+
return coords
|
153 |
+
|
154 |
+
def apply_segmentation(self, segmentation):
|
155 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
156 |
+
return segmentation
|
157 |
+
|
158 |
+
def inverse(self):
|
159 |
+
return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp)
|
160 |
+
|
161 |
+
|
162 |
+
class RotationTransform(Transform):
|
163 |
+
"""
|
164 |
+
This method returns a copy of this image, rotated the given
|
165 |
+
number of degrees counter clockwise around its center.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, h, w, angle, expand=True, center=None, interp=None):
|
169 |
+
"""
|
170 |
+
Args:
|
171 |
+
h, w (int): original image size
|
172 |
+
angle (float): degrees for rotation
|
173 |
+
expand (bool): choose if the image should be resized to fit the whole
|
174 |
+
rotated image (default), or simply cropped
|
175 |
+
center (tuple (width, height)): coordinates of the rotation center
|
176 |
+
if left to None, the center will be fit to the center of each image
|
177 |
+
center has no effect if expand=True because it only affects shifting
|
178 |
+
interp: cv2 interpolation method, default cv2.INTER_LINEAR
|
179 |
+
"""
|
180 |
+
super().__init__()
|
181 |
+
image_center = np.array((w / 2, h / 2))
|
182 |
+
if center is None:
|
183 |
+
center = image_center
|
184 |
+
if interp is None:
|
185 |
+
interp = cv2.INTER_LINEAR
|
186 |
+
abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle))))
|
187 |
+
if expand:
|
188 |
+
# find the new width and height bounds
|
189 |
+
bound_w, bound_h = np.rint(
|
190 |
+
[h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
|
191 |
+
).astype(int)
|
192 |
+
else:
|
193 |
+
bound_w, bound_h = w, h
|
194 |
+
|
195 |
+
self._set_attributes(locals())
|
196 |
+
self.rm_coords = self.create_rotation_matrix()
|
197 |
+
# Needed because of this problem https://github.com/opencv/opencv/issues/11784
|
198 |
+
self.rm_image = self.create_rotation_matrix(offset=-0.5)
|
199 |
+
|
200 |
+
def apply_image(self, img, interp=None):
|
201 |
+
"""
|
202 |
+
img should be a numpy array, formatted as Height * Width * Nchannels
|
203 |
+
"""
|
204 |
+
if len(img) == 0 or self.angle % 360 == 0:
|
205 |
+
return img
|
206 |
+
assert img.shape[:2] == (self.h, self.w)
|
207 |
+
interp = interp if interp is not None else self.interp
|
208 |
+
return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
|
209 |
+
|
210 |
+
def apply_coords(self, coords):
|
211 |
+
"""
|
212 |
+
coords should be a N * 2 array-like, containing N couples of (x, y) points
|
213 |
+
"""
|
214 |
+
coords = np.asarray(coords, dtype=float)
|
215 |
+
if len(coords) == 0 or self.angle % 360 == 0:
|
216 |
+
return coords
|
217 |
+
return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
|
218 |
+
|
219 |
+
def apply_segmentation(self, segmentation):
|
220 |
+
segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
|
221 |
+
return segmentation
|
222 |
+
|
223 |
+
def create_rotation_matrix(self, offset=0):
|
224 |
+
center = (self.center[0] + offset, self.center[1] + offset)
|
225 |
+
rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
|
226 |
+
if self.expand:
|
227 |
+
# Find the coordinates of the center of rotation in the new image
|
228 |
+
# The only point for which we know the future coordinates is the center of the image
|
229 |
+
rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
|
230 |
+
new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
|
231 |
+
# shift the rotation center to the new coordinates
|
232 |
+
rm[:, 2] += new_center
|
233 |
+
return rm
|
234 |
+
|
235 |
+
def inverse(self):
|
236 |
+
"""
|
237 |
+
The inverse is to rotate it back with expand, and crop to get the original shape.
|
238 |
+
"""
|
239 |
+
if not self.expand: # Not possible to inverse if a part of the image is lost
|
240 |
+
raise NotImplementedError()
|
241 |
+
rotation = RotationTransform(
|
242 |
+
self.bound_h, self.bound_w, -self.angle, True, None, self.interp
|
243 |
+
)
|
244 |
+
crop = CropTransform(
|
245 |
+
(rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h
|
246 |
+
)
|
247 |
+
return TransformList([rotation, crop])
|
248 |
+
|
249 |
+
|
250 |
+
class ColorTransform(Transform):
|
251 |
+
"""
|
252 |
+
Generic wrapper for any photometric transforms.
|
253 |
+
These transformations should only affect the color space and
|
254 |
+
not the coordinate space of the image (e.g. annotation
|
255 |
+
coordinates such as bounding boxes should not be changed)
|
256 |
+
"""
|
257 |
+
|
258 |
+
def __init__(self, op):
|
259 |
+
"""
|
260 |
+
Args:
|
261 |
+
op (Callable): operation to be applied to the image,
|
262 |
+
which takes in an ndarray and returns an ndarray.
|
263 |
+
"""
|
264 |
+
if not callable(op):
|
265 |
+
raise ValueError("op parameter should be callable")
|
266 |
+
super().__init__()
|
267 |
+
self._set_attributes(locals())
|
268 |
+
|
269 |
+
def apply_image(self, img):
|
270 |
+
return self.op(img)
|
271 |
+
|
272 |
+
def apply_coords(self, coords):
|
273 |
+
return coords
|
274 |
+
|
275 |
+
def inverse(self):
|
276 |
+
return NoOpTransform()
|
277 |
+
|
278 |
+
def apply_segmentation(self, segmentation):
|
279 |
+
return segmentation
|
280 |
+
|
281 |
+
|
282 |
+
class PILColorTransform(ColorTransform):
|
283 |
+
"""
|
284 |
+
Generic wrapper for PIL Photometric image transforms,
|
285 |
+
which affect the color space and not the coordinate
|
286 |
+
space of the image
|
287 |
+
"""
|
288 |
+
|
289 |
+
def __init__(self, op):
|
290 |
+
"""
|
291 |
+
Args:
|
292 |
+
op (Callable): operation to be applied to the image,
|
293 |
+
which takes in a PIL Image and returns a transformed
|
294 |
+
PIL Image.
|
295 |
+
For reference on possible operations see:
|
296 |
+
- https://pillow.readthedocs.io/en/stable/
|
297 |
+
"""
|
298 |
+
if not callable(op):
|
299 |
+
raise ValueError("op parameter should be callable")
|
300 |
+
super().__init__(op)
|
301 |
+
|
302 |
+
def apply_image(self, img):
|
303 |
+
img = Image.fromarray(img)
|
304 |
+
return np.asarray(super().apply_image(img))
|
305 |
+
|
306 |
+
|
307 |
+
def HFlip_rotated_box(transform, rotated_boxes):
|
308 |
+
"""
|
309 |
+
Apply the horizontal flip transform on rotated boxes.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
rotated_boxes (ndarray): Nx5 floating point array of
|
313 |
+
(x_center, y_center, width, height, angle_degrees) format
|
314 |
+
in absolute coordinates.
|
315 |
+
"""
|
316 |
+
# Transform x_center
|
317 |
+
rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0]
|
318 |
+
# Transform angle
|
319 |
+
rotated_boxes[:, 4] = -rotated_boxes[:, 4]
|
320 |
+
return rotated_boxes
|
321 |
+
|
322 |
+
|
323 |
+
def Resize_rotated_box(transform, rotated_boxes):
|
324 |
+
"""
|
325 |
+
Apply the resizing transform on rotated boxes. For details of how these (approximation)
|
326 |
+
formulas are derived, please refer to :meth:`RotatedBoxes.scale`.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
rotated_boxes (ndarray): Nx5 floating point array of
|
330 |
+
(x_center, y_center, width, height, angle_degrees) format
|
331 |
+
in absolute coordinates.
|
332 |
+
"""
|
333 |
+
scale_factor_x = transform.new_w * 1.0 / transform.w
|
334 |
+
scale_factor_y = transform.new_h * 1.0 / transform.h
|
335 |
+
rotated_boxes[:, 0] *= scale_factor_x
|
336 |
+
rotated_boxes[:, 1] *= scale_factor_y
|
337 |
+
theta = rotated_boxes[:, 4] * np.pi / 180.0
|
338 |
+
c = np.cos(theta)
|
339 |
+
s = np.sin(theta)
|
340 |
+
rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s))
|
341 |
+
rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c))
|
342 |
+
rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi
|
343 |
+
|
344 |
+
return rotated_boxes
|
345 |
+
|
346 |
+
|
347 |
+
HFlipTransform.register_type("rotated_box", HFlip_rotated_box)
|
348 |
+
ResizeTransform.register_type("rotated_box", Resize_rotated_box)
|
349 |
+
|
350 |
+
# not necessary any more with latest fvcore
|
351 |
+
NoOpTransform.register_type("rotated_box", lambda t, x: x)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
from .launch import *
|
4 |
+
from .train_loop import *
|
5 |
+
|
6 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
7 |
+
|
8 |
+
|
9 |
+
# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
|
10 |
+
# but still make them available here
|
11 |
+
from .hooks import *
|
12 |
+
from .defaults import *
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/defaults.py
ADDED
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
"""
|
5 |
+
This file contains components with some default boilerplate logic user may need
|
6 |
+
in training / testing. They will not work for everyone, but many users may find them useful.
|
7 |
+
|
8 |
+
The behavior of functions/classes in this file is subject to change,
|
9 |
+
since they are meant to represent the "common default behavior" people need in their projects.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import argparse
|
13 |
+
import logging
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import weakref
|
17 |
+
from collections import OrderedDict
|
18 |
+
from typing import Optional
|
19 |
+
import torch
|
20 |
+
from fvcore.nn.precise_bn import get_bn_modules
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from torch.nn.parallel import DistributedDataParallel
|
23 |
+
|
24 |
+
import annotator.oneformer.detectron2.data.transforms as T
|
25 |
+
from annotator.oneformer.detectron2.checkpoint import DetectionCheckpointer
|
26 |
+
from annotator.oneformer.detectron2.config import CfgNode, LazyConfig
|
27 |
+
from annotator.oneformer.detectron2.data import (
|
28 |
+
MetadataCatalog,
|
29 |
+
build_detection_test_loader,
|
30 |
+
build_detection_train_loader,
|
31 |
+
)
|
32 |
+
from annotator.oneformer.detectron2.evaluation import (
|
33 |
+
DatasetEvaluator,
|
34 |
+
inference_on_dataset,
|
35 |
+
print_csv_format,
|
36 |
+
verify_results,
|
37 |
+
)
|
38 |
+
from annotator.oneformer.detectron2.modeling import build_model
|
39 |
+
from annotator.oneformer.detectron2.solver import build_lr_scheduler, build_optimizer
|
40 |
+
from annotator.oneformer.detectron2.utils import comm
|
41 |
+
from annotator.oneformer.detectron2.utils.collect_env import collect_env_info
|
42 |
+
from annotator.oneformer.detectron2.utils.env import seed_all_rng
|
43 |
+
from annotator.oneformer.detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
44 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
45 |
+
from annotator.oneformer.detectron2.utils.logger import setup_logger
|
46 |
+
|
47 |
+
from . import hooks
|
48 |
+
from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase
|
49 |
+
|
50 |
+
__all__ = [
|
51 |
+
"create_ddp_model",
|
52 |
+
"default_argument_parser",
|
53 |
+
"default_setup",
|
54 |
+
"default_writers",
|
55 |
+
"DefaultPredictor",
|
56 |
+
"DefaultTrainer",
|
57 |
+
]
|
58 |
+
|
59 |
+
|
60 |
+
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
|
61 |
+
"""
|
62 |
+
Create a DistributedDataParallel model if there are >1 processes.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
model: a torch.nn.Module
|
66 |
+
fp16_compression: add fp16 compression hooks to the ddp object.
|
67 |
+
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
|
68 |
+
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
|
69 |
+
""" # noqa
|
70 |
+
if comm.get_world_size() == 1:
|
71 |
+
return model
|
72 |
+
if "device_ids" not in kwargs:
|
73 |
+
kwargs["device_ids"] = [comm.get_local_rank()]
|
74 |
+
ddp = DistributedDataParallel(model, **kwargs)
|
75 |
+
if fp16_compression:
|
76 |
+
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
77 |
+
|
78 |
+
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
|
79 |
+
return ddp
|
80 |
+
|
81 |
+
|
82 |
+
def default_argument_parser(epilog=None):
|
83 |
+
"""
|
84 |
+
Create a parser with some common arguments used by detectron2 users.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
epilog (str): epilog passed to ArgumentParser describing the usage.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
argparse.ArgumentParser:
|
91 |
+
"""
|
92 |
+
parser = argparse.ArgumentParser(
|
93 |
+
epilog=epilog
|
94 |
+
or f"""
|
95 |
+
Examples:
|
96 |
+
|
97 |
+
Run on single machine:
|
98 |
+
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
|
99 |
+
|
100 |
+
Change some config options:
|
101 |
+
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
|
102 |
+
|
103 |
+
Run on multiple machines:
|
104 |
+
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
|
105 |
+
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
|
106 |
+
""",
|
107 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
108 |
+
)
|
109 |
+
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
110 |
+
parser.add_argument(
|
111 |
+
"--resume",
|
112 |
+
action="store_true",
|
113 |
+
help="Whether to attempt to resume from the checkpoint directory. "
|
114 |
+
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
|
115 |
+
)
|
116 |
+
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
117 |
+
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
|
118 |
+
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
|
119 |
+
parser.add_argument(
|
120 |
+
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
|
121 |
+
)
|
122 |
+
|
123 |
+
# PyTorch still may leave orphan processes in multi-gpu training.
|
124 |
+
# Therefore we use a deterministic way to obtain port,
|
125 |
+
# so that users are aware of orphan processes by seeing the port occupied.
|
126 |
+
port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
|
127 |
+
parser.add_argument(
|
128 |
+
"--dist-url",
|
129 |
+
default="tcp://127.0.0.1:{}".format(port),
|
130 |
+
help="initialization URL for pytorch distributed backend. See "
|
131 |
+
"https://pytorch.org/docs/stable/distributed.html for details.",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"opts",
|
135 |
+
help="""
|
136 |
+
Modify config options at the end of the command. For Yacs configs, use
|
137 |
+
space-separated "PATH.KEY VALUE" pairs.
|
138 |
+
For python-based LazyConfig, use "path.key=value".
|
139 |
+
""".strip(),
|
140 |
+
default=None,
|
141 |
+
nargs=argparse.REMAINDER,
|
142 |
+
)
|
143 |
+
return parser
|
144 |
+
|
145 |
+
|
146 |
+
def _try_get_key(cfg, *keys, default=None):
|
147 |
+
"""
|
148 |
+
Try select keys from cfg until the first key that exists. Otherwise return default.
|
149 |
+
"""
|
150 |
+
if isinstance(cfg, CfgNode):
|
151 |
+
cfg = OmegaConf.create(cfg.dump())
|
152 |
+
for k in keys:
|
153 |
+
none = object()
|
154 |
+
p = OmegaConf.select(cfg, k, default=none)
|
155 |
+
if p is not none:
|
156 |
+
return p
|
157 |
+
return default
|
158 |
+
|
159 |
+
|
160 |
+
def _highlight(code, filename):
|
161 |
+
try:
|
162 |
+
import pygments
|
163 |
+
except ImportError:
|
164 |
+
return code
|
165 |
+
|
166 |
+
from pygments.lexers import Python3Lexer, YamlLexer
|
167 |
+
from pygments.formatters import Terminal256Formatter
|
168 |
+
|
169 |
+
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
|
170 |
+
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
|
171 |
+
return code
|
172 |
+
|
173 |
+
|
174 |
+
def default_setup(cfg, args):
|
175 |
+
"""
|
176 |
+
Perform some basic common setups at the beginning of a job, including:
|
177 |
+
|
178 |
+
1. Set up the detectron2 logger
|
179 |
+
2. Log basic information about environment, cmdline arguments, and config
|
180 |
+
3. Backup the config to the output directory
|
181 |
+
|
182 |
+
Args:
|
183 |
+
cfg (CfgNode or omegaconf.DictConfig): the full config to be used
|
184 |
+
args (argparse.NameSpace): the command line arguments to be logged
|
185 |
+
"""
|
186 |
+
output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
|
187 |
+
if comm.is_main_process() and output_dir:
|
188 |
+
PathManager.mkdirs(output_dir)
|
189 |
+
|
190 |
+
rank = comm.get_rank()
|
191 |
+
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
192 |
+
logger = setup_logger(output_dir, distributed_rank=rank)
|
193 |
+
|
194 |
+
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
|
195 |
+
logger.info("Environment info:\n" + collect_env_info())
|
196 |
+
|
197 |
+
logger.info("Command line arguments: " + str(args))
|
198 |
+
if hasattr(args, "config_file") and args.config_file != "":
|
199 |
+
logger.info(
|
200 |
+
"Contents of args.config_file={}:\n{}".format(
|
201 |
+
args.config_file,
|
202 |
+
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
|
203 |
+
)
|
204 |
+
)
|
205 |
+
|
206 |
+
if comm.is_main_process() and output_dir:
|
207 |
+
# Note: some of our scripts may expect the existence of
|
208 |
+
# config.yaml in output directory
|
209 |
+
path = os.path.join(output_dir, "config.yaml")
|
210 |
+
if isinstance(cfg, CfgNode):
|
211 |
+
logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
|
212 |
+
with PathManager.open(path, "w") as f:
|
213 |
+
f.write(cfg.dump())
|
214 |
+
else:
|
215 |
+
LazyConfig.save(cfg, path)
|
216 |
+
logger.info("Full config saved to {}".format(path))
|
217 |
+
|
218 |
+
# make sure each worker has a different, yet deterministic seed if specified
|
219 |
+
seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
|
220 |
+
seed_all_rng(None if seed < 0 else seed + rank)
|
221 |
+
|
222 |
+
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
|
223 |
+
# typical validation set.
|
224 |
+
if not (hasattr(args, "eval_only") and args.eval_only):
|
225 |
+
torch.backends.cudnn.benchmark = _try_get_key(
|
226 |
+
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
|
227 |
+
)
|
228 |
+
|
229 |
+
|
230 |
+
def default_writers(output_dir: str, max_iter: Optional[int] = None):
|
231 |
+
"""
|
232 |
+
Build a list of :class:`EventWriter` to be used.
|
233 |
+
It now consists of a :class:`CommonMetricPrinter`,
|
234 |
+
:class:`TensorboardXWriter` and :class:`JSONWriter`.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
output_dir: directory to store JSON metrics and tensorboard events
|
238 |
+
max_iter: the total number of iterations
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
list[EventWriter]: a list of :class:`EventWriter` objects.
|
242 |
+
"""
|
243 |
+
PathManager.mkdirs(output_dir)
|
244 |
+
return [
|
245 |
+
# It may not always print what you want to see, since it prints "common" metrics only.
|
246 |
+
CommonMetricPrinter(max_iter),
|
247 |
+
JSONWriter(os.path.join(output_dir, "metrics.json")),
|
248 |
+
TensorboardXWriter(output_dir),
|
249 |
+
]
|
250 |
+
|
251 |
+
|
252 |
+
class DefaultPredictor:
|
253 |
+
"""
|
254 |
+
Create a simple end-to-end predictor with the given config that runs on
|
255 |
+
single device for a single input image.
|
256 |
+
|
257 |
+
Compared to using the model directly, this class does the following additions:
|
258 |
+
|
259 |
+
1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
|
260 |
+
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
|
261 |
+
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
|
262 |
+
4. Take one input image and produce a single output, instead of a batch.
|
263 |
+
|
264 |
+
This is meant for simple demo purposes, so it does the above steps automatically.
|
265 |
+
This is not meant for benchmarks or running complicated inference logic.
|
266 |
+
If you'd like to do anything more complicated, please refer to its source code as
|
267 |
+
examples to build and use the model manually.
|
268 |
+
|
269 |
+
Attributes:
|
270 |
+
metadata (Metadata): the metadata of the underlying dataset, obtained from
|
271 |
+
cfg.DATASETS.TEST.
|
272 |
+
|
273 |
+
Examples:
|
274 |
+
::
|
275 |
+
pred = DefaultPredictor(cfg)
|
276 |
+
inputs = cv2.imread("input.jpg")
|
277 |
+
outputs = pred(inputs)
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, cfg):
|
281 |
+
self.cfg = cfg.clone() # cfg can be modified by model
|
282 |
+
self.model = build_model(self.cfg)
|
283 |
+
self.model.eval()
|
284 |
+
if len(cfg.DATASETS.TEST):
|
285 |
+
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
286 |
+
|
287 |
+
checkpointer = DetectionCheckpointer(self.model)
|
288 |
+
checkpointer.load(cfg.MODEL.WEIGHTS)
|
289 |
+
|
290 |
+
self.aug = T.ResizeShortestEdge(
|
291 |
+
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
292 |
+
)
|
293 |
+
|
294 |
+
self.input_format = cfg.INPUT.FORMAT
|
295 |
+
assert self.input_format in ["RGB", "BGR"], self.input_format
|
296 |
+
|
297 |
+
def __call__(self, original_image):
|
298 |
+
"""
|
299 |
+
Args:
|
300 |
+
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
predictions (dict):
|
304 |
+
the output of the model for one image only.
|
305 |
+
See :doc:`/tutorials/models` for details about the format.
|
306 |
+
"""
|
307 |
+
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
308 |
+
# Apply pre-processing to image.
|
309 |
+
if self.input_format == "RGB":
|
310 |
+
# whether the model expects BGR inputs or RGB
|
311 |
+
original_image = original_image[:, :, ::-1]
|
312 |
+
height, width = original_image.shape[:2]
|
313 |
+
image = self.aug.get_transform(original_image).apply_image(original_image)
|
314 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
315 |
+
|
316 |
+
inputs = {"image": image, "height": height, "width": width}
|
317 |
+
predictions = self.model([inputs])[0]
|
318 |
+
return predictions
|
319 |
+
|
320 |
+
|
321 |
+
class DefaultTrainer(TrainerBase):
|
322 |
+
"""
|
323 |
+
A trainer with default training logic. It does the following:
|
324 |
+
|
325 |
+
1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
|
326 |
+
defined by the given config. Create a LR scheduler defined by the config.
|
327 |
+
2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
|
328 |
+
`resume_or_load` is called.
|
329 |
+
3. Register a few common hooks defined by the config.
|
330 |
+
|
331 |
+
It is created to simplify the **standard model training workflow** and reduce code boilerplate
|
332 |
+
for users who only need the standard training workflow, with standard features.
|
333 |
+
It means this class makes *many assumptions* about your training logic that
|
334 |
+
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
|
335 |
+
:class:`SimpleTrainer` are too much for research.
|
336 |
+
|
337 |
+
The code of this class has been annotated about restrictive assumptions it makes.
|
338 |
+
When they do not work for you, you're encouraged to:
|
339 |
+
|
340 |
+
1. Overwrite methods of this class, OR:
|
341 |
+
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
|
342 |
+
nothing else. You can then add your own hooks if needed. OR:
|
343 |
+
3. Write your own training loop similar to `tools/plain_train_net.py`.
|
344 |
+
|
345 |
+
See the :doc:`/tutorials/training` tutorials for more details.
|
346 |
+
|
347 |
+
Note that the behavior of this class, like other functions/classes in
|
348 |
+
this file, is not stable, since it is meant to represent the "common default behavior".
|
349 |
+
It is only guaranteed to work well with the standard models and training workflow in detectron2.
|
350 |
+
To obtain more stable behavior, write your own training logic with other public APIs.
|
351 |
+
|
352 |
+
Examples:
|
353 |
+
::
|
354 |
+
trainer = DefaultTrainer(cfg)
|
355 |
+
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
|
356 |
+
trainer.train()
|
357 |
+
|
358 |
+
Attributes:
|
359 |
+
scheduler:
|
360 |
+
checkpointer (DetectionCheckpointer):
|
361 |
+
cfg (CfgNode):
|
362 |
+
"""
|
363 |
+
|
364 |
+
def __init__(self, cfg):
|
365 |
+
"""
|
366 |
+
Args:
|
367 |
+
cfg (CfgNode):
|
368 |
+
"""
|
369 |
+
super().__init__()
|
370 |
+
logger = logging.getLogger("detectron2")
|
371 |
+
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
372 |
+
setup_logger()
|
373 |
+
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
|
374 |
+
|
375 |
+
# Assume these objects must be constructed in this order.
|
376 |
+
model = self.build_model(cfg)
|
377 |
+
optimizer = self.build_optimizer(cfg, model)
|
378 |
+
data_loader = self.build_train_loader(cfg)
|
379 |
+
|
380 |
+
model = create_ddp_model(model, broadcast_buffers=False)
|
381 |
+
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
382 |
+
model, data_loader, optimizer
|
383 |
+
)
|
384 |
+
|
385 |
+
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
386 |
+
self.checkpointer = DetectionCheckpointer(
|
387 |
+
# Assume you want to save checkpoints together with logs/statistics
|
388 |
+
model,
|
389 |
+
cfg.OUTPUT_DIR,
|
390 |
+
trainer=weakref.proxy(self),
|
391 |
+
)
|
392 |
+
self.start_iter = 0
|
393 |
+
self.max_iter = cfg.SOLVER.MAX_ITER
|
394 |
+
self.cfg = cfg
|
395 |
+
|
396 |
+
self.register_hooks(self.build_hooks())
|
397 |
+
|
398 |
+
def resume_or_load(self, resume=True):
|
399 |
+
"""
|
400 |
+
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
|
401 |
+
a `last_checkpoint` file), resume from the file. Resuming means loading all
|
402 |
+
available states (eg. optimizer and scheduler) and update iteration counter
|
403 |
+
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
|
404 |
+
|
405 |
+
Otherwise, this is considered as an independent training. The method will load model
|
406 |
+
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
|
407 |
+
from iteration 0.
|
408 |
+
|
409 |
+
Args:
|
410 |
+
resume (bool): whether to do resume or not
|
411 |
+
"""
|
412 |
+
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
|
413 |
+
if resume and self.checkpointer.has_checkpoint():
|
414 |
+
# The checkpoint stores the training iteration that just finished, thus we start
|
415 |
+
# at the next iteration
|
416 |
+
self.start_iter = self.iter + 1
|
417 |
+
|
418 |
+
def build_hooks(self):
|
419 |
+
"""
|
420 |
+
Build a list of default hooks, including timing, evaluation,
|
421 |
+
checkpointing, lr scheduling, precise BN, writing events.
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
list[HookBase]:
|
425 |
+
"""
|
426 |
+
cfg = self.cfg.clone()
|
427 |
+
cfg.defrost()
|
428 |
+
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
429 |
+
|
430 |
+
ret = [
|
431 |
+
hooks.IterationTimer(),
|
432 |
+
hooks.LRScheduler(),
|
433 |
+
hooks.PreciseBN(
|
434 |
+
# Run at the same freq as (but before) evaluation.
|
435 |
+
cfg.TEST.EVAL_PERIOD,
|
436 |
+
self.model,
|
437 |
+
# Build a new data loader to not affect training
|
438 |
+
self.build_train_loader(cfg),
|
439 |
+
cfg.TEST.PRECISE_BN.NUM_ITER,
|
440 |
+
)
|
441 |
+
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
442 |
+
else None,
|
443 |
+
]
|
444 |
+
|
445 |
+
# Do PreciseBN before checkpointer, because it updates the model and need to
|
446 |
+
# be saved by checkpointer.
|
447 |
+
# This is not always the best: if checkpointing has a different frequency,
|
448 |
+
# some checkpoints may have more precise statistics than others.
|
449 |
+
if comm.is_main_process():
|
450 |
+
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
451 |
+
|
452 |
+
def test_and_save_results():
|
453 |
+
self._last_eval_results = self.test(self.cfg, self.model)
|
454 |
+
return self._last_eval_results
|
455 |
+
|
456 |
+
# Do evaluation after checkpointer, because then if it fails,
|
457 |
+
# we can use the saved checkpoint to debug.
|
458 |
+
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
459 |
+
|
460 |
+
if comm.is_main_process():
|
461 |
+
# Here the default print/log frequency of each writer is used.
|
462 |
+
# run writers in the end, so that evaluation metrics are written
|
463 |
+
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
|
464 |
+
return ret
|
465 |
+
|
466 |
+
def build_writers(self):
|
467 |
+
"""
|
468 |
+
Build a list of writers to be used using :func:`default_writers()`.
|
469 |
+
If you'd like a different list of writers, you can overwrite it in
|
470 |
+
your trainer.
|
471 |
+
|
472 |
+
Returns:
|
473 |
+
list[EventWriter]: a list of :class:`EventWriter` objects.
|
474 |
+
"""
|
475 |
+
return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
|
476 |
+
|
477 |
+
def train(self):
|
478 |
+
"""
|
479 |
+
Run training.
|
480 |
+
|
481 |
+
Returns:
|
482 |
+
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
483 |
+
"""
|
484 |
+
super().train(self.start_iter, self.max_iter)
|
485 |
+
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
|
486 |
+
assert hasattr(
|
487 |
+
self, "_last_eval_results"
|
488 |
+
), "No evaluation results obtained during training!"
|
489 |
+
verify_results(self.cfg, self._last_eval_results)
|
490 |
+
return self._last_eval_results
|
491 |
+
|
492 |
+
def run_step(self):
|
493 |
+
self._trainer.iter = self.iter
|
494 |
+
self._trainer.run_step()
|
495 |
+
|
496 |
+
def state_dict(self):
|
497 |
+
ret = super().state_dict()
|
498 |
+
ret["_trainer"] = self._trainer.state_dict()
|
499 |
+
return ret
|
500 |
+
|
501 |
+
def load_state_dict(self, state_dict):
|
502 |
+
super().load_state_dict(state_dict)
|
503 |
+
self._trainer.load_state_dict(state_dict["_trainer"])
|
504 |
+
|
505 |
+
@classmethod
|
506 |
+
def build_model(cls, cfg):
|
507 |
+
"""
|
508 |
+
Returns:
|
509 |
+
torch.nn.Module:
|
510 |
+
|
511 |
+
It now calls :func:`detectron2.modeling.build_model`.
|
512 |
+
Overwrite it if you'd like a different model.
|
513 |
+
"""
|
514 |
+
model = build_model(cfg)
|
515 |
+
logger = logging.getLogger(__name__)
|
516 |
+
logger.info("Model:\n{}".format(model))
|
517 |
+
return model
|
518 |
+
|
519 |
+
@classmethod
|
520 |
+
def build_optimizer(cls, cfg, model):
|
521 |
+
"""
|
522 |
+
Returns:
|
523 |
+
torch.optim.Optimizer:
|
524 |
+
|
525 |
+
It now calls :func:`detectron2.solver.build_optimizer`.
|
526 |
+
Overwrite it if you'd like a different optimizer.
|
527 |
+
"""
|
528 |
+
return build_optimizer(cfg, model)
|
529 |
+
|
530 |
+
@classmethod
|
531 |
+
def build_lr_scheduler(cls, cfg, optimizer):
|
532 |
+
"""
|
533 |
+
It now calls :func:`detectron2.solver.build_lr_scheduler`.
|
534 |
+
Overwrite it if you'd like a different scheduler.
|
535 |
+
"""
|
536 |
+
return build_lr_scheduler(cfg, optimizer)
|
537 |
+
|
538 |
+
@classmethod
|
539 |
+
def build_train_loader(cls, cfg):
|
540 |
+
"""
|
541 |
+
Returns:
|
542 |
+
iterable
|
543 |
+
|
544 |
+
It now calls :func:`detectron2.data.build_detection_train_loader`.
|
545 |
+
Overwrite it if you'd like a different data loader.
|
546 |
+
"""
|
547 |
+
return build_detection_train_loader(cfg)
|
548 |
+
|
549 |
+
@classmethod
|
550 |
+
def build_test_loader(cls, cfg, dataset_name):
|
551 |
+
"""
|
552 |
+
Returns:
|
553 |
+
iterable
|
554 |
+
|
555 |
+
It now calls :func:`detectron2.data.build_detection_test_loader`.
|
556 |
+
Overwrite it if you'd like a different data loader.
|
557 |
+
"""
|
558 |
+
return build_detection_test_loader(cfg, dataset_name)
|
559 |
+
|
560 |
+
@classmethod
|
561 |
+
def build_evaluator(cls, cfg, dataset_name):
|
562 |
+
"""
|
563 |
+
Returns:
|
564 |
+
DatasetEvaluator or None
|
565 |
+
|
566 |
+
It is not implemented by default.
|
567 |
+
"""
|
568 |
+
raise NotImplementedError(
|
569 |
+
"""
|
570 |
+
If you want DefaultTrainer to automatically run evaluation,
|
571 |
+
please implement `build_evaluator()` in subclasses (see train_net.py for example).
|
572 |
+
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
|
573 |
+
"""
|
574 |
+
)
|
575 |
+
|
576 |
+
@classmethod
|
577 |
+
def test(cls, cfg, model, evaluators=None):
|
578 |
+
"""
|
579 |
+
Evaluate the given model. The given model is expected to already contain
|
580 |
+
weights to evaluate.
|
581 |
+
|
582 |
+
Args:
|
583 |
+
cfg (CfgNode):
|
584 |
+
model (nn.Module):
|
585 |
+
evaluators (list[DatasetEvaluator] or None): if None, will call
|
586 |
+
:meth:`build_evaluator`. Otherwise, must have the same length as
|
587 |
+
``cfg.DATASETS.TEST``.
|
588 |
+
|
589 |
+
Returns:
|
590 |
+
dict: a dict of result metrics
|
591 |
+
"""
|
592 |
+
logger = logging.getLogger(__name__)
|
593 |
+
if isinstance(evaluators, DatasetEvaluator):
|
594 |
+
evaluators = [evaluators]
|
595 |
+
if evaluators is not None:
|
596 |
+
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
597 |
+
len(cfg.DATASETS.TEST), len(evaluators)
|
598 |
+
)
|
599 |
+
|
600 |
+
results = OrderedDict()
|
601 |
+
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
602 |
+
data_loader = cls.build_test_loader(cfg, dataset_name)
|
603 |
+
# When evaluators are passed in as arguments,
|
604 |
+
# implicitly assume that evaluators can be created before data_loader.
|
605 |
+
if evaluators is not None:
|
606 |
+
evaluator = evaluators[idx]
|
607 |
+
else:
|
608 |
+
try:
|
609 |
+
evaluator = cls.build_evaluator(cfg, dataset_name)
|
610 |
+
except NotImplementedError:
|
611 |
+
logger.warn(
|
612 |
+
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
613 |
+
"or implement its `build_evaluator` method."
|
614 |
+
)
|
615 |
+
results[dataset_name] = {}
|
616 |
+
continue
|
617 |
+
results_i = inference_on_dataset(model, data_loader, evaluator)
|
618 |
+
results[dataset_name] = results_i
|
619 |
+
if comm.is_main_process():
|
620 |
+
assert isinstance(
|
621 |
+
results_i, dict
|
622 |
+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
623 |
+
results_i
|
624 |
+
)
|
625 |
+
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
626 |
+
print_csv_format(results_i)
|
627 |
+
|
628 |
+
if len(results) == 1:
|
629 |
+
results = list(results.values())[0]
|
630 |
+
return results
|
631 |
+
|
632 |
+
@staticmethod
|
633 |
+
def auto_scale_workers(cfg, num_workers: int):
|
634 |
+
"""
|
635 |
+
When the config is defined for certain number of workers (according to
|
636 |
+
``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
|
637 |
+
workers currently in use, returns a new cfg where the total batch size
|
638 |
+
is scaled so that the per-GPU batch size stays the same as the
|
639 |
+
original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
|
640 |
+
|
641 |
+
Other config options are also scaled accordingly:
|
642 |
+
* training steps and warmup steps are scaled inverse proportionally.
|
643 |
+
* learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
|
644 |
+
|
645 |
+
For example, with the original config like the following:
|
646 |
+
|
647 |
+
.. code-block:: yaml
|
648 |
+
|
649 |
+
IMS_PER_BATCH: 16
|
650 |
+
BASE_LR: 0.1
|
651 |
+
REFERENCE_WORLD_SIZE: 8
|
652 |
+
MAX_ITER: 5000
|
653 |
+
STEPS: (4000,)
|
654 |
+
CHECKPOINT_PERIOD: 1000
|
655 |
+
|
656 |
+
When this config is used on 16 GPUs instead of the reference number 8,
|
657 |
+
calling this method will return a new config with:
|
658 |
+
|
659 |
+
.. code-block:: yaml
|
660 |
+
|
661 |
+
IMS_PER_BATCH: 32
|
662 |
+
BASE_LR: 0.2
|
663 |
+
REFERENCE_WORLD_SIZE: 16
|
664 |
+
MAX_ITER: 2500
|
665 |
+
STEPS: (2000,)
|
666 |
+
CHECKPOINT_PERIOD: 500
|
667 |
+
|
668 |
+
Note that both the original config and this new config can be trained on 16 GPUs.
|
669 |
+
It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
|
670 |
+
|
671 |
+
Returns:
|
672 |
+
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
|
673 |
+
"""
|
674 |
+
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
|
675 |
+
if old_world_size == 0 or old_world_size == num_workers:
|
676 |
+
return cfg
|
677 |
+
cfg = cfg.clone()
|
678 |
+
frozen = cfg.is_frozen()
|
679 |
+
cfg.defrost()
|
680 |
+
|
681 |
+
assert (
|
682 |
+
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
|
683 |
+
), "Invalid REFERENCE_WORLD_SIZE in config!"
|
684 |
+
scale = num_workers / old_world_size
|
685 |
+
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
|
686 |
+
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
|
687 |
+
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
|
688 |
+
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
|
689 |
+
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
|
690 |
+
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
|
691 |
+
cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
|
692 |
+
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
|
693 |
+
logger = logging.getLogger(__name__)
|
694 |
+
logger.info(
|
695 |
+
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
|
696 |
+
f"max_iter={max_iter}, warmup={warmup_iter}."
|
697 |
+
)
|
698 |
+
|
699 |
+
if frozen:
|
700 |
+
cfg.freeze()
|
701 |
+
return cfg
|
702 |
+
|
703 |
+
|
704 |
+
# Access basic attributes from the underlying trainer
|
705 |
+
for _attr in ["model", "data_loader", "optimizer"]:
|
706 |
+
setattr(
|
707 |
+
DefaultTrainer,
|
708 |
+
_attr,
|
709 |
+
property(
|
710 |
+
# getter
|
711 |
+
lambda self, x=_attr: getattr(self._trainer, x),
|
712 |
+
# setter
|
713 |
+
lambda self, value, x=_attr: setattr(self._trainer, x, value),
|
714 |
+
),
|
715 |
+
)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/hooks.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import datetime
|
5 |
+
import itertools
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
import operator
|
9 |
+
import os
|
10 |
+
import tempfile
|
11 |
+
import time
|
12 |
+
import warnings
|
13 |
+
from collections import Counter
|
14 |
+
import torch
|
15 |
+
from fvcore.common.checkpoint import Checkpointer
|
16 |
+
from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
17 |
+
from fvcore.common.param_scheduler import ParamScheduler
|
18 |
+
from fvcore.common.timer import Timer
|
19 |
+
from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats
|
20 |
+
|
21 |
+
import annotator.oneformer.detectron2.utils.comm as comm
|
22 |
+
from annotator.oneformer.detectron2.evaluation.testing import flatten_results_dict
|
23 |
+
from annotator.oneformer.detectron2.solver import LRMultiplier
|
24 |
+
from annotator.oneformer.detectron2.solver import LRScheduler as _LRScheduler
|
25 |
+
from annotator.oneformer.detectron2.utils.events import EventStorage, EventWriter
|
26 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
27 |
+
|
28 |
+
from .train_loop import HookBase
|
29 |
+
|
30 |
+
__all__ = [
|
31 |
+
"CallbackHook",
|
32 |
+
"IterationTimer",
|
33 |
+
"PeriodicWriter",
|
34 |
+
"PeriodicCheckpointer",
|
35 |
+
"BestCheckpointer",
|
36 |
+
"LRScheduler",
|
37 |
+
"AutogradProfiler",
|
38 |
+
"EvalHook",
|
39 |
+
"PreciseBN",
|
40 |
+
"TorchProfiler",
|
41 |
+
"TorchMemoryStats",
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
"""
|
46 |
+
Implement some common hooks.
|
47 |
+
"""
|
48 |
+
|
49 |
+
|
50 |
+
class CallbackHook(HookBase):
|
51 |
+
"""
|
52 |
+
Create a hook using callback functions provided by the user.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
|
56 |
+
"""
|
57 |
+
Each argument is a function that takes one argument: the trainer.
|
58 |
+
"""
|
59 |
+
self._before_train = before_train
|
60 |
+
self._before_step = before_step
|
61 |
+
self._after_step = after_step
|
62 |
+
self._after_train = after_train
|
63 |
+
|
64 |
+
def before_train(self):
|
65 |
+
if self._before_train:
|
66 |
+
self._before_train(self.trainer)
|
67 |
+
|
68 |
+
def after_train(self):
|
69 |
+
if self._after_train:
|
70 |
+
self._after_train(self.trainer)
|
71 |
+
# The functions may be closures that hold reference to the trainer
|
72 |
+
# Therefore, delete them to avoid circular reference.
|
73 |
+
del self._before_train, self._after_train
|
74 |
+
del self._before_step, self._after_step
|
75 |
+
|
76 |
+
def before_step(self):
|
77 |
+
if self._before_step:
|
78 |
+
self._before_step(self.trainer)
|
79 |
+
|
80 |
+
def after_step(self):
|
81 |
+
if self._after_step:
|
82 |
+
self._after_step(self.trainer)
|
83 |
+
|
84 |
+
|
85 |
+
class IterationTimer(HookBase):
|
86 |
+
"""
|
87 |
+
Track the time spent for each iteration (each run_step call in the trainer).
|
88 |
+
Print a summary in the end of training.
|
89 |
+
|
90 |
+
This hook uses the time between the call to its :meth:`before_step`
|
91 |
+
and :meth:`after_step` methods.
|
92 |
+
Under the convention that :meth:`before_step` of all hooks should only
|
93 |
+
take negligible amount of time, the :class:`IterationTimer` hook should be
|
94 |
+
placed at the beginning of the list of hooks to obtain accurate timing.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, warmup_iter=3):
|
98 |
+
"""
|
99 |
+
Args:
|
100 |
+
warmup_iter (int): the number of iterations at the beginning to exclude
|
101 |
+
from timing.
|
102 |
+
"""
|
103 |
+
self._warmup_iter = warmup_iter
|
104 |
+
self._step_timer = Timer()
|
105 |
+
self._start_time = time.perf_counter()
|
106 |
+
self._total_timer = Timer()
|
107 |
+
|
108 |
+
def before_train(self):
|
109 |
+
self._start_time = time.perf_counter()
|
110 |
+
self._total_timer.reset()
|
111 |
+
self._total_timer.pause()
|
112 |
+
|
113 |
+
def after_train(self):
|
114 |
+
logger = logging.getLogger(__name__)
|
115 |
+
total_time = time.perf_counter() - self._start_time
|
116 |
+
total_time_minus_hooks = self._total_timer.seconds()
|
117 |
+
hook_time = total_time - total_time_minus_hooks
|
118 |
+
|
119 |
+
num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter
|
120 |
+
|
121 |
+
if num_iter > 0 and total_time_minus_hooks > 0:
|
122 |
+
# Speed is meaningful only after warmup
|
123 |
+
# NOTE this format is parsed by grep in some scripts
|
124 |
+
logger.info(
|
125 |
+
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
|
126 |
+
num_iter,
|
127 |
+
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
|
128 |
+
total_time_minus_hooks / num_iter,
|
129 |
+
)
|
130 |
+
)
|
131 |
+
|
132 |
+
logger.info(
|
133 |
+
"Total training time: {} ({} on hooks)".format(
|
134 |
+
str(datetime.timedelta(seconds=int(total_time))),
|
135 |
+
str(datetime.timedelta(seconds=int(hook_time))),
|
136 |
+
)
|
137 |
+
)
|
138 |
+
|
139 |
+
def before_step(self):
|
140 |
+
self._step_timer.reset()
|
141 |
+
self._total_timer.resume()
|
142 |
+
|
143 |
+
def after_step(self):
|
144 |
+
# +1 because we're in after_step, the current step is done
|
145 |
+
# but not yet counted
|
146 |
+
iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1
|
147 |
+
if iter_done >= self._warmup_iter:
|
148 |
+
sec = self._step_timer.seconds()
|
149 |
+
self.trainer.storage.put_scalars(time=sec)
|
150 |
+
else:
|
151 |
+
self._start_time = time.perf_counter()
|
152 |
+
self._total_timer.reset()
|
153 |
+
|
154 |
+
self._total_timer.pause()
|
155 |
+
|
156 |
+
|
157 |
+
class PeriodicWriter(HookBase):
|
158 |
+
"""
|
159 |
+
Write events to EventStorage (by calling ``writer.write()``) periodically.
|
160 |
+
|
161 |
+
It is executed every ``period`` iterations and after the last iteration.
|
162 |
+
Note that ``period`` does not affect how data is smoothed by each writer.
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self, writers, period=20):
|
166 |
+
"""
|
167 |
+
Args:
|
168 |
+
writers (list[EventWriter]): a list of EventWriter objects
|
169 |
+
period (int):
|
170 |
+
"""
|
171 |
+
self._writers = writers
|
172 |
+
for w in writers:
|
173 |
+
assert isinstance(w, EventWriter), w
|
174 |
+
self._period = period
|
175 |
+
|
176 |
+
def after_step(self):
|
177 |
+
if (self.trainer.iter + 1) % self._period == 0 or (
|
178 |
+
self.trainer.iter == self.trainer.max_iter - 1
|
179 |
+
):
|
180 |
+
for writer in self._writers:
|
181 |
+
writer.write()
|
182 |
+
|
183 |
+
def after_train(self):
|
184 |
+
for writer in self._writers:
|
185 |
+
# If any new data is found (e.g. produced by other after_train),
|
186 |
+
# write them before closing
|
187 |
+
writer.write()
|
188 |
+
writer.close()
|
189 |
+
|
190 |
+
|
191 |
+
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
|
192 |
+
"""
|
193 |
+
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
|
194 |
+
|
195 |
+
Note that when used as a hook,
|
196 |
+
it is unable to save additional data other than what's defined
|
197 |
+
by the given `checkpointer`.
|
198 |
+
|
199 |
+
It is executed every ``period`` iterations and after the last iteration.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def before_train(self):
|
203 |
+
self.max_iter = self.trainer.max_iter
|
204 |
+
|
205 |
+
def after_step(self):
|
206 |
+
# No way to use **kwargs
|
207 |
+
self.step(self.trainer.iter)
|
208 |
+
|
209 |
+
|
210 |
+
class BestCheckpointer(HookBase):
|
211 |
+
"""
|
212 |
+
Checkpoints best weights based off given metric.
|
213 |
+
|
214 |
+
This hook should be used in conjunction to and executed after the hook
|
215 |
+
that produces the metric, e.g. `EvalHook`.
|
216 |
+
"""
|
217 |
+
|
218 |
+
def __init__(
|
219 |
+
self,
|
220 |
+
eval_period: int,
|
221 |
+
checkpointer: Checkpointer,
|
222 |
+
val_metric: str,
|
223 |
+
mode: str = "max",
|
224 |
+
file_prefix: str = "model_best",
|
225 |
+
) -> None:
|
226 |
+
"""
|
227 |
+
Args:
|
228 |
+
eval_period (int): the period `EvalHook` is set to run.
|
229 |
+
checkpointer: the checkpointer object used to save checkpoints.
|
230 |
+
val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50"
|
231 |
+
mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be
|
232 |
+
maximized or minimized, e.g. for "bbox/AP50" it should be "max"
|
233 |
+
file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best"
|
234 |
+
"""
|
235 |
+
self._logger = logging.getLogger(__name__)
|
236 |
+
self._period = eval_period
|
237 |
+
self._val_metric = val_metric
|
238 |
+
assert mode in [
|
239 |
+
"max",
|
240 |
+
"min",
|
241 |
+
], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.'
|
242 |
+
if mode == "max":
|
243 |
+
self._compare = operator.gt
|
244 |
+
else:
|
245 |
+
self._compare = operator.lt
|
246 |
+
self._checkpointer = checkpointer
|
247 |
+
self._file_prefix = file_prefix
|
248 |
+
self.best_metric = None
|
249 |
+
self.best_iter = None
|
250 |
+
|
251 |
+
def _update_best(self, val, iteration):
|
252 |
+
if math.isnan(val) or math.isinf(val):
|
253 |
+
return False
|
254 |
+
self.best_metric = val
|
255 |
+
self.best_iter = iteration
|
256 |
+
return True
|
257 |
+
|
258 |
+
def _best_checking(self):
|
259 |
+
metric_tuple = self.trainer.storage.latest().get(self._val_metric)
|
260 |
+
if metric_tuple is None:
|
261 |
+
self._logger.warning(
|
262 |
+
f"Given val metric {self._val_metric} does not seem to be computed/stored."
|
263 |
+
"Will not be checkpointing based on it."
|
264 |
+
)
|
265 |
+
return
|
266 |
+
else:
|
267 |
+
latest_metric, metric_iter = metric_tuple
|
268 |
+
|
269 |
+
if self.best_metric is None:
|
270 |
+
if self._update_best(latest_metric, metric_iter):
|
271 |
+
additional_state = {"iteration": metric_iter}
|
272 |
+
self._checkpointer.save(f"{self._file_prefix}", **additional_state)
|
273 |
+
self._logger.info(
|
274 |
+
f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps"
|
275 |
+
)
|
276 |
+
elif self._compare(latest_metric, self.best_metric):
|
277 |
+
additional_state = {"iteration": metric_iter}
|
278 |
+
self._checkpointer.save(f"{self._file_prefix}", **additional_state)
|
279 |
+
self._logger.info(
|
280 |
+
f"Saved best model as latest eval score for {self._val_metric} is "
|
281 |
+
f"{latest_metric:0.5f}, better than last best score "
|
282 |
+
f"{self.best_metric:0.5f} @ iteration {self.best_iter}."
|
283 |
+
)
|
284 |
+
self._update_best(latest_metric, metric_iter)
|
285 |
+
else:
|
286 |
+
self._logger.info(
|
287 |
+
f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, "
|
288 |
+
f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}."
|
289 |
+
)
|
290 |
+
|
291 |
+
def after_step(self):
|
292 |
+
# same conditions as `EvalHook`
|
293 |
+
next_iter = self.trainer.iter + 1
|
294 |
+
if (
|
295 |
+
self._period > 0
|
296 |
+
and next_iter % self._period == 0
|
297 |
+
and next_iter != self.trainer.max_iter
|
298 |
+
):
|
299 |
+
self._best_checking()
|
300 |
+
|
301 |
+
def after_train(self):
|
302 |
+
# same conditions as `EvalHook`
|
303 |
+
if self.trainer.iter + 1 >= self.trainer.max_iter:
|
304 |
+
self._best_checking()
|
305 |
+
|
306 |
+
|
307 |
+
class LRScheduler(HookBase):
|
308 |
+
"""
|
309 |
+
A hook which executes a torch builtin LR scheduler and summarizes the LR.
|
310 |
+
It is executed after every iteration.
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(self, optimizer=None, scheduler=None):
|
314 |
+
"""
|
315 |
+
Args:
|
316 |
+
optimizer (torch.optim.Optimizer):
|
317 |
+
scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler):
|
318 |
+
if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
|
319 |
+
in the optimizer.
|
320 |
+
|
321 |
+
If any argument is not given, will try to obtain it from the trainer.
|
322 |
+
"""
|
323 |
+
self._optimizer = optimizer
|
324 |
+
self._scheduler = scheduler
|
325 |
+
|
326 |
+
def before_train(self):
|
327 |
+
self._optimizer = self._optimizer or self.trainer.optimizer
|
328 |
+
if isinstance(self.scheduler, ParamScheduler):
|
329 |
+
self._scheduler = LRMultiplier(
|
330 |
+
self._optimizer,
|
331 |
+
self.scheduler,
|
332 |
+
self.trainer.max_iter,
|
333 |
+
last_iter=self.trainer.iter - 1,
|
334 |
+
)
|
335 |
+
self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def get_best_param_group_id(optimizer):
|
339 |
+
# NOTE: some heuristics on what LR to summarize
|
340 |
+
# summarize the param group with most parameters
|
341 |
+
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
|
342 |
+
|
343 |
+
if largest_group == 1:
|
344 |
+
# If all groups have one parameter,
|
345 |
+
# then find the most common initial LR, and use it for summary
|
346 |
+
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
|
347 |
+
lr = lr_count.most_common()[0][0]
|
348 |
+
for i, g in enumerate(optimizer.param_groups):
|
349 |
+
if g["lr"] == lr:
|
350 |
+
return i
|
351 |
+
else:
|
352 |
+
for i, g in enumerate(optimizer.param_groups):
|
353 |
+
if len(g["params"]) == largest_group:
|
354 |
+
return i
|
355 |
+
|
356 |
+
def after_step(self):
|
357 |
+
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
358 |
+
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
359 |
+
self.scheduler.step()
|
360 |
+
|
361 |
+
@property
|
362 |
+
def scheduler(self):
|
363 |
+
return self._scheduler or self.trainer.scheduler
|
364 |
+
|
365 |
+
def state_dict(self):
|
366 |
+
if isinstance(self.scheduler, _LRScheduler):
|
367 |
+
return self.scheduler.state_dict()
|
368 |
+
return {}
|
369 |
+
|
370 |
+
def load_state_dict(self, state_dict):
|
371 |
+
if isinstance(self.scheduler, _LRScheduler):
|
372 |
+
logger = logging.getLogger(__name__)
|
373 |
+
logger.info("Loading scheduler from state_dict ...")
|
374 |
+
self.scheduler.load_state_dict(state_dict)
|
375 |
+
|
376 |
+
|
377 |
+
class TorchProfiler(HookBase):
|
378 |
+
"""
|
379 |
+
A hook which runs `torch.profiler.profile`.
|
380 |
+
|
381 |
+
Examples:
|
382 |
+
::
|
383 |
+
hooks.TorchProfiler(
|
384 |
+
lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
|
385 |
+
)
|
386 |
+
|
387 |
+
The above example will run the profiler for iteration 10~20 and dump
|
388 |
+
results to ``OUTPUT_DIR``. We did not profile the first few iterations
|
389 |
+
because they are typically slower than the rest.
|
390 |
+
The result files can be loaded in the ``chrome://tracing`` page in chrome browser,
|
391 |
+
and the tensorboard visualizations can be visualized using
|
392 |
+
``tensorboard --logdir OUTPUT_DIR/log``
|
393 |
+
"""
|
394 |
+
|
395 |
+
def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True):
|
396 |
+
"""
|
397 |
+
Args:
|
398 |
+
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
|
399 |
+
and returns whether to enable the profiler.
|
400 |
+
It will be called once every step, and can be used to select which steps to profile.
|
401 |
+
output_dir (str): the output directory to dump tracing files.
|
402 |
+
activities (iterable): same as in `torch.profiler.profile`.
|
403 |
+
save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/
|
404 |
+
"""
|
405 |
+
self._enable_predicate = enable_predicate
|
406 |
+
self._activities = activities
|
407 |
+
self._output_dir = output_dir
|
408 |
+
self._save_tensorboard = save_tensorboard
|
409 |
+
|
410 |
+
def before_step(self):
|
411 |
+
if self._enable_predicate(self.trainer):
|
412 |
+
if self._save_tensorboard:
|
413 |
+
on_trace_ready = torch.profiler.tensorboard_trace_handler(
|
414 |
+
os.path.join(
|
415 |
+
self._output_dir,
|
416 |
+
"log",
|
417 |
+
"profiler-tensorboard-iter{}".format(self.trainer.iter),
|
418 |
+
),
|
419 |
+
f"worker{comm.get_rank()}",
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
on_trace_ready = None
|
423 |
+
self._profiler = torch.profiler.profile(
|
424 |
+
activities=self._activities,
|
425 |
+
on_trace_ready=on_trace_ready,
|
426 |
+
record_shapes=True,
|
427 |
+
profile_memory=True,
|
428 |
+
with_stack=True,
|
429 |
+
with_flops=True,
|
430 |
+
)
|
431 |
+
self._profiler.__enter__()
|
432 |
+
else:
|
433 |
+
self._profiler = None
|
434 |
+
|
435 |
+
def after_step(self):
|
436 |
+
if self._profiler is None:
|
437 |
+
return
|
438 |
+
self._profiler.__exit__(None, None, None)
|
439 |
+
if not self._save_tensorboard:
|
440 |
+
PathManager.mkdirs(self._output_dir)
|
441 |
+
out_file = os.path.join(
|
442 |
+
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
|
443 |
+
)
|
444 |
+
if "://" not in out_file:
|
445 |
+
self._profiler.export_chrome_trace(out_file)
|
446 |
+
else:
|
447 |
+
# Support non-posix filesystems
|
448 |
+
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
|
449 |
+
tmp_file = os.path.join(d, "tmp.json")
|
450 |
+
self._profiler.export_chrome_trace(tmp_file)
|
451 |
+
with open(tmp_file) as f:
|
452 |
+
content = f.read()
|
453 |
+
with PathManager.open(out_file, "w") as f:
|
454 |
+
f.write(content)
|
455 |
+
|
456 |
+
|
457 |
+
class AutogradProfiler(TorchProfiler):
|
458 |
+
"""
|
459 |
+
A hook which runs `torch.autograd.profiler.profile`.
|
460 |
+
|
461 |
+
Examples:
|
462 |
+
::
|
463 |
+
hooks.AutogradProfiler(
|
464 |
+
lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
|
465 |
+
)
|
466 |
+
|
467 |
+
The above example will run the profiler for iteration 10~20 and dump
|
468 |
+
results to ``OUTPUT_DIR``. We did not profile the first few iterations
|
469 |
+
because they are typically slower than the rest.
|
470 |
+
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
|
471 |
+
|
472 |
+
Note:
|
473 |
+
When used together with NCCL on older version of GPUs,
|
474 |
+
autograd profiler may cause deadlock because it unnecessarily allocates
|
475 |
+
memory on every device it sees. The memory management calls, if
|
476 |
+
interleaved with NCCL calls, lead to deadlock on GPUs that do not
|
477 |
+
support ``cudaLaunchCooperativeKernelMultiDevice``.
|
478 |
+
"""
|
479 |
+
|
480 |
+
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
|
481 |
+
"""
|
482 |
+
Args:
|
483 |
+
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
|
484 |
+
and returns whether to enable the profiler.
|
485 |
+
It will be called once every step, and can be used to select which steps to profile.
|
486 |
+
output_dir (str): the output directory to dump tracing files.
|
487 |
+
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
|
488 |
+
"""
|
489 |
+
warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.")
|
490 |
+
self._enable_predicate = enable_predicate
|
491 |
+
self._use_cuda = use_cuda
|
492 |
+
self._output_dir = output_dir
|
493 |
+
|
494 |
+
def before_step(self):
|
495 |
+
if self._enable_predicate(self.trainer):
|
496 |
+
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
|
497 |
+
self._profiler.__enter__()
|
498 |
+
else:
|
499 |
+
self._profiler = None
|
500 |
+
|
501 |
+
|
502 |
+
class EvalHook(HookBase):
|
503 |
+
"""
|
504 |
+
Run an evaluation function periodically, and at the end of training.
|
505 |
+
|
506 |
+
It is executed every ``eval_period`` iterations and after the last iteration.
|
507 |
+
"""
|
508 |
+
|
509 |
+
def __init__(self, eval_period, eval_function, eval_after_train=True):
|
510 |
+
"""
|
511 |
+
Args:
|
512 |
+
eval_period (int): the period to run `eval_function`. Set to 0 to
|
513 |
+
not evaluate periodically (but still evaluate after the last iteration
|
514 |
+
if `eval_after_train` is True).
|
515 |
+
eval_function (callable): a function which takes no arguments, and
|
516 |
+
returns a nested dict of evaluation metrics.
|
517 |
+
eval_after_train (bool): whether to evaluate after the last iteration
|
518 |
+
|
519 |
+
Note:
|
520 |
+
This hook must be enabled in all or none workers.
|
521 |
+
If you would like only certain workers to perform evaluation,
|
522 |
+
give other workers a no-op function (`eval_function=lambda: None`).
|
523 |
+
"""
|
524 |
+
self._period = eval_period
|
525 |
+
self._func = eval_function
|
526 |
+
self._eval_after_train = eval_after_train
|
527 |
+
|
528 |
+
def _do_eval(self):
|
529 |
+
results = self._func()
|
530 |
+
|
531 |
+
if results:
|
532 |
+
assert isinstance(
|
533 |
+
results, dict
|
534 |
+
), "Eval function must return a dict. Got {} instead.".format(results)
|
535 |
+
|
536 |
+
flattened_results = flatten_results_dict(results)
|
537 |
+
for k, v in flattened_results.items():
|
538 |
+
try:
|
539 |
+
v = float(v)
|
540 |
+
except Exception as e:
|
541 |
+
raise ValueError(
|
542 |
+
"[EvalHook] eval_function should return a nested dict of float. "
|
543 |
+
"Got '{}: {}' instead.".format(k, v)
|
544 |
+
) from e
|
545 |
+
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
|
546 |
+
|
547 |
+
# Evaluation may take different time among workers.
|
548 |
+
# A barrier make them start the next iteration together.
|
549 |
+
comm.synchronize()
|
550 |
+
|
551 |
+
def after_step(self):
|
552 |
+
next_iter = self.trainer.iter + 1
|
553 |
+
if self._period > 0 and next_iter % self._period == 0:
|
554 |
+
# do the last eval in after_train
|
555 |
+
if next_iter != self.trainer.max_iter:
|
556 |
+
self._do_eval()
|
557 |
+
|
558 |
+
def after_train(self):
|
559 |
+
# This condition is to prevent the eval from running after a failed training
|
560 |
+
if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter:
|
561 |
+
self._do_eval()
|
562 |
+
# func is likely a closure that holds reference to the trainer
|
563 |
+
# therefore we clean it to avoid circular reference in the end
|
564 |
+
del self._func
|
565 |
+
|
566 |
+
|
567 |
+
class PreciseBN(HookBase):
|
568 |
+
"""
|
569 |
+
The standard implementation of BatchNorm uses EMA in inference, which is
|
570 |
+
sometimes suboptimal.
|
571 |
+
This class computes the true average of statistics rather than the moving average,
|
572 |
+
and put true averages to every BN layer in the given model.
|
573 |
+
|
574 |
+
It is executed every ``period`` iterations and after the last iteration.
|
575 |
+
"""
|
576 |
+
|
577 |
+
def __init__(self, period, model, data_loader, num_iter):
|
578 |
+
"""
|
579 |
+
Args:
|
580 |
+
period (int): the period this hook is run, or 0 to not run during training.
|
581 |
+
The hook will always run in the end of training.
|
582 |
+
model (nn.Module): a module whose all BN layers in training mode will be
|
583 |
+
updated by precise BN.
|
584 |
+
Note that user is responsible for ensuring the BN layers to be
|
585 |
+
updated are in training mode when this hook is triggered.
|
586 |
+
data_loader (iterable): it will produce data to be run by `model(data)`.
|
587 |
+
num_iter (int): number of iterations used to compute the precise
|
588 |
+
statistics.
|
589 |
+
"""
|
590 |
+
self._logger = logging.getLogger(__name__)
|
591 |
+
if len(get_bn_modules(model)) == 0:
|
592 |
+
self._logger.info(
|
593 |
+
"PreciseBN is disabled because model does not contain BN layers in training mode."
|
594 |
+
)
|
595 |
+
self._disabled = True
|
596 |
+
return
|
597 |
+
|
598 |
+
self._model = model
|
599 |
+
self._data_loader = data_loader
|
600 |
+
self._num_iter = num_iter
|
601 |
+
self._period = period
|
602 |
+
self._disabled = False
|
603 |
+
|
604 |
+
self._data_iter = None
|
605 |
+
|
606 |
+
def after_step(self):
|
607 |
+
next_iter = self.trainer.iter + 1
|
608 |
+
is_final = next_iter == self.trainer.max_iter
|
609 |
+
if is_final or (self._period > 0 and next_iter % self._period == 0):
|
610 |
+
self.update_stats()
|
611 |
+
|
612 |
+
def update_stats(self):
|
613 |
+
"""
|
614 |
+
Update the model with precise statistics. Users can manually call this method.
|
615 |
+
"""
|
616 |
+
if self._disabled:
|
617 |
+
return
|
618 |
+
|
619 |
+
if self._data_iter is None:
|
620 |
+
self._data_iter = iter(self._data_loader)
|
621 |
+
|
622 |
+
def data_loader():
|
623 |
+
for num_iter in itertools.count(1):
|
624 |
+
if num_iter % 100 == 0:
|
625 |
+
self._logger.info(
|
626 |
+
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
627 |
+
)
|
628 |
+
# This way we can reuse the same iterator
|
629 |
+
yield next(self._data_iter)
|
630 |
+
|
631 |
+
with EventStorage(): # capture events in a new storage to discard them
|
632 |
+
self._logger.info(
|
633 |
+
"Running precise-BN for {} iterations... ".format(self._num_iter)
|
634 |
+
+ "Note that this could produce different statistics every time."
|
635 |
+
)
|
636 |
+
update_bn_stats(self._model, data_loader(), self._num_iter)
|
637 |
+
|
638 |
+
|
639 |
+
class TorchMemoryStats(HookBase):
|
640 |
+
"""
|
641 |
+
Writes pytorch's cuda memory statistics periodically.
|
642 |
+
"""
|
643 |
+
|
644 |
+
def __init__(self, period=20, max_runs=10):
|
645 |
+
"""
|
646 |
+
Args:
|
647 |
+
period (int): Output stats each 'period' iterations
|
648 |
+
max_runs (int): Stop the logging after 'max_runs'
|
649 |
+
"""
|
650 |
+
|
651 |
+
self._logger = logging.getLogger(__name__)
|
652 |
+
self._period = period
|
653 |
+
self._max_runs = max_runs
|
654 |
+
self._runs = 0
|
655 |
+
|
656 |
+
def after_step(self):
|
657 |
+
if self._runs > self._max_runs:
|
658 |
+
return
|
659 |
+
|
660 |
+
if (self.trainer.iter + 1) % self._period == 0 or (
|
661 |
+
self.trainer.iter == self.trainer.max_iter - 1
|
662 |
+
):
|
663 |
+
if torch.cuda.is_available():
|
664 |
+
max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0
|
665 |
+
reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0
|
666 |
+
max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
667 |
+
allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0
|
668 |
+
|
669 |
+
self._logger.info(
|
670 |
+
(
|
671 |
+
" iter: {} "
|
672 |
+
" max_reserved_mem: {:.0f}MB "
|
673 |
+
" reserved_mem: {:.0f}MB "
|
674 |
+
" max_allocated_mem: {:.0f}MB "
|
675 |
+
" allocated_mem: {:.0f}MB "
|
676 |
+
).format(
|
677 |
+
self.trainer.iter,
|
678 |
+
max_reserved_mb,
|
679 |
+
reserved_mb,
|
680 |
+
max_allocated_mb,
|
681 |
+
allocated_mb,
|
682 |
+
)
|
683 |
+
)
|
684 |
+
|
685 |
+
self._runs += 1
|
686 |
+
if self._runs == self._max_runs:
|
687 |
+
mem_summary = torch.cuda.memory_summary()
|
688 |
+
self._logger.info("\n" + mem_summary)
|
689 |
+
|
690 |
+
torch.cuda.reset_peak_memory_stats()
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/launch.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
from datetime import timedelta
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
import torch.multiprocessing as mp
|
7 |
+
|
8 |
+
from annotator.oneformer.detectron2.utils import comm
|
9 |
+
|
10 |
+
__all__ = ["DEFAULT_TIMEOUT", "launch"]
|
11 |
+
|
12 |
+
DEFAULT_TIMEOUT = timedelta(minutes=30)
|
13 |
+
|
14 |
+
|
15 |
+
def _find_free_port():
|
16 |
+
import socket
|
17 |
+
|
18 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
19 |
+
# Binding to port 0 will cause the OS to find an available port for us
|
20 |
+
sock.bind(("", 0))
|
21 |
+
port = sock.getsockname()[1]
|
22 |
+
sock.close()
|
23 |
+
# NOTE: there is still a chance the port could be taken by other processes.
|
24 |
+
return port
|
25 |
+
|
26 |
+
|
27 |
+
def launch(
|
28 |
+
main_func,
|
29 |
+
# Should be num_processes_per_machine, but kept for compatibility.
|
30 |
+
num_gpus_per_machine,
|
31 |
+
num_machines=1,
|
32 |
+
machine_rank=0,
|
33 |
+
dist_url=None,
|
34 |
+
args=(),
|
35 |
+
timeout=DEFAULT_TIMEOUT,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Launch multi-process or distributed training.
|
39 |
+
This function must be called on all machines involved in the training.
|
40 |
+
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
main_func: a function that will be called by `main_func(*args)`
|
44 |
+
num_gpus_per_machine (int): number of processes per machine. When
|
45 |
+
using GPUs, this should be the number of GPUs.
|
46 |
+
num_machines (int): the total number of machines
|
47 |
+
machine_rank (int): the rank of this machine
|
48 |
+
dist_url (str): url to connect to for distributed jobs, including protocol
|
49 |
+
e.g. "tcp://127.0.0.1:8686".
|
50 |
+
Can be set to "auto" to automatically select a free port on localhost
|
51 |
+
timeout (timedelta): timeout of the distributed workers
|
52 |
+
args (tuple): arguments passed to main_func
|
53 |
+
"""
|
54 |
+
world_size = num_machines * num_gpus_per_machine
|
55 |
+
if world_size > 1:
|
56 |
+
# https://github.com/pytorch/pytorch/pull/14391
|
57 |
+
# TODO prctl in spawned processes
|
58 |
+
|
59 |
+
if dist_url == "auto":
|
60 |
+
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
|
61 |
+
port = _find_free_port()
|
62 |
+
dist_url = f"tcp://127.0.0.1:{port}"
|
63 |
+
if num_machines > 1 and dist_url.startswith("file://"):
|
64 |
+
logger = logging.getLogger(__name__)
|
65 |
+
logger.warning(
|
66 |
+
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
|
67 |
+
)
|
68 |
+
|
69 |
+
mp.start_processes(
|
70 |
+
_distributed_worker,
|
71 |
+
nprocs=num_gpus_per_machine,
|
72 |
+
args=(
|
73 |
+
main_func,
|
74 |
+
world_size,
|
75 |
+
num_gpus_per_machine,
|
76 |
+
machine_rank,
|
77 |
+
dist_url,
|
78 |
+
args,
|
79 |
+
timeout,
|
80 |
+
),
|
81 |
+
daemon=False,
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
main_func(*args)
|
85 |
+
|
86 |
+
|
87 |
+
def _distributed_worker(
|
88 |
+
local_rank,
|
89 |
+
main_func,
|
90 |
+
world_size,
|
91 |
+
num_gpus_per_machine,
|
92 |
+
machine_rank,
|
93 |
+
dist_url,
|
94 |
+
args,
|
95 |
+
timeout=DEFAULT_TIMEOUT,
|
96 |
+
):
|
97 |
+
has_gpu = torch.cuda.is_available()
|
98 |
+
if has_gpu:
|
99 |
+
assert num_gpus_per_machine <= torch.cuda.device_count()
|
100 |
+
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
101 |
+
try:
|
102 |
+
dist.init_process_group(
|
103 |
+
backend="NCCL" if has_gpu else "GLOO",
|
104 |
+
init_method=dist_url,
|
105 |
+
world_size=world_size,
|
106 |
+
rank=global_rank,
|
107 |
+
timeout=timeout,
|
108 |
+
)
|
109 |
+
except Exception as e:
|
110 |
+
logger = logging.getLogger(__name__)
|
111 |
+
logger.error("Process group URL: {}".format(dist_url))
|
112 |
+
raise e
|
113 |
+
|
114 |
+
# Setup the local process group.
|
115 |
+
comm.create_local_process_group(num_gpus_per_machine)
|
116 |
+
if has_gpu:
|
117 |
+
torch.cuda.set_device(local_rank)
|
118 |
+
|
119 |
+
# synchronize is needed here to prevent a possible timeout after calling init_process_group
|
120 |
+
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
|
121 |
+
comm.synchronize()
|
122 |
+
|
123 |
+
main_func(*args)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/engine/train_loop.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
import weakref
|
8 |
+
from typing import List, Mapping, Optional
|
9 |
+
import torch
|
10 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
11 |
+
|
12 |
+
import annotator.oneformer.detectron2.utils.comm as comm
|
13 |
+
from annotator.oneformer.detectron2.utils.events import EventStorage, get_event_storage
|
14 |
+
from annotator.oneformer.detectron2.utils.logger import _log_api_usage
|
15 |
+
|
16 |
+
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"]
|
17 |
+
|
18 |
+
|
19 |
+
class HookBase:
|
20 |
+
"""
|
21 |
+
Base class for hooks that can be registered with :class:`TrainerBase`.
|
22 |
+
|
23 |
+
Each hook can implement 4 methods. The way they are called is demonstrated
|
24 |
+
in the following snippet:
|
25 |
+
::
|
26 |
+
hook.before_train()
|
27 |
+
for iter in range(start_iter, max_iter):
|
28 |
+
hook.before_step()
|
29 |
+
trainer.run_step()
|
30 |
+
hook.after_step()
|
31 |
+
iter += 1
|
32 |
+
hook.after_train()
|
33 |
+
|
34 |
+
Notes:
|
35 |
+
1. In the hook method, users can access ``self.trainer`` to access more
|
36 |
+
properties about the context (e.g., model, current iteration, or config
|
37 |
+
if using :class:`DefaultTrainer`).
|
38 |
+
|
39 |
+
2. A hook that does something in :meth:`before_step` can often be
|
40 |
+
implemented equivalently in :meth:`after_step`.
|
41 |
+
If the hook takes non-trivial time, it is strongly recommended to
|
42 |
+
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
|
43 |
+
The convention is that :meth:`before_step` should only take negligible time.
|
44 |
+
|
45 |
+
Following this convention will allow hooks that do care about the difference
|
46 |
+
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
|
47 |
+
function properly.
|
48 |
+
|
49 |
+
"""
|
50 |
+
|
51 |
+
trainer: "TrainerBase" = None
|
52 |
+
"""
|
53 |
+
A weak reference to the trainer object. Set by the trainer when the hook is registered.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def before_train(self):
|
57 |
+
"""
|
58 |
+
Called before the first iteration.
|
59 |
+
"""
|
60 |
+
pass
|
61 |
+
|
62 |
+
def after_train(self):
|
63 |
+
"""
|
64 |
+
Called after the last iteration.
|
65 |
+
"""
|
66 |
+
pass
|
67 |
+
|
68 |
+
def before_step(self):
|
69 |
+
"""
|
70 |
+
Called before each iteration.
|
71 |
+
"""
|
72 |
+
pass
|
73 |
+
|
74 |
+
def after_backward(self):
|
75 |
+
"""
|
76 |
+
Called after the backward pass of each iteration.
|
77 |
+
"""
|
78 |
+
pass
|
79 |
+
|
80 |
+
def after_step(self):
|
81 |
+
"""
|
82 |
+
Called after each iteration.
|
83 |
+
"""
|
84 |
+
pass
|
85 |
+
|
86 |
+
def state_dict(self):
|
87 |
+
"""
|
88 |
+
Hooks are stateless by default, but can be made checkpointable by
|
89 |
+
implementing `state_dict` and `load_state_dict`.
|
90 |
+
"""
|
91 |
+
return {}
|
92 |
+
|
93 |
+
|
94 |
+
class TrainerBase:
|
95 |
+
"""
|
96 |
+
Base class for iterative trainer with hooks.
|
97 |
+
|
98 |
+
The only assumption we made here is: the training runs in a loop.
|
99 |
+
A subclass can implement what the loop is.
|
100 |
+
We made no assumptions about the existence of dataloader, optimizer, model, etc.
|
101 |
+
|
102 |
+
Attributes:
|
103 |
+
iter(int): the current iteration.
|
104 |
+
|
105 |
+
start_iter(int): The iteration to start with.
|
106 |
+
By convention the minimum possible value is 0.
|
107 |
+
|
108 |
+
max_iter(int): The iteration to end training.
|
109 |
+
|
110 |
+
storage(EventStorage): An EventStorage that's opened during the course of training.
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self) -> None:
|
114 |
+
self._hooks: List[HookBase] = []
|
115 |
+
self.iter: int = 0
|
116 |
+
self.start_iter: int = 0
|
117 |
+
self.max_iter: int
|
118 |
+
self.storage: EventStorage
|
119 |
+
_log_api_usage("trainer." + self.__class__.__name__)
|
120 |
+
|
121 |
+
def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
|
122 |
+
"""
|
123 |
+
Register hooks to the trainer. The hooks are executed in the order
|
124 |
+
they are registered.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
hooks (list[Optional[HookBase]]): list of hooks
|
128 |
+
"""
|
129 |
+
hooks = [h for h in hooks if h is not None]
|
130 |
+
for h in hooks:
|
131 |
+
assert isinstance(h, HookBase)
|
132 |
+
# To avoid circular reference, hooks and trainer cannot own each other.
|
133 |
+
# This normally does not matter, but will cause memory leak if the
|
134 |
+
# involved objects contain __del__:
|
135 |
+
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
|
136 |
+
h.trainer = weakref.proxy(self)
|
137 |
+
self._hooks.extend(hooks)
|
138 |
+
|
139 |
+
def train(self, start_iter: int, max_iter: int):
|
140 |
+
"""
|
141 |
+
Args:
|
142 |
+
start_iter, max_iter (int): See docs above
|
143 |
+
"""
|
144 |
+
logger = logging.getLogger(__name__)
|
145 |
+
logger.info("Starting training from iteration {}".format(start_iter))
|
146 |
+
|
147 |
+
self.iter = self.start_iter = start_iter
|
148 |
+
self.max_iter = max_iter
|
149 |
+
|
150 |
+
with EventStorage(start_iter) as self.storage:
|
151 |
+
try:
|
152 |
+
self.before_train()
|
153 |
+
for self.iter in range(start_iter, max_iter):
|
154 |
+
self.before_step()
|
155 |
+
self.run_step()
|
156 |
+
self.after_step()
|
157 |
+
# self.iter == max_iter can be used by `after_train` to
|
158 |
+
# tell whether the training successfully finished or failed
|
159 |
+
# due to exceptions.
|
160 |
+
self.iter += 1
|
161 |
+
except Exception:
|
162 |
+
logger.exception("Exception during training:")
|
163 |
+
raise
|
164 |
+
finally:
|
165 |
+
self.after_train()
|
166 |
+
|
167 |
+
def before_train(self):
|
168 |
+
for h in self._hooks:
|
169 |
+
h.before_train()
|
170 |
+
|
171 |
+
def after_train(self):
|
172 |
+
self.storage.iter = self.iter
|
173 |
+
for h in self._hooks:
|
174 |
+
h.after_train()
|
175 |
+
|
176 |
+
def before_step(self):
|
177 |
+
# Maintain the invariant that storage.iter == trainer.iter
|
178 |
+
# for the entire execution of each step
|
179 |
+
self.storage.iter = self.iter
|
180 |
+
|
181 |
+
for h in self._hooks:
|
182 |
+
h.before_step()
|
183 |
+
|
184 |
+
def after_backward(self):
|
185 |
+
for h in self._hooks:
|
186 |
+
h.after_backward()
|
187 |
+
|
188 |
+
def after_step(self):
|
189 |
+
for h in self._hooks:
|
190 |
+
h.after_step()
|
191 |
+
|
192 |
+
def run_step(self):
|
193 |
+
raise NotImplementedError
|
194 |
+
|
195 |
+
def state_dict(self):
|
196 |
+
ret = {"iteration": self.iter}
|
197 |
+
hooks_state = {}
|
198 |
+
for h in self._hooks:
|
199 |
+
sd = h.state_dict()
|
200 |
+
if sd:
|
201 |
+
name = type(h).__qualname__
|
202 |
+
if name in hooks_state:
|
203 |
+
# TODO handle repetitive stateful hooks
|
204 |
+
continue
|
205 |
+
hooks_state[name] = sd
|
206 |
+
if hooks_state:
|
207 |
+
ret["hooks"] = hooks_state
|
208 |
+
return ret
|
209 |
+
|
210 |
+
def load_state_dict(self, state_dict):
|
211 |
+
logger = logging.getLogger(__name__)
|
212 |
+
self.iter = state_dict["iteration"]
|
213 |
+
for key, value in state_dict.get("hooks", {}).items():
|
214 |
+
for h in self._hooks:
|
215 |
+
try:
|
216 |
+
name = type(h).__qualname__
|
217 |
+
except AttributeError:
|
218 |
+
continue
|
219 |
+
if name == key:
|
220 |
+
h.load_state_dict(value)
|
221 |
+
break
|
222 |
+
else:
|
223 |
+
logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.")
|
224 |
+
|
225 |
+
|
226 |
+
class SimpleTrainer(TrainerBase):
|
227 |
+
"""
|
228 |
+
A simple trainer for the most common type of task:
|
229 |
+
single-cost single-optimizer single-data-source iterative optimization,
|
230 |
+
optionally using data-parallelism.
|
231 |
+
It assumes that every step, you:
|
232 |
+
|
233 |
+
1. Compute the loss with a data from the data_loader.
|
234 |
+
2. Compute the gradients with the above loss.
|
235 |
+
3. Update the model with the optimizer.
|
236 |
+
|
237 |
+
All other tasks during training (checkpointing, logging, evaluation, LR schedule)
|
238 |
+
are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
|
239 |
+
|
240 |
+
If you want to do anything fancier than this,
|
241 |
+
either subclass TrainerBase and implement your own `run_step`,
|
242 |
+
or write your own training loop.
|
243 |
+
"""
|
244 |
+
|
245 |
+
def __init__(self, model, data_loader, optimizer, gather_metric_period=1):
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
model: a torch Module. Takes a data from data_loader and returns a
|
249 |
+
dict of losses.
|
250 |
+
data_loader: an iterable. Contains data to be used to call model.
|
251 |
+
optimizer: a torch optimizer.
|
252 |
+
gather_metric_period: an int. Every gather_metric_period iterations
|
253 |
+
the metrics are gathered from all the ranks to rank 0 and logged.
|
254 |
+
"""
|
255 |
+
super().__init__()
|
256 |
+
|
257 |
+
"""
|
258 |
+
We set the model to training mode in the trainer.
|
259 |
+
However it's valid to train a model that's in eval mode.
|
260 |
+
If you want your model (or a submodule of it) to behave
|
261 |
+
like evaluation during training, you can overwrite its train() method.
|
262 |
+
"""
|
263 |
+
model.train()
|
264 |
+
|
265 |
+
self.model = model
|
266 |
+
self.data_loader = data_loader
|
267 |
+
# to access the data loader iterator, call `self._data_loader_iter`
|
268 |
+
self._data_loader_iter_obj = None
|
269 |
+
self.optimizer = optimizer
|
270 |
+
self.gather_metric_period = gather_metric_period
|
271 |
+
|
272 |
+
def run_step(self):
|
273 |
+
"""
|
274 |
+
Implement the standard training logic described above.
|
275 |
+
"""
|
276 |
+
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
|
277 |
+
start = time.perf_counter()
|
278 |
+
"""
|
279 |
+
If you want to do something with the data, you can wrap the dataloader.
|
280 |
+
"""
|
281 |
+
data = next(self._data_loader_iter)
|
282 |
+
data_time = time.perf_counter() - start
|
283 |
+
|
284 |
+
"""
|
285 |
+
If you want to do something with the losses, you can wrap the model.
|
286 |
+
"""
|
287 |
+
loss_dict = self.model(data)
|
288 |
+
if isinstance(loss_dict, torch.Tensor):
|
289 |
+
losses = loss_dict
|
290 |
+
loss_dict = {"total_loss": loss_dict}
|
291 |
+
else:
|
292 |
+
losses = sum(loss_dict.values())
|
293 |
+
|
294 |
+
"""
|
295 |
+
If you need to accumulate gradients or do something similar, you can
|
296 |
+
wrap the optimizer with your custom `zero_grad()` method.
|
297 |
+
"""
|
298 |
+
self.optimizer.zero_grad()
|
299 |
+
losses.backward()
|
300 |
+
|
301 |
+
self.after_backward()
|
302 |
+
|
303 |
+
self._write_metrics(loss_dict, data_time)
|
304 |
+
|
305 |
+
"""
|
306 |
+
If you need gradient clipping/scaling or other processing, you can
|
307 |
+
wrap the optimizer with your custom `step()` method. But it is
|
308 |
+
suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
|
309 |
+
"""
|
310 |
+
self.optimizer.step()
|
311 |
+
|
312 |
+
@property
|
313 |
+
def _data_loader_iter(self):
|
314 |
+
# only create the data loader iterator when it is used
|
315 |
+
if self._data_loader_iter_obj is None:
|
316 |
+
self._data_loader_iter_obj = iter(self.data_loader)
|
317 |
+
return self._data_loader_iter_obj
|
318 |
+
|
319 |
+
def reset_data_loader(self, data_loader_builder):
|
320 |
+
"""
|
321 |
+
Delete and replace the current data loader with a new one, which will be created
|
322 |
+
by calling `data_loader_builder` (without argument).
|
323 |
+
"""
|
324 |
+
del self.data_loader
|
325 |
+
data_loader = data_loader_builder()
|
326 |
+
self.data_loader = data_loader
|
327 |
+
self._data_loader_iter_obj = None
|
328 |
+
|
329 |
+
def _write_metrics(
|
330 |
+
self,
|
331 |
+
loss_dict: Mapping[str, torch.Tensor],
|
332 |
+
data_time: float,
|
333 |
+
prefix: str = "",
|
334 |
+
) -> None:
|
335 |
+
if (self.iter + 1) % self.gather_metric_period == 0:
|
336 |
+
SimpleTrainer.write_metrics(loss_dict, data_time, prefix)
|
337 |
+
|
338 |
+
@staticmethod
|
339 |
+
def write_metrics(
|
340 |
+
loss_dict: Mapping[str, torch.Tensor],
|
341 |
+
data_time: float,
|
342 |
+
prefix: str = "",
|
343 |
+
) -> None:
|
344 |
+
"""
|
345 |
+
Args:
|
346 |
+
loss_dict (dict): dict of scalar losses
|
347 |
+
data_time (float): time taken by the dataloader iteration
|
348 |
+
prefix (str): prefix for logging keys
|
349 |
+
"""
|
350 |
+
metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
|
351 |
+
metrics_dict["data_time"] = data_time
|
352 |
+
|
353 |
+
# Gather metrics among all workers for logging
|
354 |
+
# This assumes we do DDP-style training, which is currently the only
|
355 |
+
# supported method in detectron2.
|
356 |
+
all_metrics_dict = comm.gather(metrics_dict)
|
357 |
+
|
358 |
+
if comm.is_main_process():
|
359 |
+
storage = get_event_storage()
|
360 |
+
|
361 |
+
# data_time among workers can have high variance. The actual latency
|
362 |
+
# caused by data_time is the maximum among workers.
|
363 |
+
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
364 |
+
storage.put_scalar("data_time", data_time)
|
365 |
+
|
366 |
+
# average the rest metrics
|
367 |
+
metrics_dict = {
|
368 |
+
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
|
369 |
+
}
|
370 |
+
total_losses_reduced = sum(metrics_dict.values())
|
371 |
+
if not np.isfinite(total_losses_reduced):
|
372 |
+
raise FloatingPointError(
|
373 |
+
f"Loss became infinite or NaN at iteration={storage.iter}!\n"
|
374 |
+
f"loss_dict = {metrics_dict}"
|
375 |
+
)
|
376 |
+
|
377 |
+
storage.put_scalar("{}total_loss".format(prefix), total_losses_reduced)
|
378 |
+
if len(metrics_dict) > 1:
|
379 |
+
storage.put_scalars(**metrics_dict)
|
380 |
+
|
381 |
+
def state_dict(self):
|
382 |
+
ret = super().state_dict()
|
383 |
+
ret["optimizer"] = self.optimizer.state_dict()
|
384 |
+
return ret
|
385 |
+
|
386 |
+
def load_state_dict(self, state_dict):
|
387 |
+
super().load_state_dict(state_dict)
|
388 |
+
self.optimizer.load_state_dict(state_dict["optimizer"])
|
389 |
+
|
390 |
+
|
391 |
+
class AMPTrainer(SimpleTrainer):
|
392 |
+
"""
|
393 |
+
Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision
|
394 |
+
in the training loop.
|
395 |
+
"""
|
396 |
+
|
397 |
+
def __init__(
|
398 |
+
self,
|
399 |
+
model,
|
400 |
+
data_loader,
|
401 |
+
optimizer,
|
402 |
+
gather_metric_period=1,
|
403 |
+
grad_scaler=None,
|
404 |
+
precision: torch.dtype = torch.float16,
|
405 |
+
log_grad_scaler: bool = False,
|
406 |
+
):
|
407 |
+
"""
|
408 |
+
Args:
|
409 |
+
model, data_loader, optimizer, gather_metric_period: same as in :class:`SimpleTrainer`.
|
410 |
+
grad_scaler: torch GradScaler to automatically scale gradients.
|
411 |
+
precision: torch.dtype as the target precision to cast to in computations
|
412 |
+
"""
|
413 |
+
unsupported = "AMPTrainer does not support single-process multi-device training!"
|
414 |
+
if isinstance(model, DistributedDataParallel):
|
415 |
+
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
|
416 |
+
assert not isinstance(model, DataParallel), unsupported
|
417 |
+
|
418 |
+
super().__init__(model, data_loader, optimizer, gather_metric_period)
|
419 |
+
|
420 |
+
if grad_scaler is None:
|
421 |
+
from torch.cuda.amp import GradScaler
|
422 |
+
|
423 |
+
grad_scaler = GradScaler()
|
424 |
+
self.grad_scaler = grad_scaler
|
425 |
+
self.precision = precision
|
426 |
+
self.log_grad_scaler = log_grad_scaler
|
427 |
+
|
428 |
+
def run_step(self):
|
429 |
+
"""
|
430 |
+
Implement the AMP training logic.
|
431 |
+
"""
|
432 |
+
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
|
433 |
+
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
|
434 |
+
from torch.cuda.amp import autocast
|
435 |
+
|
436 |
+
start = time.perf_counter()
|
437 |
+
data = next(self._data_loader_iter)
|
438 |
+
data_time = time.perf_counter() - start
|
439 |
+
|
440 |
+
with autocast(dtype=self.precision):
|
441 |
+
loss_dict = self.model(data)
|
442 |
+
if isinstance(loss_dict, torch.Tensor):
|
443 |
+
losses = loss_dict
|
444 |
+
loss_dict = {"total_loss": loss_dict}
|
445 |
+
else:
|
446 |
+
losses = sum(loss_dict.values())
|
447 |
+
|
448 |
+
self.optimizer.zero_grad()
|
449 |
+
self.grad_scaler.scale(losses).backward()
|
450 |
+
|
451 |
+
if self.log_grad_scaler:
|
452 |
+
storage = get_event_storage()
|
453 |
+
storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale())
|
454 |
+
|
455 |
+
self.after_backward()
|
456 |
+
|
457 |
+
self._write_metrics(loss_dict, data_time)
|
458 |
+
|
459 |
+
self.grad_scaler.step(self.optimizer)
|
460 |
+
self.grad_scaler.update()
|
461 |
+
|
462 |
+
def state_dict(self):
|
463 |
+
ret = super().state_dict()
|
464 |
+
ret["grad_scaler"] = self.grad_scaler.state_dict()
|
465 |
+
return ret
|
466 |
+
|
467 |
+
def load_state_dict(self, state_dict):
|
468 |
+
super().load_state_dict(state_dict)
|
469 |
+
self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .cityscapes_evaluation import CityscapesInstanceEvaluator, CityscapesSemSegEvaluator
|
3 |
+
from .coco_evaluation import COCOEvaluator
|
4 |
+
from .rotated_coco_evaluation import RotatedCOCOEvaluator
|
5 |
+
from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset
|
6 |
+
from .lvis_evaluation import LVISEvaluator
|
7 |
+
from .panoptic_evaluation import COCOPanopticEvaluator
|
8 |
+
from .pascal_voc_evaluation import PascalVOCDetectionEvaluator
|
9 |
+
from .sem_seg_evaluation import SemSegEvaluator
|
10 |
+
from .testing import print_csv_format, verify_results
|
11 |
+
|
12 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/cityscapes_evaluation.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
from collections import OrderedDict
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from annotator.oneformer.detectron2.data import MetadataCatalog
|
12 |
+
from annotator.oneformer.detectron2.utils import comm
|
13 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
14 |
+
|
15 |
+
from .evaluator import DatasetEvaluator
|
16 |
+
|
17 |
+
|
18 |
+
class CityscapesEvaluator(DatasetEvaluator):
|
19 |
+
"""
|
20 |
+
Base class for evaluation using cityscapes API.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, dataset_name):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
dataset_name (str): the name of the dataset.
|
27 |
+
It must have the following metadata associated with it:
|
28 |
+
"thing_classes", "gt_dir".
|
29 |
+
"""
|
30 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
31 |
+
self._cpu_device = torch.device("cpu")
|
32 |
+
self._logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
def reset(self):
|
35 |
+
self._working_dir = tempfile.TemporaryDirectory(prefix="cityscapes_eval_")
|
36 |
+
self._temp_dir = self._working_dir.name
|
37 |
+
# All workers will write to the same results directory
|
38 |
+
# TODO this does not work in distributed training
|
39 |
+
assert (
|
40 |
+
comm.get_local_size() == comm.get_world_size()
|
41 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
42 |
+
self._temp_dir = comm.all_gather(self._temp_dir)[0]
|
43 |
+
if self._temp_dir != self._working_dir.name:
|
44 |
+
self._working_dir.cleanup()
|
45 |
+
self._logger.info(
|
46 |
+
"Writing cityscapes results to temporary directory {} ...".format(self._temp_dir)
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
class CityscapesInstanceEvaluator(CityscapesEvaluator):
|
51 |
+
"""
|
52 |
+
Evaluate instance segmentation results on cityscapes dataset using cityscapes API.
|
53 |
+
|
54 |
+
Note:
|
55 |
+
* It does not work in multi-machine distributed training.
|
56 |
+
* It contains a synchronization, therefore has to be used on all ranks.
|
57 |
+
* Only the main process runs evaluation.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def process(self, inputs, outputs):
|
61 |
+
from cityscapesscripts.helpers.labels import name2label
|
62 |
+
|
63 |
+
for input, output in zip(inputs, outputs):
|
64 |
+
file_name = input["file_name"]
|
65 |
+
basename = os.path.splitext(os.path.basename(file_name))[0]
|
66 |
+
pred_txt = os.path.join(self._temp_dir, basename + "_pred.txt")
|
67 |
+
|
68 |
+
if "instances" in output:
|
69 |
+
output = output["instances"].to(self._cpu_device)
|
70 |
+
num_instances = len(output)
|
71 |
+
with open(pred_txt, "w") as fout:
|
72 |
+
for i in range(num_instances):
|
73 |
+
pred_class = output.pred_classes[i]
|
74 |
+
classes = self._metadata.thing_classes[pred_class]
|
75 |
+
class_id = name2label[classes].id
|
76 |
+
score = output.scores[i]
|
77 |
+
mask = output.pred_masks[i].numpy().astype("uint8")
|
78 |
+
png_filename = os.path.join(
|
79 |
+
self._temp_dir, basename + "_{}_{}.png".format(i, classes)
|
80 |
+
)
|
81 |
+
|
82 |
+
Image.fromarray(mask * 255).save(png_filename)
|
83 |
+
fout.write(
|
84 |
+
"{} {} {}\n".format(os.path.basename(png_filename), class_id, score)
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
# Cityscapes requires a prediction file for every ground truth image.
|
88 |
+
with open(pred_txt, "w") as fout:
|
89 |
+
pass
|
90 |
+
|
91 |
+
def evaluate(self):
|
92 |
+
"""
|
93 |
+
Returns:
|
94 |
+
dict: has a key "segm", whose value is a dict of "AP" and "AP50".
|
95 |
+
"""
|
96 |
+
comm.synchronize()
|
97 |
+
if comm.get_rank() > 0:
|
98 |
+
return
|
99 |
+
import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as cityscapes_eval
|
100 |
+
|
101 |
+
self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
|
102 |
+
|
103 |
+
# set some global states in cityscapes evaluation API, before evaluating
|
104 |
+
cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
|
105 |
+
cityscapes_eval.args.predictionWalk = None
|
106 |
+
cityscapes_eval.args.JSONOutput = False
|
107 |
+
cityscapes_eval.args.colorized = False
|
108 |
+
cityscapes_eval.args.gtInstancesFile = os.path.join(self._temp_dir, "gtInstances.json")
|
109 |
+
|
110 |
+
# These lines are adopted from
|
111 |
+
# https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
|
112 |
+
gt_dir = PathManager.get_local_path(self._metadata.gt_dir)
|
113 |
+
groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_instanceIds.png"))
|
114 |
+
assert len(
|
115 |
+
groundTruthImgList
|
116 |
+
), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
|
117 |
+
cityscapes_eval.args.groundTruthSearch
|
118 |
+
)
|
119 |
+
predictionImgList = []
|
120 |
+
for gt in groundTruthImgList:
|
121 |
+
predictionImgList.append(cityscapes_eval.getPrediction(gt, cityscapes_eval.args))
|
122 |
+
results = cityscapes_eval.evaluateImgLists(
|
123 |
+
predictionImgList, groundTruthImgList, cityscapes_eval.args
|
124 |
+
)["averages"]
|
125 |
+
|
126 |
+
ret = OrderedDict()
|
127 |
+
ret["segm"] = {"AP": results["allAp"] * 100, "AP50": results["allAp50%"] * 100}
|
128 |
+
self._working_dir.cleanup()
|
129 |
+
return ret
|
130 |
+
|
131 |
+
|
132 |
+
class CityscapesSemSegEvaluator(CityscapesEvaluator):
|
133 |
+
"""
|
134 |
+
Evaluate semantic segmentation results on cityscapes dataset using cityscapes API.
|
135 |
+
|
136 |
+
Note:
|
137 |
+
* It does not work in multi-machine distributed training.
|
138 |
+
* It contains a synchronization, therefore has to be used on all ranks.
|
139 |
+
* Only the main process runs evaluation.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def process(self, inputs, outputs):
|
143 |
+
from cityscapesscripts.helpers.labels import trainId2label
|
144 |
+
|
145 |
+
for input, output in zip(inputs, outputs):
|
146 |
+
file_name = input["file_name"]
|
147 |
+
basename = os.path.splitext(os.path.basename(file_name))[0]
|
148 |
+
pred_filename = os.path.join(self._temp_dir, basename + "_pred.png")
|
149 |
+
|
150 |
+
output = output["sem_seg"].argmax(dim=0).to(self._cpu_device).numpy()
|
151 |
+
pred = 255 * np.ones(output.shape, dtype=np.uint8)
|
152 |
+
for train_id, label in trainId2label.items():
|
153 |
+
if label.ignoreInEval:
|
154 |
+
continue
|
155 |
+
pred[output == train_id] = label.id
|
156 |
+
Image.fromarray(pred).save(pred_filename)
|
157 |
+
|
158 |
+
def evaluate(self):
|
159 |
+
comm.synchronize()
|
160 |
+
if comm.get_rank() > 0:
|
161 |
+
return
|
162 |
+
# Load the Cityscapes eval script *after* setting the required env var,
|
163 |
+
# since the script reads CITYSCAPES_DATASET into global variables at load time.
|
164 |
+
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_eval
|
165 |
+
|
166 |
+
self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
|
167 |
+
|
168 |
+
# set some global states in cityscapes evaluation API, before evaluating
|
169 |
+
cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
|
170 |
+
cityscapes_eval.args.predictionWalk = None
|
171 |
+
cityscapes_eval.args.JSONOutput = False
|
172 |
+
cityscapes_eval.args.colorized = False
|
173 |
+
|
174 |
+
# These lines are adopted from
|
175 |
+
# https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py # noqa
|
176 |
+
gt_dir = PathManager.get_local_path(self._metadata.gt_dir)
|
177 |
+
groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png"))
|
178 |
+
assert len(
|
179 |
+
groundTruthImgList
|
180 |
+
), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
|
181 |
+
cityscapes_eval.args.groundTruthSearch
|
182 |
+
)
|
183 |
+
predictionImgList = []
|
184 |
+
for gt in groundTruthImgList:
|
185 |
+
predictionImgList.append(cityscapes_eval.getPrediction(cityscapes_eval.args, gt))
|
186 |
+
results = cityscapes_eval.evaluateImgLists(
|
187 |
+
predictionImgList, groundTruthImgList, cityscapes_eval.args
|
188 |
+
)
|
189 |
+
ret = OrderedDict()
|
190 |
+
ret["sem_seg"] = {
|
191 |
+
"IoU": 100.0 * results["averageScoreClasses"],
|
192 |
+
"iIoU": 100.0 * results["averageScoreInstClasses"],
|
193 |
+
"IoU_sup": 100.0 * results["averageScoreCategories"],
|
194 |
+
"iIoU_sup": 100.0 * results["averageScoreInstCategories"],
|
195 |
+
}
|
196 |
+
self._working_dir.cleanup()
|
197 |
+
return ret
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/coco_evaluation.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import contextlib
|
3 |
+
import copy
|
4 |
+
import io
|
5 |
+
import itertools
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import pickle
|
11 |
+
from collections import OrderedDict
|
12 |
+
import annotator.oneformer.pycocotools.mask as mask_util
|
13 |
+
import torch
|
14 |
+
from annotator.oneformer.pycocotools.coco import COCO
|
15 |
+
from annotator.oneformer.pycocotools.cocoeval import COCOeval
|
16 |
+
from tabulate import tabulate
|
17 |
+
|
18 |
+
import annotator.oneformer.detectron2.utils.comm as comm
|
19 |
+
from annotator.oneformer.detectron2.config import CfgNode
|
20 |
+
from annotator.oneformer.detectron2.data import MetadataCatalog
|
21 |
+
from annotator.oneformer.detectron2.data.datasets.coco import convert_to_coco_json
|
22 |
+
from annotator.oneformer.detectron2.structures import Boxes, BoxMode, pairwise_iou
|
23 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
24 |
+
from annotator.oneformer.detectron2.utils.logger import create_small_table
|
25 |
+
|
26 |
+
from .evaluator import DatasetEvaluator
|
27 |
+
|
28 |
+
try:
|
29 |
+
from annotator.oneformer.detectron2.evaluation.fast_eval_api import COCOeval_opt
|
30 |
+
except ImportError:
|
31 |
+
COCOeval_opt = COCOeval
|
32 |
+
|
33 |
+
|
34 |
+
class COCOEvaluator(DatasetEvaluator):
|
35 |
+
"""
|
36 |
+
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
|
37 |
+
for keypoint detection outputs using COCO's metrics.
|
38 |
+
See http://cocodataset.org/#detection-eval and
|
39 |
+
http://cocodataset.org/#keypoints-eval to understand its metrics.
|
40 |
+
The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
|
41 |
+
the metric cannot be computed (e.g. due to no predictions made).
|
42 |
+
|
43 |
+
In addition to COCO, this evaluator is able to support any bounding box detection,
|
44 |
+
instance segmentation, or keypoint detection dataset.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
dataset_name,
|
50 |
+
tasks=None,
|
51 |
+
distributed=True,
|
52 |
+
output_dir=None,
|
53 |
+
*,
|
54 |
+
max_dets_per_image=None,
|
55 |
+
use_fast_impl=True,
|
56 |
+
kpt_oks_sigmas=(),
|
57 |
+
allow_cached_coco=True,
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
dataset_name (str): name of the dataset to be evaluated.
|
62 |
+
It must have either the following corresponding metadata:
|
63 |
+
|
64 |
+
"json_file": the path to the COCO format annotation
|
65 |
+
|
66 |
+
Or it must be in detectron2's standard dataset format
|
67 |
+
so it can be converted to COCO format automatically.
|
68 |
+
tasks (tuple[str]): tasks that can be evaluated under the given
|
69 |
+
configuration. A task is one of "bbox", "segm", "keypoints".
|
70 |
+
By default, will infer this automatically from predictions.
|
71 |
+
distributed (True): if True, will collect results from all ranks and run evaluation
|
72 |
+
in the main process.
|
73 |
+
Otherwise, will only evaluate the results in the current process.
|
74 |
+
output_dir (str): optional, an output directory to dump all
|
75 |
+
results predicted on the dataset. The dump contains two files:
|
76 |
+
|
77 |
+
1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
|
78 |
+
contains all the results in the format they are produced by the model.
|
79 |
+
2. "coco_instances_results.json" a json file in COCO's result format.
|
80 |
+
max_dets_per_image (int): limit on the maximum number of detections per image.
|
81 |
+
By default in COCO, this limit is to 100, but this can be customized
|
82 |
+
to be greater, as is needed in evaluation metrics AP fixed and AP pool
|
83 |
+
(see https://arxiv.org/pdf/2102.01066.pdf)
|
84 |
+
This doesn't affect keypoint evaluation.
|
85 |
+
use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
|
86 |
+
Although the results should be very close to the official implementation in COCO
|
87 |
+
API, it is still recommended to compute results with the official API for use in
|
88 |
+
papers. The faster implementation also uses more RAM.
|
89 |
+
kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
|
90 |
+
See http://cocodataset.org/#keypoints-eval
|
91 |
+
When empty, it will use the defaults in COCO.
|
92 |
+
Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
|
93 |
+
allow_cached_coco (bool): Whether to use cached coco json from previous validation
|
94 |
+
runs. You should set this to False if you need to use different validation data.
|
95 |
+
Defaults to True.
|
96 |
+
"""
|
97 |
+
self._logger = logging.getLogger(__name__)
|
98 |
+
self._distributed = distributed
|
99 |
+
self._output_dir = output_dir
|
100 |
+
|
101 |
+
if use_fast_impl and (COCOeval_opt is COCOeval):
|
102 |
+
self._logger.info("Fast COCO eval is not built. Falling back to official COCO eval.")
|
103 |
+
use_fast_impl = False
|
104 |
+
self._use_fast_impl = use_fast_impl
|
105 |
+
|
106 |
+
# COCOeval requires the limit on the number of detections per image (maxDets) to be a list
|
107 |
+
# with at least 3 elements. The default maxDets in COCOeval is [1, 10, 100], in which the
|
108 |
+
# 3rd element (100) is used as the limit on the number of detections per image when
|
109 |
+
# evaluating AP. COCOEvaluator expects an integer for max_dets_per_image, so for COCOeval,
|
110 |
+
# we reformat max_dets_per_image into [1, 10, max_dets_per_image], based on the defaults.
|
111 |
+
if max_dets_per_image is None:
|
112 |
+
max_dets_per_image = [1, 10, 100]
|
113 |
+
else:
|
114 |
+
max_dets_per_image = [1, 10, max_dets_per_image]
|
115 |
+
self._max_dets_per_image = max_dets_per_image
|
116 |
+
|
117 |
+
if tasks is not None and isinstance(tasks, CfgNode):
|
118 |
+
kpt_oks_sigmas = (
|
119 |
+
tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas
|
120 |
+
)
|
121 |
+
self._logger.warn(
|
122 |
+
"COCO Evaluator instantiated using config, this is deprecated behavior."
|
123 |
+
" Please pass in explicit arguments instead."
|
124 |
+
)
|
125 |
+
self._tasks = None # Infering it from predictions should be better
|
126 |
+
else:
|
127 |
+
self._tasks = tasks
|
128 |
+
|
129 |
+
self._cpu_device = torch.device("cpu")
|
130 |
+
|
131 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
132 |
+
if not hasattr(self._metadata, "json_file"):
|
133 |
+
if output_dir is None:
|
134 |
+
raise ValueError(
|
135 |
+
"output_dir must be provided to COCOEvaluator "
|
136 |
+
"for datasets not in COCO format."
|
137 |
+
)
|
138 |
+
self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...")
|
139 |
+
|
140 |
+
cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
|
141 |
+
self._metadata.json_file = cache_path
|
142 |
+
convert_to_coco_json(dataset_name, cache_path, allow_cached=allow_cached_coco)
|
143 |
+
|
144 |
+
json_file = PathManager.get_local_path(self._metadata.json_file)
|
145 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
146 |
+
self._coco_api = COCO(json_file)
|
147 |
+
|
148 |
+
# Test set json files do not contain annotations (evaluation must be
|
149 |
+
# performed using the COCO evaluation server).
|
150 |
+
self._do_evaluation = "annotations" in self._coco_api.dataset
|
151 |
+
if self._do_evaluation:
|
152 |
+
self._kpt_oks_sigmas = kpt_oks_sigmas
|
153 |
+
|
154 |
+
def reset(self):
|
155 |
+
self._predictions = []
|
156 |
+
|
157 |
+
def process(self, inputs, outputs):
|
158 |
+
"""
|
159 |
+
Args:
|
160 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
161 |
+
It is a list of dict. Each dict corresponds to an image and
|
162 |
+
contains keys like "height", "width", "file_name", "image_id".
|
163 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
164 |
+
"instances" that contains :class:`Instances`.
|
165 |
+
"""
|
166 |
+
for input, output in zip(inputs, outputs):
|
167 |
+
prediction = {"image_id": input["image_id"]}
|
168 |
+
|
169 |
+
if "instances" in output:
|
170 |
+
instances = output["instances"].to(self._cpu_device)
|
171 |
+
prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
|
172 |
+
if "proposals" in output:
|
173 |
+
prediction["proposals"] = output["proposals"].to(self._cpu_device)
|
174 |
+
if len(prediction) > 1:
|
175 |
+
self._predictions.append(prediction)
|
176 |
+
|
177 |
+
def evaluate(self, img_ids=None):
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
|
181 |
+
"""
|
182 |
+
if self._distributed:
|
183 |
+
comm.synchronize()
|
184 |
+
predictions = comm.gather(self._predictions, dst=0)
|
185 |
+
predictions = list(itertools.chain(*predictions))
|
186 |
+
|
187 |
+
if not comm.is_main_process():
|
188 |
+
return {}
|
189 |
+
else:
|
190 |
+
predictions = self._predictions
|
191 |
+
|
192 |
+
if len(predictions) == 0:
|
193 |
+
self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
|
194 |
+
return {}
|
195 |
+
|
196 |
+
if self._output_dir:
|
197 |
+
PathManager.mkdirs(self._output_dir)
|
198 |
+
file_path = os.path.join(self._output_dir, "instances_predictions.pth")
|
199 |
+
with PathManager.open(file_path, "wb") as f:
|
200 |
+
torch.save(predictions, f)
|
201 |
+
|
202 |
+
self._results = OrderedDict()
|
203 |
+
if "proposals" in predictions[0]:
|
204 |
+
self._eval_box_proposals(predictions)
|
205 |
+
if "instances" in predictions[0]:
|
206 |
+
self._eval_predictions(predictions, img_ids=img_ids)
|
207 |
+
# Copy so the caller can do whatever with results
|
208 |
+
return copy.deepcopy(self._results)
|
209 |
+
|
210 |
+
def _tasks_from_predictions(self, predictions):
|
211 |
+
"""
|
212 |
+
Get COCO API "tasks" (i.e. iou_type) from COCO-format predictions.
|
213 |
+
"""
|
214 |
+
tasks = {"bbox"}
|
215 |
+
for pred in predictions:
|
216 |
+
if "segmentation" in pred:
|
217 |
+
tasks.add("segm")
|
218 |
+
if "keypoints" in pred:
|
219 |
+
tasks.add("keypoints")
|
220 |
+
return sorted(tasks)
|
221 |
+
|
222 |
+
def _eval_predictions(self, predictions, img_ids=None):
|
223 |
+
"""
|
224 |
+
Evaluate predictions. Fill self._results with the metrics of the tasks.
|
225 |
+
"""
|
226 |
+
self._logger.info("Preparing results for COCO format ...")
|
227 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
228 |
+
tasks = self._tasks or self._tasks_from_predictions(coco_results)
|
229 |
+
|
230 |
+
# unmap the category ids for COCO
|
231 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
232 |
+
dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
|
233 |
+
all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
|
234 |
+
num_classes = len(all_contiguous_ids)
|
235 |
+
assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
|
236 |
+
|
237 |
+
reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
|
238 |
+
for result in coco_results:
|
239 |
+
category_id = result["category_id"]
|
240 |
+
assert category_id < num_classes, (
|
241 |
+
f"A prediction has class={category_id}, "
|
242 |
+
f"but the dataset only has {num_classes} classes and "
|
243 |
+
f"predicted class id should be in [0, {num_classes - 1}]."
|
244 |
+
)
|
245 |
+
result["category_id"] = reverse_id_mapping[category_id]
|
246 |
+
|
247 |
+
if self._output_dir:
|
248 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
249 |
+
self._logger.info("Saving results to {}".format(file_path))
|
250 |
+
with PathManager.open(file_path, "w") as f:
|
251 |
+
f.write(json.dumps(coco_results))
|
252 |
+
f.flush()
|
253 |
+
|
254 |
+
if not self._do_evaluation:
|
255 |
+
self._logger.info("Annotations are not available for evaluation.")
|
256 |
+
return
|
257 |
+
|
258 |
+
self._logger.info(
|
259 |
+
"Evaluating predictions with {} COCO API...".format(
|
260 |
+
"unofficial" if self._use_fast_impl else "official"
|
261 |
+
)
|
262 |
+
)
|
263 |
+
for task in sorted(tasks):
|
264 |
+
assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
|
265 |
+
coco_eval = (
|
266 |
+
_evaluate_predictions_on_coco(
|
267 |
+
self._coco_api,
|
268 |
+
coco_results,
|
269 |
+
task,
|
270 |
+
kpt_oks_sigmas=self._kpt_oks_sigmas,
|
271 |
+
cocoeval_fn=COCOeval_opt if self._use_fast_impl else COCOeval,
|
272 |
+
img_ids=img_ids,
|
273 |
+
max_dets_per_image=self._max_dets_per_image,
|
274 |
+
)
|
275 |
+
if len(coco_results) > 0
|
276 |
+
else None # cocoapi does not handle empty results very well
|
277 |
+
)
|
278 |
+
|
279 |
+
res = self._derive_coco_results(
|
280 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
281 |
+
)
|
282 |
+
self._results[task] = res
|
283 |
+
|
284 |
+
def _eval_box_proposals(self, predictions):
|
285 |
+
"""
|
286 |
+
Evaluate the box proposals in predictions.
|
287 |
+
Fill self._results with the metrics for "box_proposals" task.
|
288 |
+
"""
|
289 |
+
if self._output_dir:
|
290 |
+
# Saving generated box proposals to file.
|
291 |
+
# Predicted box_proposals are in XYXY_ABS mode.
|
292 |
+
bbox_mode = BoxMode.XYXY_ABS.value
|
293 |
+
ids, boxes, objectness_logits = [], [], []
|
294 |
+
for prediction in predictions:
|
295 |
+
ids.append(prediction["image_id"])
|
296 |
+
boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
|
297 |
+
objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
|
298 |
+
|
299 |
+
proposal_data = {
|
300 |
+
"boxes": boxes,
|
301 |
+
"objectness_logits": objectness_logits,
|
302 |
+
"ids": ids,
|
303 |
+
"bbox_mode": bbox_mode,
|
304 |
+
}
|
305 |
+
with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
|
306 |
+
pickle.dump(proposal_data, f)
|
307 |
+
|
308 |
+
if not self._do_evaluation:
|
309 |
+
self._logger.info("Annotations are not available for evaluation.")
|
310 |
+
return
|
311 |
+
|
312 |
+
self._logger.info("Evaluating bbox proposals ...")
|
313 |
+
res = {}
|
314 |
+
areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
|
315 |
+
for limit in [100, 1000]:
|
316 |
+
for area, suffix in areas.items():
|
317 |
+
stats = _evaluate_box_proposals(predictions, self._coco_api, area=area, limit=limit)
|
318 |
+
key = "AR{}@{:d}".format(suffix, limit)
|
319 |
+
res[key] = float(stats["ar"].item() * 100)
|
320 |
+
self._logger.info("Proposal metrics: \n" + create_small_table(res))
|
321 |
+
self._results["box_proposals"] = res
|
322 |
+
|
323 |
+
def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
|
324 |
+
"""
|
325 |
+
Derive the desired score numbers from summarized COCOeval.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
coco_eval (None or COCOEval): None represents no predictions from model.
|
329 |
+
iou_type (str):
|
330 |
+
class_names (None or list[str]): if provided, will use it to predict
|
331 |
+
per-category AP.
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
a dict of {metric name: score}
|
335 |
+
"""
|
336 |
+
|
337 |
+
metrics = {
|
338 |
+
"bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
|
339 |
+
"segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
|
340 |
+
"keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
|
341 |
+
}[iou_type]
|
342 |
+
|
343 |
+
if coco_eval is None:
|
344 |
+
self._logger.warn("No predictions from the model!")
|
345 |
+
return {metric: float("nan") for metric in metrics}
|
346 |
+
|
347 |
+
# the standard metrics
|
348 |
+
results = {
|
349 |
+
metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
|
350 |
+
for idx, metric in enumerate(metrics)
|
351 |
+
}
|
352 |
+
self._logger.info(
|
353 |
+
"Evaluation results for {}: \n".format(iou_type) + create_small_table(results)
|
354 |
+
)
|
355 |
+
if not np.isfinite(sum(results.values())):
|
356 |
+
self._logger.info("Some metrics cannot be computed and is shown as NaN.")
|
357 |
+
|
358 |
+
if class_names is None or len(class_names) <= 1:
|
359 |
+
return results
|
360 |
+
# Compute per-category AP
|
361 |
+
# from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
|
362 |
+
precisions = coco_eval.eval["precision"]
|
363 |
+
# precision has dims (iou, recall, cls, area range, max dets)
|
364 |
+
assert len(class_names) == precisions.shape[2]
|
365 |
+
|
366 |
+
results_per_category = []
|
367 |
+
for idx, name in enumerate(class_names):
|
368 |
+
# area range index 0: all area ranges
|
369 |
+
# max dets index -1: typically 100 per image
|
370 |
+
precision = precisions[:, :, idx, 0, -1]
|
371 |
+
precision = precision[precision > -1]
|
372 |
+
ap = np.mean(precision) if precision.size else float("nan")
|
373 |
+
results_per_category.append(("{}".format(name), float(ap * 100)))
|
374 |
+
|
375 |
+
# tabulate it
|
376 |
+
N_COLS = min(6, len(results_per_category) * 2)
|
377 |
+
results_flatten = list(itertools.chain(*results_per_category))
|
378 |
+
results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
|
379 |
+
table = tabulate(
|
380 |
+
results_2d,
|
381 |
+
tablefmt="pipe",
|
382 |
+
floatfmt=".3f",
|
383 |
+
headers=["category", "AP"] * (N_COLS // 2),
|
384 |
+
numalign="left",
|
385 |
+
)
|
386 |
+
self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
|
387 |
+
|
388 |
+
results.update({"AP-" + name: ap for name, ap in results_per_category})
|
389 |
+
return results
|
390 |
+
|
391 |
+
|
392 |
+
def instances_to_coco_json(instances, img_id):
|
393 |
+
"""
|
394 |
+
Dump an "Instances" object to a COCO-format json that's used for evaluation.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
instances (Instances):
|
398 |
+
img_id (int): the image id
|
399 |
+
|
400 |
+
Returns:
|
401 |
+
list[dict]: list of json annotations in COCO format.
|
402 |
+
"""
|
403 |
+
num_instance = len(instances)
|
404 |
+
if num_instance == 0:
|
405 |
+
return []
|
406 |
+
|
407 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
408 |
+
boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
409 |
+
boxes = boxes.tolist()
|
410 |
+
scores = instances.scores.tolist()
|
411 |
+
classes = instances.pred_classes.tolist()
|
412 |
+
|
413 |
+
has_mask = instances.has("pred_masks")
|
414 |
+
if has_mask:
|
415 |
+
# use RLE to encode the masks, because they are too large and takes memory
|
416 |
+
# since this evaluator stores outputs of the entire dataset
|
417 |
+
rles = [
|
418 |
+
mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
419 |
+
for mask in instances.pred_masks
|
420 |
+
]
|
421 |
+
for rle in rles:
|
422 |
+
# "counts" is an array encoded by mask_util as a byte-stream. Python3's
|
423 |
+
# json writer which always produces strings cannot serialize a bytestream
|
424 |
+
# unless you decode it. Thankfully, utf-8 works out (which is also what
|
425 |
+
# the annotator.oneformer.pycocotools/_mask.pyx does).
|
426 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
427 |
+
|
428 |
+
has_keypoints = instances.has("pred_keypoints")
|
429 |
+
if has_keypoints:
|
430 |
+
keypoints = instances.pred_keypoints
|
431 |
+
|
432 |
+
results = []
|
433 |
+
for k in range(num_instance):
|
434 |
+
result = {
|
435 |
+
"image_id": img_id,
|
436 |
+
"category_id": classes[k],
|
437 |
+
"bbox": boxes[k],
|
438 |
+
"score": scores[k],
|
439 |
+
}
|
440 |
+
if has_mask:
|
441 |
+
result["segmentation"] = rles[k]
|
442 |
+
if has_keypoints:
|
443 |
+
# In COCO annotations,
|
444 |
+
# keypoints coordinates are pixel indices.
|
445 |
+
# However our predictions are floating point coordinates.
|
446 |
+
# Therefore we subtract 0.5 to be consistent with the annotation format.
|
447 |
+
# This is the inverse of data loading logic in `datasets/coco.py`.
|
448 |
+
keypoints[k][:, :2] -= 0.5
|
449 |
+
result["keypoints"] = keypoints[k].flatten().tolist()
|
450 |
+
results.append(result)
|
451 |
+
return results
|
452 |
+
|
453 |
+
|
454 |
+
# inspired from Detectron:
|
455 |
+
# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
|
456 |
+
def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area="all", limit=None):
|
457 |
+
"""
|
458 |
+
Evaluate detection proposal recall metrics. This function is a much
|
459 |
+
faster alternative to the official COCO API recall evaluation code. However,
|
460 |
+
it produces slightly different results.
|
461 |
+
"""
|
462 |
+
# Record max overlap value for each gt box
|
463 |
+
# Return vector of overlap values
|
464 |
+
areas = {
|
465 |
+
"all": 0,
|
466 |
+
"small": 1,
|
467 |
+
"medium": 2,
|
468 |
+
"large": 3,
|
469 |
+
"96-128": 4,
|
470 |
+
"128-256": 5,
|
471 |
+
"256-512": 6,
|
472 |
+
"512-inf": 7,
|
473 |
+
}
|
474 |
+
area_ranges = [
|
475 |
+
[0**2, 1e5**2], # all
|
476 |
+
[0**2, 32**2], # small
|
477 |
+
[32**2, 96**2], # medium
|
478 |
+
[96**2, 1e5**2], # large
|
479 |
+
[96**2, 128**2], # 96-128
|
480 |
+
[128**2, 256**2], # 128-256
|
481 |
+
[256**2, 512**2], # 256-512
|
482 |
+
[512**2, 1e5**2],
|
483 |
+
] # 512-inf
|
484 |
+
assert area in areas, "Unknown area range: {}".format(area)
|
485 |
+
area_range = area_ranges[areas[area]]
|
486 |
+
gt_overlaps = []
|
487 |
+
num_pos = 0
|
488 |
+
|
489 |
+
for prediction_dict in dataset_predictions:
|
490 |
+
predictions = prediction_dict["proposals"]
|
491 |
+
|
492 |
+
# sort predictions in descending order
|
493 |
+
# TODO maybe remove this and make it explicit in the documentation
|
494 |
+
inds = predictions.objectness_logits.sort(descending=True)[1]
|
495 |
+
predictions = predictions[inds]
|
496 |
+
|
497 |
+
ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"])
|
498 |
+
anno = coco_api.loadAnns(ann_ids)
|
499 |
+
gt_boxes = [
|
500 |
+
BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
|
501 |
+
for obj in anno
|
502 |
+
if obj["iscrowd"] == 0
|
503 |
+
]
|
504 |
+
gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
|
505 |
+
gt_boxes = Boxes(gt_boxes)
|
506 |
+
gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
|
507 |
+
|
508 |
+
if len(gt_boxes) == 0 or len(predictions) == 0:
|
509 |
+
continue
|
510 |
+
|
511 |
+
valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
|
512 |
+
gt_boxes = gt_boxes[valid_gt_inds]
|
513 |
+
|
514 |
+
num_pos += len(gt_boxes)
|
515 |
+
|
516 |
+
if len(gt_boxes) == 0:
|
517 |
+
continue
|
518 |
+
|
519 |
+
if limit is not None and len(predictions) > limit:
|
520 |
+
predictions = predictions[:limit]
|
521 |
+
|
522 |
+
overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
|
523 |
+
|
524 |
+
_gt_overlaps = torch.zeros(len(gt_boxes))
|
525 |
+
for j in range(min(len(predictions), len(gt_boxes))):
|
526 |
+
# find which proposal box maximally covers each gt box
|
527 |
+
# and get the iou amount of coverage for each gt box
|
528 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
529 |
+
|
530 |
+
# find which gt box is 'best' covered (i.e. 'best' = most iou)
|
531 |
+
gt_ovr, gt_ind = max_overlaps.max(dim=0)
|
532 |
+
assert gt_ovr >= 0
|
533 |
+
# find the proposal box that covers the best covered gt box
|
534 |
+
box_ind = argmax_overlaps[gt_ind]
|
535 |
+
# record the iou coverage of this gt box
|
536 |
+
_gt_overlaps[j] = overlaps[box_ind, gt_ind]
|
537 |
+
assert _gt_overlaps[j] == gt_ovr
|
538 |
+
# mark the proposal box and the gt box as used
|
539 |
+
overlaps[box_ind, :] = -1
|
540 |
+
overlaps[:, gt_ind] = -1
|
541 |
+
|
542 |
+
# append recorded iou coverage level
|
543 |
+
gt_overlaps.append(_gt_overlaps)
|
544 |
+
gt_overlaps = (
|
545 |
+
torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
|
546 |
+
)
|
547 |
+
gt_overlaps, _ = torch.sort(gt_overlaps)
|
548 |
+
|
549 |
+
if thresholds is None:
|
550 |
+
step = 0.05
|
551 |
+
thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
|
552 |
+
recalls = torch.zeros_like(thresholds)
|
553 |
+
# compute recall for each iou threshold
|
554 |
+
for i, t in enumerate(thresholds):
|
555 |
+
recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
|
556 |
+
# ar = 2 * np.trapz(recalls, thresholds)
|
557 |
+
ar = recalls.mean()
|
558 |
+
return {
|
559 |
+
"ar": ar,
|
560 |
+
"recalls": recalls,
|
561 |
+
"thresholds": thresholds,
|
562 |
+
"gt_overlaps": gt_overlaps,
|
563 |
+
"num_pos": num_pos,
|
564 |
+
}
|
565 |
+
|
566 |
+
|
567 |
+
def _evaluate_predictions_on_coco(
|
568 |
+
coco_gt,
|
569 |
+
coco_results,
|
570 |
+
iou_type,
|
571 |
+
kpt_oks_sigmas=None,
|
572 |
+
cocoeval_fn=COCOeval_opt,
|
573 |
+
img_ids=None,
|
574 |
+
max_dets_per_image=None,
|
575 |
+
):
|
576 |
+
"""
|
577 |
+
Evaluate the coco results using COCOEval API.
|
578 |
+
"""
|
579 |
+
assert len(coco_results) > 0
|
580 |
+
|
581 |
+
if iou_type == "segm":
|
582 |
+
coco_results = copy.deepcopy(coco_results)
|
583 |
+
# When evaluating mask AP, if the results contain bbox, cocoapi will
|
584 |
+
# use the box area as the area of the instance, instead of the mask area.
|
585 |
+
# This leads to a different definition of small/medium/large.
|
586 |
+
# We remove the bbox field to let mask AP use mask area.
|
587 |
+
for c in coco_results:
|
588 |
+
c.pop("bbox", None)
|
589 |
+
|
590 |
+
coco_dt = coco_gt.loadRes(coco_results)
|
591 |
+
coco_eval = cocoeval_fn(coco_gt, coco_dt, iou_type)
|
592 |
+
# For COCO, the default max_dets_per_image is [1, 10, 100].
|
593 |
+
if max_dets_per_image is None:
|
594 |
+
max_dets_per_image = [1, 10, 100] # Default from COCOEval
|
595 |
+
else:
|
596 |
+
assert (
|
597 |
+
len(max_dets_per_image) >= 3
|
598 |
+
), "COCOeval requires maxDets (and max_dets_per_image) to have length at least 3"
|
599 |
+
# In the case that user supplies a custom input for max_dets_per_image,
|
600 |
+
# apply COCOevalMaxDets to evaluate AP with the custom input.
|
601 |
+
if max_dets_per_image[2] != 100:
|
602 |
+
coco_eval = COCOevalMaxDets(coco_gt, coco_dt, iou_type)
|
603 |
+
if iou_type != "keypoints":
|
604 |
+
coco_eval.params.maxDets = max_dets_per_image
|
605 |
+
|
606 |
+
if img_ids is not None:
|
607 |
+
coco_eval.params.imgIds = img_ids
|
608 |
+
|
609 |
+
if iou_type == "keypoints":
|
610 |
+
# Use the COCO default keypoint OKS sigmas unless overrides are specified
|
611 |
+
if kpt_oks_sigmas:
|
612 |
+
assert hasattr(coco_eval.params, "kpt_oks_sigmas"), "annotator.oneformer.pycocotools is too old!"
|
613 |
+
coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas)
|
614 |
+
# COCOAPI requires every detection and every gt to have keypoints, so
|
615 |
+
# we just take the first entry from both
|
616 |
+
num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3
|
617 |
+
num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3
|
618 |
+
num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas)
|
619 |
+
assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, (
|
620 |
+
f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. "
|
621 |
+
f"Ground truth contains {num_keypoints_gt} keypoints. "
|
622 |
+
f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. "
|
623 |
+
"They have to agree with each other. For meaning of OKS, please refer to "
|
624 |
+
"http://cocodataset.org/#keypoints-eval."
|
625 |
+
)
|
626 |
+
|
627 |
+
coco_eval.evaluate()
|
628 |
+
coco_eval.accumulate()
|
629 |
+
coco_eval.summarize()
|
630 |
+
|
631 |
+
return coco_eval
|
632 |
+
|
633 |
+
|
634 |
+
class COCOevalMaxDets(COCOeval):
|
635 |
+
"""
|
636 |
+
Modified version of COCOeval for evaluating AP with a custom
|
637 |
+
maxDets (by default for COCO, maxDets is 100)
|
638 |
+
"""
|
639 |
+
|
640 |
+
def summarize(self):
|
641 |
+
"""
|
642 |
+
Compute and display summary metrics for evaluation results given
|
643 |
+
a custom value for max_dets_per_image
|
644 |
+
"""
|
645 |
+
|
646 |
+
def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
|
647 |
+
p = self.params
|
648 |
+
iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
|
649 |
+
titleStr = "Average Precision" if ap == 1 else "Average Recall"
|
650 |
+
typeStr = "(AP)" if ap == 1 else "(AR)"
|
651 |
+
iouStr = (
|
652 |
+
"{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
|
653 |
+
if iouThr is None
|
654 |
+
else "{:0.2f}".format(iouThr)
|
655 |
+
)
|
656 |
+
|
657 |
+
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
|
658 |
+
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
659 |
+
if ap == 1:
|
660 |
+
# dimension of precision: [TxRxKxAxM]
|
661 |
+
s = self.eval["precision"]
|
662 |
+
# IoU
|
663 |
+
if iouThr is not None:
|
664 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
665 |
+
s = s[t]
|
666 |
+
s = s[:, :, :, aind, mind]
|
667 |
+
else:
|
668 |
+
# dimension of recall: [TxKxAxM]
|
669 |
+
s = self.eval["recall"]
|
670 |
+
if iouThr is not None:
|
671 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
672 |
+
s = s[t]
|
673 |
+
s = s[:, :, aind, mind]
|
674 |
+
if len(s[s > -1]) == 0:
|
675 |
+
mean_s = -1
|
676 |
+
else:
|
677 |
+
mean_s = np.mean(s[s > -1])
|
678 |
+
print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
|
679 |
+
return mean_s
|
680 |
+
|
681 |
+
def _summarizeDets():
|
682 |
+
stats = np.zeros((12,))
|
683 |
+
# Evaluate AP using the custom limit on maximum detections per image
|
684 |
+
stats[0] = _summarize(1, maxDets=self.params.maxDets[2])
|
685 |
+
stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
|
686 |
+
stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
|
687 |
+
stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2])
|
688 |
+
stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2])
|
689 |
+
stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2])
|
690 |
+
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
|
691 |
+
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
|
692 |
+
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
|
693 |
+
stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2])
|
694 |
+
stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2])
|
695 |
+
stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2])
|
696 |
+
return stats
|
697 |
+
|
698 |
+
def _summarizeKps():
|
699 |
+
stats = np.zeros((10,))
|
700 |
+
stats[0] = _summarize(1, maxDets=20)
|
701 |
+
stats[1] = _summarize(1, maxDets=20, iouThr=0.5)
|
702 |
+
stats[2] = _summarize(1, maxDets=20, iouThr=0.75)
|
703 |
+
stats[3] = _summarize(1, maxDets=20, areaRng="medium")
|
704 |
+
stats[4] = _summarize(1, maxDets=20, areaRng="large")
|
705 |
+
stats[5] = _summarize(0, maxDets=20)
|
706 |
+
stats[6] = _summarize(0, maxDets=20, iouThr=0.5)
|
707 |
+
stats[7] = _summarize(0, maxDets=20, iouThr=0.75)
|
708 |
+
stats[8] = _summarize(0, maxDets=20, areaRng="medium")
|
709 |
+
stats[9] = _summarize(0, maxDets=20, areaRng="large")
|
710 |
+
return stats
|
711 |
+
|
712 |
+
if not self.eval:
|
713 |
+
raise Exception("Please run accumulate() first")
|
714 |
+
iouType = self.params.iouType
|
715 |
+
if iouType == "segm" or iouType == "bbox":
|
716 |
+
summarize = _summarizeDets
|
717 |
+
elif iouType == "keypoints":
|
718 |
+
summarize = _summarizeKps
|
719 |
+
self.stats = summarize()
|
720 |
+
|
721 |
+
def __str__(self):
|
722 |
+
self.summarize()
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/evaluator.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import datetime
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
from collections import OrderedDict, abc
|
6 |
+
from contextlib import ExitStack, contextmanager
|
7 |
+
from typing import List, Union
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from annotator.oneformer.detectron2.utils.comm import get_world_size, is_main_process
|
12 |
+
from annotator.oneformer.detectron2.utils.logger import log_every_n_seconds
|
13 |
+
|
14 |
+
|
15 |
+
class DatasetEvaluator:
|
16 |
+
"""
|
17 |
+
Base class for a dataset evaluator.
|
18 |
+
|
19 |
+
The function :func:`inference_on_dataset` runs the model over
|
20 |
+
all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
|
21 |
+
|
22 |
+
This class will accumulate information of the inputs/outputs (by :meth:`process`),
|
23 |
+
and produce evaluation results in the end (by :meth:`evaluate`).
|
24 |
+
"""
|
25 |
+
|
26 |
+
def reset(self):
|
27 |
+
"""
|
28 |
+
Preparation for a new round of evaluation.
|
29 |
+
Should be called before starting a round of evaluation.
|
30 |
+
"""
|
31 |
+
pass
|
32 |
+
|
33 |
+
def process(self, inputs, outputs):
|
34 |
+
"""
|
35 |
+
Process the pair of inputs and outputs.
|
36 |
+
If they contain batches, the pairs can be consumed one-by-one using `zip`:
|
37 |
+
|
38 |
+
.. code-block:: python
|
39 |
+
|
40 |
+
for input_, output in zip(inputs, outputs):
|
41 |
+
# do evaluation on single input/output pair
|
42 |
+
...
|
43 |
+
|
44 |
+
Args:
|
45 |
+
inputs (list): the inputs that's used to call the model.
|
46 |
+
outputs (list): the return value of `model(inputs)`
|
47 |
+
"""
|
48 |
+
pass
|
49 |
+
|
50 |
+
def evaluate(self):
|
51 |
+
"""
|
52 |
+
Evaluate/summarize the performance, after processing all input/output pairs.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
dict:
|
56 |
+
A new evaluator class can return a dict of arbitrary format
|
57 |
+
as long as the user can process the results.
|
58 |
+
In our train_net.py, we expect the following format:
|
59 |
+
|
60 |
+
* key: the name of the task (e.g., bbox)
|
61 |
+
* value: a dict of {metric name: score}, e.g.: {"AP50": 80}
|
62 |
+
"""
|
63 |
+
pass
|
64 |
+
|
65 |
+
|
66 |
+
class DatasetEvaluators(DatasetEvaluator):
|
67 |
+
"""
|
68 |
+
Wrapper class to combine multiple :class:`DatasetEvaluator` instances.
|
69 |
+
|
70 |
+
This class dispatches every evaluation call to
|
71 |
+
all of its :class:`DatasetEvaluator`.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, evaluators):
|
75 |
+
"""
|
76 |
+
Args:
|
77 |
+
evaluators (list): the evaluators to combine.
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
self._evaluators = evaluators
|
81 |
+
|
82 |
+
def reset(self):
|
83 |
+
for evaluator in self._evaluators:
|
84 |
+
evaluator.reset()
|
85 |
+
|
86 |
+
def process(self, inputs, outputs):
|
87 |
+
for evaluator in self._evaluators:
|
88 |
+
evaluator.process(inputs, outputs)
|
89 |
+
|
90 |
+
def evaluate(self):
|
91 |
+
results = OrderedDict()
|
92 |
+
for evaluator in self._evaluators:
|
93 |
+
result = evaluator.evaluate()
|
94 |
+
if is_main_process() and result is not None:
|
95 |
+
for k, v in result.items():
|
96 |
+
assert (
|
97 |
+
k not in results
|
98 |
+
), "Different evaluators produce results with the same key {}".format(k)
|
99 |
+
results[k] = v
|
100 |
+
return results
|
101 |
+
|
102 |
+
|
103 |
+
def inference_on_dataset(
|
104 |
+
model, data_loader, evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None]
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Run model on the data_loader and evaluate the metrics with evaluator.
|
108 |
+
Also benchmark the inference speed of `model.__call__` accurately.
|
109 |
+
The model will be used in eval mode.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
model (callable): a callable which takes an object from
|
113 |
+
`data_loader` and returns some outputs.
|
114 |
+
|
115 |
+
If it's an nn.Module, it will be temporarily set to `eval` mode.
|
116 |
+
If you wish to evaluate a model in `training` mode instead, you can
|
117 |
+
wrap the given model and override its behavior of `.eval()` and `.train()`.
|
118 |
+
data_loader: an iterable object with a length.
|
119 |
+
The elements it generates will be the inputs to the model.
|
120 |
+
evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark,
|
121 |
+
but don't want to do any evaluation.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
The return value of `evaluator.evaluate()`
|
125 |
+
"""
|
126 |
+
num_devices = get_world_size()
|
127 |
+
logger = logging.getLogger(__name__)
|
128 |
+
logger.info("Start inference on {} batches".format(len(data_loader)))
|
129 |
+
|
130 |
+
total = len(data_loader) # inference data loader must have a fixed length
|
131 |
+
if evaluator is None:
|
132 |
+
# create a no-op evaluator
|
133 |
+
evaluator = DatasetEvaluators([])
|
134 |
+
if isinstance(evaluator, abc.MutableSequence):
|
135 |
+
evaluator = DatasetEvaluators(evaluator)
|
136 |
+
evaluator.reset()
|
137 |
+
|
138 |
+
num_warmup = min(5, total - 1)
|
139 |
+
start_time = time.perf_counter()
|
140 |
+
total_data_time = 0
|
141 |
+
total_compute_time = 0
|
142 |
+
total_eval_time = 0
|
143 |
+
with ExitStack() as stack:
|
144 |
+
if isinstance(model, nn.Module):
|
145 |
+
stack.enter_context(inference_context(model))
|
146 |
+
stack.enter_context(torch.no_grad())
|
147 |
+
|
148 |
+
start_data_time = time.perf_counter()
|
149 |
+
for idx, inputs in enumerate(data_loader):
|
150 |
+
total_data_time += time.perf_counter() - start_data_time
|
151 |
+
if idx == num_warmup:
|
152 |
+
start_time = time.perf_counter()
|
153 |
+
total_data_time = 0
|
154 |
+
total_compute_time = 0
|
155 |
+
total_eval_time = 0
|
156 |
+
|
157 |
+
start_compute_time = time.perf_counter()
|
158 |
+
outputs = model(inputs)
|
159 |
+
if torch.cuda.is_available():
|
160 |
+
torch.cuda.synchronize()
|
161 |
+
total_compute_time += time.perf_counter() - start_compute_time
|
162 |
+
|
163 |
+
start_eval_time = time.perf_counter()
|
164 |
+
evaluator.process(inputs, outputs)
|
165 |
+
total_eval_time += time.perf_counter() - start_eval_time
|
166 |
+
|
167 |
+
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
168 |
+
data_seconds_per_iter = total_data_time / iters_after_start
|
169 |
+
compute_seconds_per_iter = total_compute_time / iters_after_start
|
170 |
+
eval_seconds_per_iter = total_eval_time / iters_after_start
|
171 |
+
total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
|
172 |
+
if idx >= num_warmup * 2 or compute_seconds_per_iter > 5:
|
173 |
+
eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
|
174 |
+
log_every_n_seconds(
|
175 |
+
logging.INFO,
|
176 |
+
(
|
177 |
+
f"Inference done {idx + 1}/{total}. "
|
178 |
+
f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
|
179 |
+
f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
|
180 |
+
f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
|
181 |
+
f"Total: {total_seconds_per_iter:.4f} s/iter. "
|
182 |
+
f"ETA={eta}"
|
183 |
+
),
|
184 |
+
n=5,
|
185 |
+
)
|
186 |
+
start_data_time = time.perf_counter()
|
187 |
+
|
188 |
+
# Measure the time only for this worker (before the synchronization barrier)
|
189 |
+
total_time = time.perf_counter() - start_time
|
190 |
+
total_time_str = str(datetime.timedelta(seconds=total_time))
|
191 |
+
# NOTE this format is parsed by grep
|
192 |
+
logger.info(
|
193 |
+
"Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format(
|
194 |
+
total_time_str, total_time / (total - num_warmup), num_devices
|
195 |
+
)
|
196 |
+
)
|
197 |
+
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
|
198 |
+
logger.info(
|
199 |
+
"Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format(
|
200 |
+
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
|
201 |
+
)
|
202 |
+
)
|
203 |
+
|
204 |
+
results = evaluator.evaluate()
|
205 |
+
# An evaluator may return None when not in main process.
|
206 |
+
# Replace it by an empty dict instead to make it easier for downstream code to handle
|
207 |
+
if results is None:
|
208 |
+
results = {}
|
209 |
+
return results
|
210 |
+
|
211 |
+
|
212 |
+
@contextmanager
|
213 |
+
def inference_context(model):
|
214 |
+
"""
|
215 |
+
A context where the model is temporarily changed to eval mode,
|
216 |
+
and restored to previous mode afterwards.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
model: a torch Module
|
220 |
+
"""
|
221 |
+
training_mode = model.training
|
222 |
+
model.eval()
|
223 |
+
yield
|
224 |
+
model.train(training_mode)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/fast_eval_api.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import time
|
6 |
+
from annotator.oneformer.pycocotools.cocoeval import COCOeval
|
7 |
+
|
8 |
+
from annotator.oneformer.detectron2 import _C
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class COCOeval_opt(COCOeval):
|
14 |
+
"""
|
15 |
+
This is a slightly modified version of the original COCO API, where the functions evaluateImg()
|
16 |
+
and accumulate() are implemented in C++ to speedup evaluation
|
17 |
+
"""
|
18 |
+
|
19 |
+
def evaluate(self):
|
20 |
+
"""
|
21 |
+
Run per image evaluation on given images and store results in self.evalImgs_cpp, a
|
22 |
+
datastructure that isn't readable from Python but is used by a c++ implementation of
|
23 |
+
accumulate(). Unlike the original COCO PythonAPI, we don't populate the datastructure
|
24 |
+
self.evalImgs because this datastructure is a computational bottleneck.
|
25 |
+
:return: None
|
26 |
+
"""
|
27 |
+
tic = time.time()
|
28 |
+
|
29 |
+
p = self.params
|
30 |
+
# add backward compatibility if useSegm is specified in params
|
31 |
+
if p.useSegm is not None:
|
32 |
+
p.iouType = "segm" if p.useSegm == 1 else "bbox"
|
33 |
+
logger.info("Evaluate annotation type *{}*".format(p.iouType))
|
34 |
+
p.imgIds = list(np.unique(p.imgIds))
|
35 |
+
if p.useCats:
|
36 |
+
p.catIds = list(np.unique(p.catIds))
|
37 |
+
p.maxDets = sorted(p.maxDets)
|
38 |
+
self.params = p
|
39 |
+
|
40 |
+
self._prepare() # bottleneck
|
41 |
+
|
42 |
+
# loop through images, area range, max detection number
|
43 |
+
catIds = p.catIds if p.useCats else [-1]
|
44 |
+
|
45 |
+
if p.iouType == "segm" or p.iouType == "bbox":
|
46 |
+
computeIoU = self.computeIoU
|
47 |
+
elif p.iouType == "keypoints":
|
48 |
+
computeIoU = self.computeOks
|
49 |
+
self.ious = {
|
50 |
+
(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds
|
51 |
+
} # bottleneck
|
52 |
+
|
53 |
+
maxDet = p.maxDets[-1]
|
54 |
+
|
55 |
+
# <<<< Beginning of code differences with original COCO API
|
56 |
+
def convert_instances_to_cpp(instances, is_det=False):
|
57 |
+
# Convert annotations for a list of instances in an image to a format that's fast
|
58 |
+
# to access in C++
|
59 |
+
instances_cpp = []
|
60 |
+
for instance in instances:
|
61 |
+
instance_cpp = _C.InstanceAnnotation(
|
62 |
+
int(instance["id"]),
|
63 |
+
instance["score"] if is_det else instance.get("score", 0.0),
|
64 |
+
instance["area"],
|
65 |
+
bool(instance.get("iscrowd", 0)),
|
66 |
+
bool(instance.get("ignore", 0)),
|
67 |
+
)
|
68 |
+
instances_cpp.append(instance_cpp)
|
69 |
+
return instances_cpp
|
70 |
+
|
71 |
+
# Convert GT annotations, detections, and IOUs to a format that's fast to access in C++
|
72 |
+
ground_truth_instances = [
|
73 |
+
[convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds]
|
74 |
+
for imgId in p.imgIds
|
75 |
+
]
|
76 |
+
detected_instances = [
|
77 |
+
[convert_instances_to_cpp(self._dts[imgId, catId], is_det=True) for catId in p.catIds]
|
78 |
+
for imgId in p.imgIds
|
79 |
+
]
|
80 |
+
ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds]
|
81 |
+
|
82 |
+
if not p.useCats:
|
83 |
+
# For each image, flatten per-category lists into a single list
|
84 |
+
ground_truth_instances = [[[o for c in i for o in c]] for i in ground_truth_instances]
|
85 |
+
detected_instances = [[[o for c in i for o in c]] for i in detected_instances]
|
86 |
+
|
87 |
+
# Call C++ implementation of self.evaluateImgs()
|
88 |
+
self._evalImgs_cpp = _C.COCOevalEvaluateImages(
|
89 |
+
p.areaRng, maxDet, p.iouThrs, ious, ground_truth_instances, detected_instances
|
90 |
+
)
|
91 |
+
self._evalImgs = None
|
92 |
+
|
93 |
+
self._paramsEval = copy.deepcopy(self.params)
|
94 |
+
toc = time.time()
|
95 |
+
logger.info("COCOeval_opt.evaluate() finished in {:0.2f} seconds.".format(toc - tic))
|
96 |
+
# >>>> End of code differences with original COCO API
|
97 |
+
|
98 |
+
def accumulate(self):
|
99 |
+
"""
|
100 |
+
Accumulate per image evaluation results and store the result in self.eval. Does not
|
101 |
+
support changing parameter settings from those used by self.evaluate()
|
102 |
+
"""
|
103 |
+
logger.info("Accumulating evaluation results...")
|
104 |
+
tic = time.time()
|
105 |
+
assert hasattr(
|
106 |
+
self, "_evalImgs_cpp"
|
107 |
+
), "evaluate() must be called before accmulate() is called."
|
108 |
+
|
109 |
+
self.eval = _C.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp)
|
110 |
+
|
111 |
+
# recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections
|
112 |
+
self.eval["recall"] = np.array(self.eval["recall"]).reshape(
|
113 |
+
self.eval["counts"][:1] + self.eval["counts"][2:]
|
114 |
+
)
|
115 |
+
|
116 |
+
# precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X
|
117 |
+
# num_area_ranges X num_max_detections
|
118 |
+
self.eval["precision"] = np.array(self.eval["precision"]).reshape(self.eval["counts"])
|
119 |
+
self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"])
|
120 |
+
toc = time.time()
|
121 |
+
logger.info("COCOeval_opt.accumulate() finished in {:0.2f} seconds.".format(toc - tic))
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/lvis_evaluation.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import itertools
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
from collections import OrderedDict
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import annotator.oneformer.detectron2.utils.comm as comm
|
12 |
+
from annotator.oneformer.detectron2.config import CfgNode
|
13 |
+
from annotator.oneformer.detectron2.data import MetadataCatalog
|
14 |
+
from annotator.oneformer.detectron2.structures import Boxes, BoxMode, pairwise_iou
|
15 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
16 |
+
from annotator.oneformer.detectron2.utils.logger import create_small_table
|
17 |
+
|
18 |
+
from .coco_evaluation import instances_to_coco_json
|
19 |
+
from .evaluator import DatasetEvaluator
|
20 |
+
|
21 |
+
|
22 |
+
class LVISEvaluator(DatasetEvaluator):
|
23 |
+
"""
|
24 |
+
Evaluate object proposal and instance detection/segmentation outputs using
|
25 |
+
LVIS's metrics and evaluation API.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
dataset_name,
|
31 |
+
tasks=None,
|
32 |
+
distributed=True,
|
33 |
+
output_dir=None,
|
34 |
+
*,
|
35 |
+
max_dets_per_image=None,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
dataset_name (str): name of the dataset to be evaluated.
|
40 |
+
It must have the following corresponding metadata:
|
41 |
+
"json_file": the path to the LVIS format annotation
|
42 |
+
tasks (tuple[str]): tasks that can be evaluated under the given
|
43 |
+
configuration. A task is one of "bbox", "segm".
|
44 |
+
By default, will infer this automatically from predictions.
|
45 |
+
distributed (True): if True, will collect results from all ranks for evaluation.
|
46 |
+
Otherwise, will evaluate the results in the current process.
|
47 |
+
output_dir (str): optional, an output directory to dump results.
|
48 |
+
max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP
|
49 |
+
This limit, by default of the LVIS dataset, is 300.
|
50 |
+
"""
|
51 |
+
from lvis import LVIS
|
52 |
+
|
53 |
+
self._logger = logging.getLogger(__name__)
|
54 |
+
|
55 |
+
if tasks is not None and isinstance(tasks, CfgNode):
|
56 |
+
self._logger.warn(
|
57 |
+
"COCO Evaluator instantiated using config, this is deprecated behavior."
|
58 |
+
" Please pass in explicit arguments instead."
|
59 |
+
)
|
60 |
+
self._tasks = None # Infering it from predictions should be better
|
61 |
+
else:
|
62 |
+
self._tasks = tasks
|
63 |
+
|
64 |
+
self._distributed = distributed
|
65 |
+
self._output_dir = output_dir
|
66 |
+
self._max_dets_per_image = max_dets_per_image
|
67 |
+
|
68 |
+
self._cpu_device = torch.device("cpu")
|
69 |
+
|
70 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
71 |
+
json_file = PathManager.get_local_path(self._metadata.json_file)
|
72 |
+
self._lvis_api = LVIS(json_file)
|
73 |
+
# Test set json files do not contain annotations (evaluation must be
|
74 |
+
# performed using the LVIS evaluation server).
|
75 |
+
self._do_evaluation = len(self._lvis_api.get_ann_ids()) > 0
|
76 |
+
|
77 |
+
def reset(self):
|
78 |
+
self._predictions = []
|
79 |
+
|
80 |
+
def process(self, inputs, outputs):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
inputs: the inputs to a LVIS model (e.g., GeneralizedRCNN).
|
84 |
+
It is a list of dict. Each dict corresponds to an image and
|
85 |
+
contains keys like "height", "width", "file_name", "image_id".
|
86 |
+
outputs: the outputs of a LVIS model. It is a list of dicts with key
|
87 |
+
"instances" that contains :class:`Instances`.
|
88 |
+
"""
|
89 |
+
for input, output in zip(inputs, outputs):
|
90 |
+
prediction = {"image_id": input["image_id"]}
|
91 |
+
|
92 |
+
if "instances" in output:
|
93 |
+
instances = output["instances"].to(self._cpu_device)
|
94 |
+
prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
|
95 |
+
if "proposals" in output:
|
96 |
+
prediction["proposals"] = output["proposals"].to(self._cpu_device)
|
97 |
+
self._predictions.append(prediction)
|
98 |
+
|
99 |
+
def evaluate(self):
|
100 |
+
if self._distributed:
|
101 |
+
comm.synchronize()
|
102 |
+
predictions = comm.gather(self._predictions, dst=0)
|
103 |
+
predictions = list(itertools.chain(*predictions))
|
104 |
+
|
105 |
+
if not comm.is_main_process():
|
106 |
+
return
|
107 |
+
else:
|
108 |
+
predictions = self._predictions
|
109 |
+
|
110 |
+
if len(predictions) == 0:
|
111 |
+
self._logger.warning("[LVISEvaluator] Did not receive valid predictions.")
|
112 |
+
return {}
|
113 |
+
|
114 |
+
if self._output_dir:
|
115 |
+
PathManager.mkdirs(self._output_dir)
|
116 |
+
file_path = os.path.join(self._output_dir, "instances_predictions.pth")
|
117 |
+
with PathManager.open(file_path, "wb") as f:
|
118 |
+
torch.save(predictions, f)
|
119 |
+
|
120 |
+
self._results = OrderedDict()
|
121 |
+
if "proposals" in predictions[0]:
|
122 |
+
self._eval_box_proposals(predictions)
|
123 |
+
if "instances" in predictions[0]:
|
124 |
+
self._eval_predictions(predictions)
|
125 |
+
# Copy so the caller can do whatever with results
|
126 |
+
return copy.deepcopy(self._results)
|
127 |
+
|
128 |
+
def _tasks_from_predictions(self, predictions):
|
129 |
+
for pred in predictions:
|
130 |
+
if "segmentation" in pred:
|
131 |
+
return ("bbox", "segm")
|
132 |
+
return ("bbox",)
|
133 |
+
|
134 |
+
def _eval_predictions(self, predictions):
|
135 |
+
"""
|
136 |
+
Evaluate predictions. Fill self._results with the metrics of the tasks.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
predictions (list[dict]): list of outputs from the model
|
140 |
+
"""
|
141 |
+
self._logger.info("Preparing results in the LVIS format ...")
|
142 |
+
lvis_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
143 |
+
tasks = self._tasks or self._tasks_from_predictions(lvis_results)
|
144 |
+
|
145 |
+
# LVIS evaluator can be used to evaluate results for COCO dataset categories.
|
146 |
+
# In this case `_metadata` variable will have a field with COCO-specific category mapping.
|
147 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
148 |
+
reverse_id_mapping = {
|
149 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
150 |
+
}
|
151 |
+
for result in lvis_results:
|
152 |
+
result["category_id"] = reverse_id_mapping[result["category_id"]]
|
153 |
+
else:
|
154 |
+
# unmap the category ids for LVIS (from 0-indexed to 1-indexed)
|
155 |
+
for result in lvis_results:
|
156 |
+
result["category_id"] += 1
|
157 |
+
|
158 |
+
if self._output_dir:
|
159 |
+
file_path = os.path.join(self._output_dir, "lvis_instances_results.json")
|
160 |
+
self._logger.info("Saving results to {}".format(file_path))
|
161 |
+
with PathManager.open(file_path, "w") as f:
|
162 |
+
f.write(json.dumps(lvis_results))
|
163 |
+
f.flush()
|
164 |
+
|
165 |
+
if not self._do_evaluation:
|
166 |
+
self._logger.info("Annotations are not available for evaluation.")
|
167 |
+
return
|
168 |
+
|
169 |
+
self._logger.info("Evaluating predictions ...")
|
170 |
+
for task in sorted(tasks):
|
171 |
+
res = _evaluate_predictions_on_lvis(
|
172 |
+
self._lvis_api,
|
173 |
+
lvis_results,
|
174 |
+
task,
|
175 |
+
max_dets_per_image=self._max_dets_per_image,
|
176 |
+
class_names=self._metadata.get("thing_classes"),
|
177 |
+
)
|
178 |
+
self._results[task] = res
|
179 |
+
|
180 |
+
def _eval_box_proposals(self, predictions):
|
181 |
+
"""
|
182 |
+
Evaluate the box proposals in predictions.
|
183 |
+
Fill self._results with the metrics for "box_proposals" task.
|
184 |
+
"""
|
185 |
+
if self._output_dir:
|
186 |
+
# Saving generated box proposals to file.
|
187 |
+
# Predicted box_proposals are in XYXY_ABS mode.
|
188 |
+
bbox_mode = BoxMode.XYXY_ABS.value
|
189 |
+
ids, boxes, objectness_logits = [], [], []
|
190 |
+
for prediction in predictions:
|
191 |
+
ids.append(prediction["image_id"])
|
192 |
+
boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
|
193 |
+
objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
|
194 |
+
|
195 |
+
proposal_data = {
|
196 |
+
"boxes": boxes,
|
197 |
+
"objectness_logits": objectness_logits,
|
198 |
+
"ids": ids,
|
199 |
+
"bbox_mode": bbox_mode,
|
200 |
+
}
|
201 |
+
with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
|
202 |
+
pickle.dump(proposal_data, f)
|
203 |
+
|
204 |
+
if not self._do_evaluation:
|
205 |
+
self._logger.info("Annotations are not available for evaluation.")
|
206 |
+
return
|
207 |
+
|
208 |
+
self._logger.info("Evaluating bbox proposals ...")
|
209 |
+
res = {}
|
210 |
+
areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
|
211 |
+
for limit in [100, 1000]:
|
212 |
+
for area, suffix in areas.items():
|
213 |
+
stats = _evaluate_box_proposals(predictions, self._lvis_api, area=area, limit=limit)
|
214 |
+
key = "AR{}@{:d}".format(suffix, limit)
|
215 |
+
res[key] = float(stats["ar"].item() * 100)
|
216 |
+
self._logger.info("Proposal metrics: \n" + create_small_table(res))
|
217 |
+
self._results["box_proposals"] = res
|
218 |
+
|
219 |
+
|
220 |
+
# inspired from Detectron:
|
221 |
+
# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
|
222 |
+
def _evaluate_box_proposals(dataset_predictions, lvis_api, thresholds=None, area="all", limit=None):
|
223 |
+
"""
|
224 |
+
Evaluate detection proposal recall metrics. This function is a much
|
225 |
+
faster alternative to the official LVIS API recall evaluation code. However,
|
226 |
+
it produces slightly different results.
|
227 |
+
"""
|
228 |
+
# Record max overlap value for each gt box
|
229 |
+
# Return vector of overlap values
|
230 |
+
areas = {
|
231 |
+
"all": 0,
|
232 |
+
"small": 1,
|
233 |
+
"medium": 2,
|
234 |
+
"large": 3,
|
235 |
+
"96-128": 4,
|
236 |
+
"128-256": 5,
|
237 |
+
"256-512": 6,
|
238 |
+
"512-inf": 7,
|
239 |
+
}
|
240 |
+
area_ranges = [
|
241 |
+
[0**2, 1e5**2], # all
|
242 |
+
[0**2, 32**2], # small
|
243 |
+
[32**2, 96**2], # medium
|
244 |
+
[96**2, 1e5**2], # large
|
245 |
+
[96**2, 128**2], # 96-128
|
246 |
+
[128**2, 256**2], # 128-256
|
247 |
+
[256**2, 512**2], # 256-512
|
248 |
+
[512**2, 1e5**2],
|
249 |
+
] # 512-inf
|
250 |
+
assert area in areas, "Unknown area range: {}".format(area)
|
251 |
+
area_range = area_ranges[areas[area]]
|
252 |
+
gt_overlaps = []
|
253 |
+
num_pos = 0
|
254 |
+
|
255 |
+
for prediction_dict in dataset_predictions:
|
256 |
+
predictions = prediction_dict["proposals"]
|
257 |
+
|
258 |
+
# sort predictions in descending order
|
259 |
+
# TODO maybe remove this and make it explicit in the documentation
|
260 |
+
inds = predictions.objectness_logits.sort(descending=True)[1]
|
261 |
+
predictions = predictions[inds]
|
262 |
+
|
263 |
+
ann_ids = lvis_api.get_ann_ids(img_ids=[prediction_dict["image_id"]])
|
264 |
+
anno = lvis_api.load_anns(ann_ids)
|
265 |
+
gt_boxes = [
|
266 |
+
BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) for obj in anno
|
267 |
+
]
|
268 |
+
gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
|
269 |
+
gt_boxes = Boxes(gt_boxes)
|
270 |
+
gt_areas = torch.as_tensor([obj["area"] for obj in anno])
|
271 |
+
|
272 |
+
if len(gt_boxes) == 0 or len(predictions) == 0:
|
273 |
+
continue
|
274 |
+
|
275 |
+
valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
|
276 |
+
gt_boxes = gt_boxes[valid_gt_inds]
|
277 |
+
|
278 |
+
num_pos += len(gt_boxes)
|
279 |
+
|
280 |
+
if len(gt_boxes) == 0:
|
281 |
+
continue
|
282 |
+
|
283 |
+
if limit is not None and len(predictions) > limit:
|
284 |
+
predictions = predictions[:limit]
|
285 |
+
|
286 |
+
overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
|
287 |
+
|
288 |
+
_gt_overlaps = torch.zeros(len(gt_boxes))
|
289 |
+
for j in range(min(len(predictions), len(gt_boxes))):
|
290 |
+
# find which proposal box maximally covers each gt box
|
291 |
+
# and get the iou amount of coverage for each gt box
|
292 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
293 |
+
|
294 |
+
# find which gt box is 'best' covered (i.e. 'best' = most iou)
|
295 |
+
gt_ovr, gt_ind = max_overlaps.max(dim=0)
|
296 |
+
assert gt_ovr >= 0
|
297 |
+
# find the proposal box that covers the best covered gt box
|
298 |
+
box_ind = argmax_overlaps[gt_ind]
|
299 |
+
# record the iou coverage of this gt box
|
300 |
+
_gt_overlaps[j] = overlaps[box_ind, gt_ind]
|
301 |
+
assert _gt_overlaps[j] == gt_ovr
|
302 |
+
# mark the proposal box and the gt box as used
|
303 |
+
overlaps[box_ind, :] = -1
|
304 |
+
overlaps[:, gt_ind] = -1
|
305 |
+
|
306 |
+
# append recorded iou coverage level
|
307 |
+
gt_overlaps.append(_gt_overlaps)
|
308 |
+
gt_overlaps = (
|
309 |
+
torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
|
310 |
+
)
|
311 |
+
gt_overlaps, _ = torch.sort(gt_overlaps)
|
312 |
+
|
313 |
+
if thresholds is None:
|
314 |
+
step = 0.05
|
315 |
+
thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
|
316 |
+
recalls = torch.zeros_like(thresholds)
|
317 |
+
# compute recall for each iou threshold
|
318 |
+
for i, t in enumerate(thresholds):
|
319 |
+
recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
|
320 |
+
# ar = 2 * np.trapz(recalls, thresholds)
|
321 |
+
ar = recalls.mean()
|
322 |
+
return {
|
323 |
+
"ar": ar,
|
324 |
+
"recalls": recalls,
|
325 |
+
"thresholds": thresholds,
|
326 |
+
"gt_overlaps": gt_overlaps,
|
327 |
+
"num_pos": num_pos,
|
328 |
+
}
|
329 |
+
|
330 |
+
|
331 |
+
def _evaluate_predictions_on_lvis(
|
332 |
+
lvis_gt, lvis_results, iou_type, max_dets_per_image=None, class_names=None
|
333 |
+
):
|
334 |
+
"""
|
335 |
+
Args:
|
336 |
+
iou_type (str):
|
337 |
+
max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP
|
338 |
+
This limit, by default of the LVIS dataset, is 300.
|
339 |
+
class_names (None or list[str]): if provided, will use it to predict
|
340 |
+
per-category AP.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
a dict of {metric name: score}
|
344 |
+
"""
|
345 |
+
metrics = {
|
346 |
+
"bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"],
|
347 |
+
"segm": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"],
|
348 |
+
}[iou_type]
|
349 |
+
|
350 |
+
logger = logging.getLogger(__name__)
|
351 |
+
|
352 |
+
if len(lvis_results) == 0: # TODO: check if needed
|
353 |
+
logger.warn("No predictions from the model!")
|
354 |
+
return {metric: float("nan") for metric in metrics}
|
355 |
+
|
356 |
+
if iou_type == "segm":
|
357 |
+
lvis_results = copy.deepcopy(lvis_results)
|
358 |
+
# When evaluating mask AP, if the results contain bbox, LVIS API will
|
359 |
+
# use the box area as the area of the instance, instead of the mask area.
|
360 |
+
# This leads to a different definition of small/medium/large.
|
361 |
+
# We remove the bbox field to let mask AP use mask area.
|
362 |
+
for c in lvis_results:
|
363 |
+
c.pop("bbox", None)
|
364 |
+
|
365 |
+
if max_dets_per_image is None:
|
366 |
+
max_dets_per_image = 300 # Default for LVIS dataset
|
367 |
+
|
368 |
+
from lvis import LVISEval, LVISResults
|
369 |
+
|
370 |
+
logger.info(f"Evaluating with max detections per image = {max_dets_per_image}")
|
371 |
+
lvis_results = LVISResults(lvis_gt, lvis_results, max_dets=max_dets_per_image)
|
372 |
+
lvis_eval = LVISEval(lvis_gt, lvis_results, iou_type)
|
373 |
+
lvis_eval.run()
|
374 |
+
lvis_eval.print_results()
|
375 |
+
|
376 |
+
# Pull the standard metrics from the LVIS results
|
377 |
+
results = lvis_eval.get_results()
|
378 |
+
results = {metric: float(results[metric] * 100) for metric in metrics}
|
379 |
+
logger.info("Evaluation results for {}: \n".format(iou_type) + create_small_table(results))
|
380 |
+
return results
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/panoptic_evaluation.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import contextlib
|
3 |
+
import io
|
4 |
+
import itertools
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
import tempfile
|
10 |
+
from collections import OrderedDict
|
11 |
+
from typing import Optional
|
12 |
+
from PIL import Image
|
13 |
+
from tabulate import tabulate
|
14 |
+
|
15 |
+
from annotator.oneformer.detectron2.data import MetadataCatalog
|
16 |
+
from annotator.oneformer.detectron2.utils import comm
|
17 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
18 |
+
|
19 |
+
from .evaluator import DatasetEvaluator
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class COCOPanopticEvaluator(DatasetEvaluator):
|
25 |
+
"""
|
26 |
+
Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
|
27 |
+
It saves panoptic segmentation prediction in `output_dir`
|
28 |
+
|
29 |
+
It contains a synchronize call and has to be called from all workers.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, dataset_name: str, output_dir: Optional[str] = None):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
dataset_name: name of the dataset
|
36 |
+
output_dir: output directory to save results for evaluation.
|
37 |
+
"""
|
38 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
39 |
+
self._thing_contiguous_id_to_dataset_id = {
|
40 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
41 |
+
}
|
42 |
+
self._stuff_contiguous_id_to_dataset_id = {
|
43 |
+
v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
|
44 |
+
}
|
45 |
+
|
46 |
+
self._output_dir = output_dir
|
47 |
+
if self._output_dir is not None:
|
48 |
+
PathManager.mkdirs(self._output_dir)
|
49 |
+
|
50 |
+
def reset(self):
|
51 |
+
self._predictions = []
|
52 |
+
|
53 |
+
def _convert_category_id(self, segment_info):
|
54 |
+
isthing = segment_info.pop("isthing", None)
|
55 |
+
if isthing is None:
|
56 |
+
# the model produces panoptic category id directly. No more conversion needed
|
57 |
+
return segment_info
|
58 |
+
if isthing is True:
|
59 |
+
segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
|
60 |
+
segment_info["category_id"]
|
61 |
+
]
|
62 |
+
else:
|
63 |
+
segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
|
64 |
+
segment_info["category_id"]
|
65 |
+
]
|
66 |
+
return segment_info
|
67 |
+
|
68 |
+
def process(self, inputs, outputs):
|
69 |
+
from panopticapi.utils import id2rgb
|
70 |
+
|
71 |
+
for input, output in zip(inputs, outputs):
|
72 |
+
panoptic_img, segments_info = output["panoptic_seg"]
|
73 |
+
panoptic_img = panoptic_img.cpu().numpy()
|
74 |
+
if segments_info is None:
|
75 |
+
# If "segments_info" is None, we assume "panoptic_img" is a
|
76 |
+
# H*W int32 image storing the panoptic_id in the format of
|
77 |
+
# category_id * label_divisor + instance_id. We reserve -1 for
|
78 |
+
# VOID label, and add 1 to panoptic_img since the official
|
79 |
+
# evaluation script uses 0 for VOID label.
|
80 |
+
label_divisor = self._metadata.label_divisor
|
81 |
+
segments_info = []
|
82 |
+
for panoptic_label in np.unique(panoptic_img):
|
83 |
+
if panoptic_label == -1:
|
84 |
+
# VOID region.
|
85 |
+
continue
|
86 |
+
pred_class = panoptic_label // label_divisor
|
87 |
+
isthing = (
|
88 |
+
pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
|
89 |
+
)
|
90 |
+
segments_info.append(
|
91 |
+
{
|
92 |
+
"id": int(panoptic_label) + 1,
|
93 |
+
"category_id": int(pred_class),
|
94 |
+
"isthing": bool(isthing),
|
95 |
+
}
|
96 |
+
)
|
97 |
+
# Official evaluation script uses 0 for VOID label.
|
98 |
+
panoptic_img += 1
|
99 |
+
|
100 |
+
file_name = os.path.basename(input["file_name"])
|
101 |
+
file_name_png = os.path.splitext(file_name)[0] + ".png"
|
102 |
+
with io.BytesIO() as out:
|
103 |
+
Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
|
104 |
+
segments_info = [self._convert_category_id(x) for x in segments_info]
|
105 |
+
self._predictions.append(
|
106 |
+
{
|
107 |
+
"image_id": input["image_id"],
|
108 |
+
"file_name": file_name_png,
|
109 |
+
"png_string": out.getvalue(),
|
110 |
+
"segments_info": segments_info,
|
111 |
+
}
|
112 |
+
)
|
113 |
+
|
114 |
+
def evaluate(self):
|
115 |
+
comm.synchronize()
|
116 |
+
|
117 |
+
self._predictions = comm.gather(self._predictions)
|
118 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
119 |
+
if not comm.is_main_process():
|
120 |
+
return
|
121 |
+
|
122 |
+
# PanopticApi requires local files
|
123 |
+
gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
|
124 |
+
gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
|
125 |
+
|
126 |
+
with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
|
127 |
+
logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
|
128 |
+
for p in self._predictions:
|
129 |
+
with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
|
130 |
+
f.write(p.pop("png_string"))
|
131 |
+
|
132 |
+
with open(gt_json, "r") as f:
|
133 |
+
json_data = json.load(f)
|
134 |
+
json_data["annotations"] = self._predictions
|
135 |
+
|
136 |
+
output_dir = self._output_dir or pred_dir
|
137 |
+
predictions_json = os.path.join(output_dir, "predictions.json")
|
138 |
+
with PathManager.open(predictions_json, "w") as f:
|
139 |
+
f.write(json.dumps(json_data))
|
140 |
+
|
141 |
+
from panopticapi.evaluation import pq_compute
|
142 |
+
|
143 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
144 |
+
pq_res = pq_compute(
|
145 |
+
gt_json,
|
146 |
+
PathManager.get_local_path(predictions_json),
|
147 |
+
gt_folder=gt_folder,
|
148 |
+
pred_folder=pred_dir,
|
149 |
+
)
|
150 |
+
|
151 |
+
res = {}
|
152 |
+
res["PQ"] = 100 * pq_res["All"]["pq"]
|
153 |
+
res["SQ"] = 100 * pq_res["All"]["sq"]
|
154 |
+
res["RQ"] = 100 * pq_res["All"]["rq"]
|
155 |
+
res["PQ_th"] = 100 * pq_res["Things"]["pq"]
|
156 |
+
res["SQ_th"] = 100 * pq_res["Things"]["sq"]
|
157 |
+
res["RQ_th"] = 100 * pq_res["Things"]["rq"]
|
158 |
+
res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
|
159 |
+
res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
|
160 |
+
res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
|
161 |
+
|
162 |
+
results = OrderedDict({"panoptic_seg": res})
|
163 |
+
_print_panoptic_results(pq_res)
|
164 |
+
|
165 |
+
return results
|
166 |
+
|
167 |
+
|
168 |
+
def _print_panoptic_results(pq_res):
|
169 |
+
headers = ["", "PQ", "SQ", "RQ", "#categories"]
|
170 |
+
data = []
|
171 |
+
for name in ["All", "Things", "Stuff"]:
|
172 |
+
row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
|
173 |
+
data.append(row)
|
174 |
+
table = tabulate(
|
175 |
+
data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
|
176 |
+
)
|
177 |
+
logger.info("Panoptic Evaluation Results:\n" + table)
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
from annotator.oneformer.detectron2.utils.logger import setup_logger
|
182 |
+
|
183 |
+
logger = setup_logger()
|
184 |
+
import argparse
|
185 |
+
|
186 |
+
parser = argparse.ArgumentParser()
|
187 |
+
parser.add_argument("--gt-json")
|
188 |
+
parser.add_argument("--gt-dir")
|
189 |
+
parser.add_argument("--pred-json")
|
190 |
+
parser.add_argument("--pred-dir")
|
191 |
+
args = parser.parse_args()
|
192 |
+
|
193 |
+
from panopticapi.evaluation import pq_compute
|
194 |
+
|
195 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
196 |
+
pq_res = pq_compute(
|
197 |
+
args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
|
198 |
+
)
|
199 |
+
_print_panoptic_results(pq_res)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/pascal_voc_evaluation.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import tempfile
|
8 |
+
import xml.etree.ElementTree as ET
|
9 |
+
from collections import OrderedDict, defaultdict
|
10 |
+
from functools import lru_cache
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from annotator.oneformer.detectron2.data import MetadataCatalog
|
14 |
+
from annotator.oneformer.detectron2.utils import comm
|
15 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
16 |
+
|
17 |
+
from .evaluator import DatasetEvaluator
|
18 |
+
|
19 |
+
|
20 |
+
class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
21 |
+
"""
|
22 |
+
Evaluate Pascal VOC style AP for Pascal VOC dataset.
|
23 |
+
It contains a synchronization, therefore has to be called from all ranks.
|
24 |
+
|
25 |
+
Note that the concept of AP can be implemented in different ways and may not
|
26 |
+
produce identical results. This class mimics the implementation of the official
|
27 |
+
Pascal VOC Matlab API, and should produce similar but not identical results to the
|
28 |
+
official API.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, dataset_name):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
dataset_name (str): name of the dataset, e.g., "voc_2007_test"
|
35 |
+
"""
|
36 |
+
self._dataset_name = dataset_name
|
37 |
+
meta = MetadataCatalog.get(dataset_name)
|
38 |
+
|
39 |
+
# Too many tiny files, download all to local for speed.
|
40 |
+
annotation_dir_local = PathManager.get_local_path(
|
41 |
+
os.path.join(meta.dirname, "Annotations/")
|
42 |
+
)
|
43 |
+
self._anno_file_template = os.path.join(annotation_dir_local, "{}.xml")
|
44 |
+
self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt")
|
45 |
+
self._class_names = meta.thing_classes
|
46 |
+
assert meta.year in [2007, 2012], meta.year
|
47 |
+
self._is_2007 = meta.year == 2007
|
48 |
+
self._cpu_device = torch.device("cpu")
|
49 |
+
self._logger = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
def reset(self):
|
52 |
+
self._predictions = defaultdict(list) # class name -> list of prediction strings
|
53 |
+
|
54 |
+
def process(self, inputs, outputs):
|
55 |
+
for input, output in zip(inputs, outputs):
|
56 |
+
image_id = input["image_id"]
|
57 |
+
instances = output["instances"].to(self._cpu_device)
|
58 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
59 |
+
scores = instances.scores.tolist()
|
60 |
+
classes = instances.pred_classes.tolist()
|
61 |
+
for box, score, cls in zip(boxes, scores, classes):
|
62 |
+
xmin, ymin, xmax, ymax = box
|
63 |
+
# The inverse of data loading logic in `datasets/pascal_voc.py`
|
64 |
+
xmin += 1
|
65 |
+
ymin += 1
|
66 |
+
self._predictions[cls].append(
|
67 |
+
f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}"
|
68 |
+
)
|
69 |
+
|
70 |
+
def evaluate(self):
|
71 |
+
"""
|
72 |
+
Returns:
|
73 |
+
dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
|
74 |
+
"""
|
75 |
+
all_predictions = comm.gather(self._predictions, dst=0)
|
76 |
+
if not comm.is_main_process():
|
77 |
+
return
|
78 |
+
predictions = defaultdict(list)
|
79 |
+
for predictions_per_rank in all_predictions:
|
80 |
+
for clsid, lines in predictions_per_rank.items():
|
81 |
+
predictions[clsid].extend(lines)
|
82 |
+
del all_predictions
|
83 |
+
|
84 |
+
self._logger.info(
|
85 |
+
"Evaluating {} using {} metric. "
|
86 |
+
"Note that results do not use the official Matlab API.".format(
|
87 |
+
self._dataset_name, 2007 if self._is_2007 else 2012
|
88 |
+
)
|
89 |
+
)
|
90 |
+
|
91 |
+
with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
|
92 |
+
res_file_template = os.path.join(dirname, "{}.txt")
|
93 |
+
|
94 |
+
aps = defaultdict(list) # iou -> ap per class
|
95 |
+
for cls_id, cls_name in enumerate(self._class_names):
|
96 |
+
lines = predictions.get(cls_id, [""])
|
97 |
+
|
98 |
+
with open(res_file_template.format(cls_name), "w") as f:
|
99 |
+
f.write("\n".join(lines))
|
100 |
+
|
101 |
+
for thresh in range(50, 100, 5):
|
102 |
+
rec, prec, ap = voc_eval(
|
103 |
+
res_file_template,
|
104 |
+
self._anno_file_template,
|
105 |
+
self._image_set_path,
|
106 |
+
cls_name,
|
107 |
+
ovthresh=thresh / 100.0,
|
108 |
+
use_07_metric=self._is_2007,
|
109 |
+
)
|
110 |
+
aps[thresh].append(ap * 100)
|
111 |
+
|
112 |
+
ret = OrderedDict()
|
113 |
+
mAP = {iou: np.mean(x) for iou, x in aps.items()}
|
114 |
+
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
|
115 |
+
return ret
|
116 |
+
|
117 |
+
|
118 |
+
##############################################################################
|
119 |
+
#
|
120 |
+
# Below code is modified from
|
121 |
+
# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py
|
122 |
+
# --------------------------------------------------------
|
123 |
+
# Fast/er R-CNN
|
124 |
+
# Licensed under The MIT License [see LICENSE for details]
|
125 |
+
# Written by Bharath Hariharan
|
126 |
+
# --------------------------------------------------------
|
127 |
+
|
128 |
+
"""Python implementation of the PASCAL VOC devkit's AP evaluation code."""
|
129 |
+
|
130 |
+
|
131 |
+
@lru_cache(maxsize=None)
|
132 |
+
def parse_rec(filename):
|
133 |
+
"""Parse a PASCAL VOC xml file."""
|
134 |
+
with PathManager.open(filename) as f:
|
135 |
+
tree = ET.parse(f)
|
136 |
+
objects = []
|
137 |
+
for obj in tree.findall("object"):
|
138 |
+
obj_struct = {}
|
139 |
+
obj_struct["name"] = obj.find("name").text
|
140 |
+
obj_struct["pose"] = obj.find("pose").text
|
141 |
+
obj_struct["truncated"] = int(obj.find("truncated").text)
|
142 |
+
obj_struct["difficult"] = int(obj.find("difficult").text)
|
143 |
+
bbox = obj.find("bndbox")
|
144 |
+
obj_struct["bbox"] = [
|
145 |
+
int(bbox.find("xmin").text),
|
146 |
+
int(bbox.find("ymin").text),
|
147 |
+
int(bbox.find("xmax").text),
|
148 |
+
int(bbox.find("ymax").text),
|
149 |
+
]
|
150 |
+
objects.append(obj_struct)
|
151 |
+
|
152 |
+
return objects
|
153 |
+
|
154 |
+
|
155 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
156 |
+
"""Compute VOC AP given precision and recall. If use_07_metric is true, uses
|
157 |
+
the VOC 07 11-point method (default:False).
|
158 |
+
"""
|
159 |
+
if use_07_metric:
|
160 |
+
# 11 point metric
|
161 |
+
ap = 0.0
|
162 |
+
for t in np.arange(0.0, 1.1, 0.1):
|
163 |
+
if np.sum(rec >= t) == 0:
|
164 |
+
p = 0
|
165 |
+
else:
|
166 |
+
p = np.max(prec[rec >= t])
|
167 |
+
ap = ap + p / 11.0
|
168 |
+
else:
|
169 |
+
# correct AP calculation
|
170 |
+
# first append sentinel values at the end
|
171 |
+
mrec = np.concatenate(([0.0], rec, [1.0]))
|
172 |
+
mpre = np.concatenate(([0.0], prec, [0.0]))
|
173 |
+
|
174 |
+
# compute the precision envelope
|
175 |
+
for i in range(mpre.size - 1, 0, -1):
|
176 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
177 |
+
|
178 |
+
# to calculate area under PR curve, look for points
|
179 |
+
# where X axis (recall) changes value
|
180 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
181 |
+
|
182 |
+
# and sum (\Delta recall) * prec
|
183 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
184 |
+
return ap
|
185 |
+
|
186 |
+
|
187 |
+
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False):
|
188 |
+
"""rec, prec, ap = voc_eval(detpath,
|
189 |
+
annopath,
|
190 |
+
imagesetfile,
|
191 |
+
classname,
|
192 |
+
[ovthresh],
|
193 |
+
[use_07_metric])
|
194 |
+
|
195 |
+
Top level function that does the PASCAL VOC evaluation.
|
196 |
+
|
197 |
+
detpath: Path to detections
|
198 |
+
detpath.format(classname) should produce the detection results file.
|
199 |
+
annopath: Path to annotations
|
200 |
+
annopath.format(imagename) should be the xml annotations file.
|
201 |
+
imagesetfile: Text file containing the list of images, one image per line.
|
202 |
+
classname: Category name (duh)
|
203 |
+
[ovthresh]: Overlap threshold (default = 0.5)
|
204 |
+
[use_07_metric]: Whether to use VOC07's 11 point AP computation
|
205 |
+
(default False)
|
206 |
+
"""
|
207 |
+
# assumes detections are in detpath.format(classname)
|
208 |
+
# assumes annotations are in annopath.format(imagename)
|
209 |
+
# assumes imagesetfile is a text file with each line an image name
|
210 |
+
|
211 |
+
# first load gt
|
212 |
+
# read list of images
|
213 |
+
with PathManager.open(imagesetfile, "r") as f:
|
214 |
+
lines = f.readlines()
|
215 |
+
imagenames = [x.strip() for x in lines]
|
216 |
+
|
217 |
+
# load annots
|
218 |
+
recs = {}
|
219 |
+
for imagename in imagenames:
|
220 |
+
recs[imagename] = parse_rec(annopath.format(imagename))
|
221 |
+
|
222 |
+
# extract gt objects for this class
|
223 |
+
class_recs = {}
|
224 |
+
npos = 0
|
225 |
+
for imagename in imagenames:
|
226 |
+
R = [obj for obj in recs[imagename] if obj["name"] == classname]
|
227 |
+
bbox = np.array([x["bbox"] for x in R])
|
228 |
+
difficult = np.array([x["difficult"] for x in R]).astype(bool)
|
229 |
+
# difficult = np.array([False for x in R]).astype(bool) # treat all "difficult" as GT
|
230 |
+
det = [False] * len(R)
|
231 |
+
npos = npos + sum(~difficult)
|
232 |
+
class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
|
233 |
+
|
234 |
+
# read dets
|
235 |
+
detfile = detpath.format(classname)
|
236 |
+
with open(detfile, "r") as f:
|
237 |
+
lines = f.readlines()
|
238 |
+
|
239 |
+
splitlines = [x.strip().split(" ") for x in lines]
|
240 |
+
image_ids = [x[0] for x in splitlines]
|
241 |
+
confidence = np.array([float(x[1]) for x in splitlines])
|
242 |
+
BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4)
|
243 |
+
|
244 |
+
# sort by confidence
|
245 |
+
sorted_ind = np.argsort(-confidence)
|
246 |
+
BB = BB[sorted_ind, :]
|
247 |
+
image_ids = [image_ids[x] for x in sorted_ind]
|
248 |
+
|
249 |
+
# go down dets and mark TPs and FPs
|
250 |
+
nd = len(image_ids)
|
251 |
+
tp = np.zeros(nd)
|
252 |
+
fp = np.zeros(nd)
|
253 |
+
for d in range(nd):
|
254 |
+
R = class_recs[image_ids[d]]
|
255 |
+
bb = BB[d, :].astype(float)
|
256 |
+
ovmax = -np.inf
|
257 |
+
BBGT = R["bbox"].astype(float)
|
258 |
+
|
259 |
+
if BBGT.size > 0:
|
260 |
+
# compute overlaps
|
261 |
+
# intersection
|
262 |
+
ixmin = np.maximum(BBGT[:, 0], bb[0])
|
263 |
+
iymin = np.maximum(BBGT[:, 1], bb[1])
|
264 |
+
ixmax = np.minimum(BBGT[:, 2], bb[2])
|
265 |
+
iymax = np.minimum(BBGT[:, 3], bb[3])
|
266 |
+
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
|
267 |
+
ih = np.maximum(iymax - iymin + 1.0, 0.0)
|
268 |
+
inters = iw * ih
|
269 |
+
|
270 |
+
# union
|
271 |
+
uni = (
|
272 |
+
(bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
|
273 |
+
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
|
274 |
+
- inters
|
275 |
+
)
|
276 |
+
|
277 |
+
overlaps = inters / uni
|
278 |
+
ovmax = np.max(overlaps)
|
279 |
+
jmax = np.argmax(overlaps)
|
280 |
+
|
281 |
+
if ovmax > ovthresh:
|
282 |
+
if not R["difficult"][jmax]:
|
283 |
+
if not R["det"][jmax]:
|
284 |
+
tp[d] = 1.0
|
285 |
+
R["det"][jmax] = 1
|
286 |
+
else:
|
287 |
+
fp[d] = 1.0
|
288 |
+
else:
|
289 |
+
fp[d] = 1.0
|
290 |
+
|
291 |
+
# compute precision recall
|
292 |
+
fp = np.cumsum(fp)
|
293 |
+
tp = np.cumsum(tp)
|
294 |
+
rec = tp / float(npos)
|
295 |
+
# avoid divide by zero in case the first detection matches a difficult
|
296 |
+
# ground truth
|
297 |
+
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
|
298 |
+
ap = voc_ap(rec, prec, use_07_metric)
|
299 |
+
|
300 |
+
return rec, prec, ap
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/rotated_coco_evaluation.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from annotator.oneformer.pycocotools.cocoeval import COCOeval, maskUtils
|
8 |
+
|
9 |
+
from annotator.oneformer.detectron2.structures import BoxMode, RotatedBoxes, pairwise_iou_rotated
|
10 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
11 |
+
|
12 |
+
from .coco_evaluation import COCOEvaluator
|
13 |
+
|
14 |
+
|
15 |
+
class RotatedCOCOeval(COCOeval):
|
16 |
+
@staticmethod
|
17 |
+
def is_rotated(box_list):
|
18 |
+
if type(box_list) == np.ndarray:
|
19 |
+
return box_list.shape[1] == 5
|
20 |
+
elif type(box_list) == list:
|
21 |
+
if box_list == []: # cannot decide the box_dim
|
22 |
+
return False
|
23 |
+
return np.all(
|
24 |
+
np.array(
|
25 |
+
[
|
26 |
+
(len(obj) == 5) and ((type(obj) == list) or (type(obj) == np.ndarray))
|
27 |
+
for obj in box_list
|
28 |
+
]
|
29 |
+
)
|
30 |
+
)
|
31 |
+
return False
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def boxlist_to_tensor(boxlist, output_box_dim):
|
35 |
+
if type(boxlist) == np.ndarray:
|
36 |
+
box_tensor = torch.from_numpy(boxlist)
|
37 |
+
elif type(boxlist) == list:
|
38 |
+
if boxlist == []:
|
39 |
+
return torch.zeros((0, output_box_dim), dtype=torch.float32)
|
40 |
+
else:
|
41 |
+
box_tensor = torch.FloatTensor(boxlist)
|
42 |
+
else:
|
43 |
+
raise Exception("Unrecognized boxlist type")
|
44 |
+
|
45 |
+
input_box_dim = box_tensor.shape[1]
|
46 |
+
if input_box_dim != output_box_dim:
|
47 |
+
if input_box_dim == 4 and output_box_dim == 5:
|
48 |
+
box_tensor = BoxMode.convert(box_tensor, BoxMode.XYWH_ABS, BoxMode.XYWHA_ABS)
|
49 |
+
else:
|
50 |
+
raise Exception(
|
51 |
+
"Unable to convert from {}-dim box to {}-dim box".format(
|
52 |
+
input_box_dim, output_box_dim
|
53 |
+
)
|
54 |
+
)
|
55 |
+
return box_tensor
|
56 |
+
|
57 |
+
def compute_iou_dt_gt(self, dt, gt, is_crowd):
|
58 |
+
if self.is_rotated(dt) or self.is_rotated(gt):
|
59 |
+
# TODO: take is_crowd into consideration
|
60 |
+
assert all(c == 0 for c in is_crowd)
|
61 |
+
dt = RotatedBoxes(self.boxlist_to_tensor(dt, output_box_dim=5))
|
62 |
+
gt = RotatedBoxes(self.boxlist_to_tensor(gt, output_box_dim=5))
|
63 |
+
return pairwise_iou_rotated(dt, gt)
|
64 |
+
else:
|
65 |
+
# This is the same as the classical COCO evaluation
|
66 |
+
return maskUtils.iou(dt, gt, is_crowd)
|
67 |
+
|
68 |
+
def computeIoU(self, imgId, catId):
|
69 |
+
p = self.params
|
70 |
+
if p.useCats:
|
71 |
+
gt = self._gts[imgId, catId]
|
72 |
+
dt = self._dts[imgId, catId]
|
73 |
+
else:
|
74 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
75 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
76 |
+
if len(gt) == 0 and len(dt) == 0:
|
77 |
+
return []
|
78 |
+
inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
79 |
+
dt = [dt[i] for i in inds]
|
80 |
+
if len(dt) > p.maxDets[-1]:
|
81 |
+
dt = dt[0 : p.maxDets[-1]]
|
82 |
+
|
83 |
+
assert p.iouType == "bbox", "unsupported iouType for iou computation"
|
84 |
+
|
85 |
+
g = [g["bbox"] for g in gt]
|
86 |
+
d = [d["bbox"] for d in dt]
|
87 |
+
|
88 |
+
# compute iou between each dt and gt region
|
89 |
+
iscrowd = [int(o["iscrowd"]) for o in gt]
|
90 |
+
|
91 |
+
# Note: this function is copied from cocoeval.py in cocoapi
|
92 |
+
# and the major difference is here.
|
93 |
+
ious = self.compute_iou_dt_gt(d, g, iscrowd)
|
94 |
+
return ious
|
95 |
+
|
96 |
+
|
97 |
+
class RotatedCOCOEvaluator(COCOEvaluator):
|
98 |
+
"""
|
99 |
+
Evaluate object proposal/instance detection outputs using COCO-like metrics and APIs,
|
100 |
+
with rotated boxes support.
|
101 |
+
Note: this uses IOU only and does not consider angle differences.
|
102 |
+
"""
|
103 |
+
|
104 |
+
def process(self, inputs, outputs):
|
105 |
+
"""
|
106 |
+
Args:
|
107 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
108 |
+
It is a list of dict. Each dict corresponds to an image and
|
109 |
+
contains keys like "height", "width", "file_name", "image_id".
|
110 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
111 |
+
"instances" that contains :class:`Instances`.
|
112 |
+
"""
|
113 |
+
for input, output in zip(inputs, outputs):
|
114 |
+
prediction = {"image_id": input["image_id"]}
|
115 |
+
|
116 |
+
if "instances" in output:
|
117 |
+
instances = output["instances"].to(self._cpu_device)
|
118 |
+
|
119 |
+
prediction["instances"] = self.instances_to_json(instances, input["image_id"])
|
120 |
+
if "proposals" in output:
|
121 |
+
prediction["proposals"] = output["proposals"].to(self._cpu_device)
|
122 |
+
self._predictions.append(prediction)
|
123 |
+
|
124 |
+
def instances_to_json(self, instances, img_id):
|
125 |
+
num_instance = len(instances)
|
126 |
+
if num_instance == 0:
|
127 |
+
return []
|
128 |
+
|
129 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
130 |
+
if boxes.shape[1] == 4:
|
131 |
+
boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
132 |
+
boxes = boxes.tolist()
|
133 |
+
scores = instances.scores.tolist()
|
134 |
+
classes = instances.pred_classes.tolist()
|
135 |
+
|
136 |
+
results = []
|
137 |
+
for k in range(num_instance):
|
138 |
+
result = {
|
139 |
+
"image_id": img_id,
|
140 |
+
"category_id": classes[k],
|
141 |
+
"bbox": boxes[k],
|
142 |
+
"score": scores[k],
|
143 |
+
}
|
144 |
+
|
145 |
+
results.append(result)
|
146 |
+
return results
|
147 |
+
|
148 |
+
def _eval_predictions(self, predictions, img_ids=None): # img_ids: unused
|
149 |
+
"""
|
150 |
+
Evaluate predictions on the given tasks.
|
151 |
+
Fill self._results with the metrics of the tasks.
|
152 |
+
"""
|
153 |
+
self._logger.info("Preparing results for COCO format ...")
|
154 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
155 |
+
|
156 |
+
# unmap the category ids for COCO
|
157 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
158 |
+
reverse_id_mapping = {
|
159 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
160 |
+
}
|
161 |
+
for result in coco_results:
|
162 |
+
result["category_id"] = reverse_id_mapping[result["category_id"]]
|
163 |
+
|
164 |
+
if self._output_dir:
|
165 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
166 |
+
self._logger.info("Saving results to {}".format(file_path))
|
167 |
+
with PathManager.open(file_path, "w") as f:
|
168 |
+
f.write(json.dumps(coco_results))
|
169 |
+
f.flush()
|
170 |
+
|
171 |
+
if not self._do_evaluation:
|
172 |
+
self._logger.info("Annotations are not available for evaluation.")
|
173 |
+
return
|
174 |
+
|
175 |
+
self._logger.info("Evaluating predictions ...")
|
176 |
+
|
177 |
+
assert self._tasks is None or set(self._tasks) == {
|
178 |
+
"bbox"
|
179 |
+
}, "[RotatedCOCOEvaluator] Only bbox evaluation is supported"
|
180 |
+
coco_eval = (
|
181 |
+
self._evaluate_predictions_on_coco(self._coco_api, coco_results)
|
182 |
+
if len(coco_results) > 0
|
183 |
+
else None # cocoapi does not handle empty results very well
|
184 |
+
)
|
185 |
+
|
186 |
+
task = "bbox"
|
187 |
+
res = self._derive_coco_results(
|
188 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
189 |
+
)
|
190 |
+
self._results[task] = res
|
191 |
+
|
192 |
+
def _evaluate_predictions_on_coco(self, coco_gt, coco_results):
|
193 |
+
"""
|
194 |
+
Evaluate the coco results using COCOEval API.
|
195 |
+
"""
|
196 |
+
assert len(coco_results) > 0
|
197 |
+
|
198 |
+
coco_dt = coco_gt.loadRes(coco_results)
|
199 |
+
|
200 |
+
# Only bbox is supported for now
|
201 |
+
coco_eval = RotatedCOCOeval(coco_gt, coco_dt, iouType="bbox")
|
202 |
+
|
203 |
+
coco_eval.evaluate()
|
204 |
+
coco_eval.accumulate()
|
205 |
+
coco_eval.summarize()
|
206 |
+
|
207 |
+
return coco_eval
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/sem_seg_evaluation.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from collections import OrderedDict
|
8 |
+
from typing import Optional, Union
|
9 |
+
import annotator.oneformer.pycocotools.mask as mask_util
|
10 |
+
import torch
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
|
14 |
+
from annotator.oneformer.detectron2.utils.comm import all_gather, is_main_process, synchronize
|
15 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
16 |
+
|
17 |
+
from .evaluator import DatasetEvaluator
|
18 |
+
|
19 |
+
_CV2_IMPORTED = True
|
20 |
+
try:
|
21 |
+
import cv2 # noqa
|
22 |
+
except ImportError:
|
23 |
+
# OpenCV is an optional dependency at the moment
|
24 |
+
_CV2_IMPORTED = False
|
25 |
+
|
26 |
+
|
27 |
+
def load_image_into_numpy_array(
|
28 |
+
filename: str,
|
29 |
+
copy: bool = False,
|
30 |
+
dtype: Optional[Union[np.dtype, str]] = None,
|
31 |
+
) -> np.ndarray:
|
32 |
+
with PathManager.open(filename, "rb") as f:
|
33 |
+
array = np.array(Image.open(f), copy=copy, dtype=dtype)
|
34 |
+
return array
|
35 |
+
|
36 |
+
|
37 |
+
class SemSegEvaluator(DatasetEvaluator):
|
38 |
+
"""
|
39 |
+
Evaluate semantic segmentation metrics.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
dataset_name,
|
45 |
+
distributed=True,
|
46 |
+
output_dir=None,
|
47 |
+
*,
|
48 |
+
sem_seg_loading_fn=load_image_into_numpy_array,
|
49 |
+
num_classes=None,
|
50 |
+
ignore_label=None,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
dataset_name (str): name of the dataset to be evaluated.
|
55 |
+
distributed (bool): if True, will collect results from all ranks for evaluation.
|
56 |
+
Otherwise, will evaluate the results in the current process.
|
57 |
+
output_dir (str): an output directory to dump results.
|
58 |
+
sem_seg_loading_fn: function to read sem seg file and load into numpy array.
|
59 |
+
Default provided, but projects can customize.
|
60 |
+
num_classes, ignore_label: deprecated argument
|
61 |
+
"""
|
62 |
+
self._logger = logging.getLogger(__name__)
|
63 |
+
if num_classes is not None:
|
64 |
+
self._logger.warn(
|
65 |
+
"SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata."
|
66 |
+
)
|
67 |
+
if ignore_label is not None:
|
68 |
+
self._logger.warn(
|
69 |
+
"SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata."
|
70 |
+
)
|
71 |
+
self._dataset_name = dataset_name
|
72 |
+
self._distributed = distributed
|
73 |
+
self._output_dir = output_dir
|
74 |
+
|
75 |
+
self._cpu_device = torch.device("cpu")
|
76 |
+
|
77 |
+
self.input_file_to_gt_file = {
|
78 |
+
dataset_record["file_name"]: dataset_record["sem_seg_file_name"]
|
79 |
+
for dataset_record in DatasetCatalog.get(dataset_name)
|
80 |
+
}
|
81 |
+
|
82 |
+
meta = MetadataCatalog.get(dataset_name)
|
83 |
+
# Dict that maps contiguous training ids to COCO category ids
|
84 |
+
try:
|
85 |
+
c2d = meta.stuff_dataset_id_to_contiguous_id
|
86 |
+
self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()}
|
87 |
+
except AttributeError:
|
88 |
+
self._contiguous_id_to_dataset_id = None
|
89 |
+
self._class_names = meta.stuff_classes
|
90 |
+
self.sem_seg_loading_fn = sem_seg_loading_fn
|
91 |
+
self._num_classes = len(meta.stuff_classes)
|
92 |
+
if num_classes is not None:
|
93 |
+
assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}"
|
94 |
+
self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label
|
95 |
+
|
96 |
+
# This is because cv2.erode did not work for int datatype. Only works for uint8.
|
97 |
+
self._compute_boundary_iou = True
|
98 |
+
if not _CV2_IMPORTED:
|
99 |
+
self._compute_boundary_iou = False
|
100 |
+
self._logger.warn(
|
101 |
+
"""Boundary IoU calculation requires OpenCV. B-IoU metrics are
|
102 |
+
not going to be computed because OpenCV is not available to import."""
|
103 |
+
)
|
104 |
+
if self._num_classes >= np.iinfo(np.uint8).max:
|
105 |
+
self._compute_boundary_iou = False
|
106 |
+
self._logger.warn(
|
107 |
+
f"""SemSegEvaluator(num_classes) is more than supported value for Boundary IoU calculation!
|
108 |
+
B-IoU metrics are not going to be computed. Max allowed value (exclusive)
|
109 |
+
for num_classes for calculating Boundary IoU is {np.iinfo(np.uint8).max}.
|
110 |
+
The number of classes of dataset {self._dataset_name} is {self._num_classes}"""
|
111 |
+
)
|
112 |
+
|
113 |
+
def reset(self):
|
114 |
+
self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64)
|
115 |
+
self._b_conf_matrix = np.zeros(
|
116 |
+
(self._num_classes + 1, self._num_classes + 1), dtype=np.int64
|
117 |
+
)
|
118 |
+
self._predictions = []
|
119 |
+
|
120 |
+
def process(self, inputs, outputs):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
inputs: the inputs to a model.
|
124 |
+
It is a list of dicts. Each dict corresponds to an image and
|
125 |
+
contains keys like "height", "width", "file_name".
|
126 |
+
outputs: the outputs of a model. It is either list of semantic segmentation predictions
|
127 |
+
(Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
|
128 |
+
segmentation prediction in the same format.
|
129 |
+
"""
|
130 |
+
for input, output in zip(inputs, outputs):
|
131 |
+
output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
|
132 |
+
pred = np.array(output, dtype=np.int)
|
133 |
+
gt_filename = self.input_file_to_gt_file[input["file_name"]]
|
134 |
+
gt = self.sem_seg_loading_fn(gt_filename, dtype=np.int)
|
135 |
+
|
136 |
+
gt[gt == self._ignore_label] = self._num_classes
|
137 |
+
|
138 |
+
self._conf_matrix += np.bincount(
|
139 |
+
(self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
|
140 |
+
minlength=self._conf_matrix.size,
|
141 |
+
).reshape(self._conf_matrix.shape)
|
142 |
+
|
143 |
+
if self._compute_boundary_iou:
|
144 |
+
b_gt = self._mask_to_boundary(gt.astype(np.uint8))
|
145 |
+
b_pred = self._mask_to_boundary(pred.astype(np.uint8))
|
146 |
+
|
147 |
+
self._b_conf_matrix += np.bincount(
|
148 |
+
(self._num_classes + 1) * b_pred.reshape(-1) + b_gt.reshape(-1),
|
149 |
+
minlength=self._conf_matrix.size,
|
150 |
+
).reshape(self._conf_matrix.shape)
|
151 |
+
|
152 |
+
self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
|
153 |
+
|
154 |
+
def evaluate(self):
|
155 |
+
"""
|
156 |
+
Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
|
157 |
+
|
158 |
+
* Mean intersection-over-union averaged across classes (mIoU)
|
159 |
+
* Frequency Weighted IoU (fwIoU)
|
160 |
+
* Mean pixel accuracy averaged across classes (mACC)
|
161 |
+
* Pixel Accuracy (pACC)
|
162 |
+
"""
|
163 |
+
if self._distributed:
|
164 |
+
synchronize()
|
165 |
+
conf_matrix_list = all_gather(self._conf_matrix)
|
166 |
+
b_conf_matrix_list = all_gather(self._b_conf_matrix)
|
167 |
+
self._predictions = all_gather(self._predictions)
|
168 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
169 |
+
if not is_main_process():
|
170 |
+
return
|
171 |
+
|
172 |
+
self._conf_matrix = np.zeros_like(self._conf_matrix)
|
173 |
+
for conf_matrix in conf_matrix_list:
|
174 |
+
self._conf_matrix += conf_matrix
|
175 |
+
|
176 |
+
self._b_conf_matrix = np.zeros_like(self._b_conf_matrix)
|
177 |
+
for b_conf_matrix in b_conf_matrix_list:
|
178 |
+
self._b_conf_matrix += b_conf_matrix
|
179 |
+
|
180 |
+
if self._output_dir:
|
181 |
+
PathManager.mkdirs(self._output_dir)
|
182 |
+
file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
|
183 |
+
with PathManager.open(file_path, "w") as f:
|
184 |
+
f.write(json.dumps(self._predictions))
|
185 |
+
|
186 |
+
acc = np.full(self._num_classes, np.nan, dtype=np.float)
|
187 |
+
iou = np.full(self._num_classes, np.nan, dtype=np.float)
|
188 |
+
tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
|
189 |
+
pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
|
190 |
+
class_weights = pos_gt / np.sum(pos_gt)
|
191 |
+
pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
|
192 |
+
acc_valid = pos_gt > 0
|
193 |
+
acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
|
194 |
+
union = pos_gt + pos_pred - tp
|
195 |
+
iou_valid = np.logical_and(acc_valid, union > 0)
|
196 |
+
iou[iou_valid] = tp[iou_valid] / union[iou_valid]
|
197 |
+
macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
|
198 |
+
miou = np.sum(iou[iou_valid]) / np.sum(iou_valid)
|
199 |
+
fiou = np.sum(iou[iou_valid] * class_weights[iou_valid])
|
200 |
+
pacc = np.sum(tp) / np.sum(pos_gt)
|
201 |
+
|
202 |
+
if self._compute_boundary_iou:
|
203 |
+
b_iou = np.full(self._num_classes, np.nan, dtype=np.float)
|
204 |
+
b_tp = self._b_conf_matrix.diagonal()[:-1].astype(np.float)
|
205 |
+
b_pos_gt = np.sum(self._b_conf_matrix[:-1, :-1], axis=0).astype(np.float)
|
206 |
+
b_pos_pred = np.sum(self._b_conf_matrix[:-1, :-1], axis=1).astype(np.float)
|
207 |
+
b_union = b_pos_gt + b_pos_pred - b_tp
|
208 |
+
b_iou_valid = b_union > 0
|
209 |
+
b_iou[b_iou_valid] = b_tp[b_iou_valid] / b_union[b_iou_valid]
|
210 |
+
|
211 |
+
res = {}
|
212 |
+
res["mIoU"] = 100 * miou
|
213 |
+
res["fwIoU"] = 100 * fiou
|
214 |
+
for i, name in enumerate(self._class_names):
|
215 |
+
res[f"IoU-{name}"] = 100 * iou[i]
|
216 |
+
if self._compute_boundary_iou:
|
217 |
+
res[f"BoundaryIoU-{name}"] = 100 * b_iou[i]
|
218 |
+
res[f"min(IoU, B-Iou)-{name}"] = 100 * min(iou[i], b_iou[i])
|
219 |
+
res["mACC"] = 100 * macc
|
220 |
+
res["pACC"] = 100 * pacc
|
221 |
+
for i, name in enumerate(self._class_names):
|
222 |
+
res[f"ACC-{name}"] = 100 * acc[i]
|
223 |
+
|
224 |
+
if self._output_dir:
|
225 |
+
file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
|
226 |
+
with PathManager.open(file_path, "wb") as f:
|
227 |
+
torch.save(res, f)
|
228 |
+
results = OrderedDict({"sem_seg": res})
|
229 |
+
self._logger.info(results)
|
230 |
+
return results
|
231 |
+
|
232 |
+
def encode_json_sem_seg(self, sem_seg, input_file_name):
|
233 |
+
"""
|
234 |
+
Convert semantic segmentation to COCO stuff format with segments encoded as RLEs.
|
235 |
+
See http://cocodataset.org/#format-results
|
236 |
+
"""
|
237 |
+
json_list = []
|
238 |
+
for label in np.unique(sem_seg):
|
239 |
+
if self._contiguous_id_to_dataset_id is not None:
|
240 |
+
assert (
|
241 |
+
label in self._contiguous_id_to_dataset_id
|
242 |
+
), "Label {} is not in the metadata info for {}".format(label, self._dataset_name)
|
243 |
+
dataset_id = self._contiguous_id_to_dataset_id[label]
|
244 |
+
else:
|
245 |
+
dataset_id = int(label)
|
246 |
+
mask = (sem_seg == label).astype(np.uint8)
|
247 |
+
mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0]
|
248 |
+
mask_rle["counts"] = mask_rle["counts"].decode("utf-8")
|
249 |
+
json_list.append(
|
250 |
+
{"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle}
|
251 |
+
)
|
252 |
+
return json_list
|
253 |
+
|
254 |
+
def _mask_to_boundary(self, mask: np.ndarray, dilation_ratio=0.02):
|
255 |
+
assert mask.ndim == 2, "mask_to_boundary expects a 2-dimensional image"
|
256 |
+
h, w = mask.shape
|
257 |
+
diag_len = np.sqrt(h**2 + w**2)
|
258 |
+
dilation = max(1, int(round(dilation_ratio * diag_len)))
|
259 |
+
kernel = np.ones((3, 3), dtype=np.uint8)
|
260 |
+
|
261 |
+
padded_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
|
262 |
+
eroded_mask_with_padding = cv2.erode(padded_mask, kernel, iterations=dilation)
|
263 |
+
eroded_mask = eroded_mask_with_padding[1:-1, 1:-1]
|
264 |
+
boundary = mask - eroded_mask
|
265 |
+
return boundary
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/evaluation/testing.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
import numpy as np
|
4 |
+
import pprint
|
5 |
+
import sys
|
6 |
+
from collections.abc import Mapping
|
7 |
+
|
8 |
+
|
9 |
+
def print_csv_format(results):
|
10 |
+
"""
|
11 |
+
Print main metrics in a format similar to Detectron,
|
12 |
+
so that they are easy to copypaste into a spreadsheet.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
results (OrderedDict[dict]): task_name -> {metric -> score}
|
16 |
+
unordered dict can also be printed, but in arbitrary order
|
17 |
+
"""
|
18 |
+
assert isinstance(results, Mapping) or not len(results), results
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
for task, res in results.items():
|
21 |
+
if isinstance(res, Mapping):
|
22 |
+
# Don't print "AP-category" metrics since they are usually not tracked.
|
23 |
+
important_res = [(k, v) for k, v in res.items() if "-" not in k]
|
24 |
+
logger.info("copypaste: Task: {}".format(task))
|
25 |
+
logger.info("copypaste: " + ",".join([k[0] for k in important_res]))
|
26 |
+
logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res]))
|
27 |
+
else:
|
28 |
+
logger.info(f"copypaste: {task}={res}")
|
29 |
+
|
30 |
+
|
31 |
+
def verify_results(cfg, results):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
results (OrderedDict[dict]): task_name -> {metric -> score}
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
bool: whether the verification succeeds or not
|
38 |
+
"""
|
39 |
+
expected_results = cfg.TEST.EXPECTED_RESULTS
|
40 |
+
if not len(expected_results):
|
41 |
+
return True
|
42 |
+
|
43 |
+
ok = True
|
44 |
+
for task, metric, expected, tolerance in expected_results:
|
45 |
+
actual = results[task].get(metric, None)
|
46 |
+
if actual is None:
|
47 |
+
ok = False
|
48 |
+
continue
|
49 |
+
if not np.isfinite(actual):
|
50 |
+
ok = False
|
51 |
+
continue
|
52 |
+
diff = abs(actual - expected)
|
53 |
+
if diff > tolerance:
|
54 |
+
ok = False
|
55 |
+
|
56 |
+
logger = logging.getLogger(__name__)
|
57 |
+
if not ok:
|
58 |
+
logger.error("Result verification failed!")
|
59 |
+
logger.error("Expected Results: " + str(expected_results))
|
60 |
+
logger.error("Actual Results: " + pprint.pformat(results))
|
61 |
+
|
62 |
+
sys.exit(1)
|
63 |
+
else:
|
64 |
+
logger.info("Results verification passed.")
|
65 |
+
return ok
|
66 |
+
|
67 |
+
|
68 |
+
def flatten_results_dict(results):
|
69 |
+
"""
|
70 |
+
Expand a hierarchical dict of scalars into a flat dict of scalars.
|
71 |
+
If results[k1][k2][k3] = v, the returned dict will have the entry
|
72 |
+
{"k1/k2/k3": v}.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
results (dict):
|
76 |
+
"""
|
77 |
+
r = {}
|
78 |
+
for k, v in results.items():
|
79 |
+
if isinstance(v, Mapping):
|
80 |
+
v = flatten_results_dict(v)
|
81 |
+
for kk, vv in v.items():
|
82 |
+
r[k + "/" + kk] = vv
|
83 |
+
else:
|
84 |
+
r[k] = v
|
85 |
+
return r
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
This directory contains code to prepare a detectron2 model for deployment.
|
3 |
+
Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format.
|
4 |
+
|
5 |
+
Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage.
|
6 |
+
|
7 |
+
|
8 |
+
### Acknowledgements
|
9 |
+
|
10 |
+
Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools.
|
11 |
+
|
12 |
+
Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who
|
13 |
+
help export Detectron2 models to TorchScript.
|
14 |
+
|
15 |
+
Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX.
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
from .flatten import TracingAdapter
|
6 |
+
from .torchscript import dump_torchscript_IR, scripting_with_instances
|
7 |
+
|
8 |
+
try:
|
9 |
+
from caffe2.proto import caffe2_pb2 as _tmp
|
10 |
+
from caffe2.python import core
|
11 |
+
|
12 |
+
# caffe2 is optional
|
13 |
+
except ImportError:
|
14 |
+
pass
|
15 |
+
else:
|
16 |
+
from .api import *
|
17 |
+
|
18 |
+
|
19 |
+
# TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported
|
20 |
+
STABLE_ONNX_OPSET_VERSION = 11
|
21 |
+
|
22 |
+
|
23 |
+
def add_export_config(cfg):
|
24 |
+
warnings.warn(
|
25 |
+
"add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning
|
26 |
+
)
|
27 |
+
return cfg
|
28 |
+
|
29 |
+
|
30 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/api.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from caffe2.proto import caffe2_pb2
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from annotator.oneformer.detectron2.config import CfgNode
|
10 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
11 |
+
|
12 |
+
from .caffe2_inference import ProtobufDetectionModel
|
13 |
+
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
|
14 |
+
from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"Caffe2Model",
|
18 |
+
"Caffe2Tracer",
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
class Caffe2Tracer:
|
23 |
+
"""
|
24 |
+
Make a detectron2 model traceable with Caffe2 operators.
|
25 |
+
This class creates a traceable version of a detectron2 model which:
|
26 |
+
|
27 |
+
1. Rewrite parts of the model using ops in Caffe2. Note that some ops do
|
28 |
+
not have GPU implementation in Caffe2.
|
29 |
+
2. Remove post-processing and only produce raw layer outputs
|
30 |
+
|
31 |
+
After making a traceable model, the class provide methods to export such a
|
32 |
+
model to different deployment formats.
|
33 |
+
Exported graph produced by this class take two input tensors:
|
34 |
+
|
35 |
+
1. (1, C, H, W) float "data" which is an image (usually in [0, 255]).
|
36 |
+
(H, W) often has to be padded to multiple of 32 (depend on the model
|
37 |
+
architecture).
|
38 |
+
2. 1x3 float "im_info", each row of which is (height, width, 1.0).
|
39 |
+
Height and width are true image shapes before padding.
|
40 |
+
|
41 |
+
The class currently only supports models using builtin meta architectures.
|
42 |
+
Batch inference is not supported, and contributions are welcome.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, cfg: CfgNode, model: nn.Module, inputs):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model.
|
49 |
+
model (nn.Module): An original pytorch model. Must be among a few official models
|
50 |
+
in detectron2 that can be converted to become caffe2-compatible automatically.
|
51 |
+
Weights have to be already loaded to this model.
|
52 |
+
inputs: sample inputs that the given model takes for inference.
|
53 |
+
Will be used to trace the model. For most models, random inputs with
|
54 |
+
no detected objects will not work as they lead to wrong traces.
|
55 |
+
"""
|
56 |
+
assert isinstance(cfg, CfgNode), cfg
|
57 |
+
assert isinstance(model, torch.nn.Module), type(model)
|
58 |
+
|
59 |
+
# TODO make it support custom models, by passing in c2 model directly
|
60 |
+
C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
|
61 |
+
self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model))
|
62 |
+
self.inputs = inputs
|
63 |
+
self.traceable_inputs = self.traceable_model.get_caffe2_inputs(inputs)
|
64 |
+
|
65 |
+
def export_caffe2(self):
|
66 |
+
"""
|
67 |
+
Export the model to Caffe2's protobuf format.
|
68 |
+
The returned object can be saved with its :meth:`.save_protobuf()` method.
|
69 |
+
The result can be loaded and executed using Caffe2 runtime.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
:class:`Caffe2Model`
|
73 |
+
"""
|
74 |
+
from .caffe2_export import export_caffe2_detection_model
|
75 |
+
|
76 |
+
predict_net, init_net = export_caffe2_detection_model(
|
77 |
+
self.traceable_model, self.traceable_inputs
|
78 |
+
)
|
79 |
+
return Caffe2Model(predict_net, init_net)
|
80 |
+
|
81 |
+
def export_onnx(self):
|
82 |
+
"""
|
83 |
+
Export the model to ONNX format.
|
84 |
+
Note that the exported model contains custom ops only available in caffe2, therefore it
|
85 |
+
cannot be directly executed by other runtime (such as onnxruntime or TensorRT).
|
86 |
+
Post-processing or transformation passes may be applied on the model to accommodate
|
87 |
+
different runtimes, but we currently do not provide support for them.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
onnx.ModelProto: an onnx model.
|
91 |
+
"""
|
92 |
+
from .caffe2_export import export_onnx_model as export_onnx_model_impl
|
93 |
+
|
94 |
+
return export_onnx_model_impl(self.traceable_model, (self.traceable_inputs,))
|
95 |
+
|
96 |
+
def export_torchscript(self):
|
97 |
+
"""
|
98 |
+
Export the model to a ``torch.jit.TracedModule`` by tracing.
|
99 |
+
The returned object can be saved to a file by ``.save()``.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
torch.jit.TracedModule: a torch TracedModule
|
103 |
+
"""
|
104 |
+
logger = logging.getLogger(__name__)
|
105 |
+
logger.info("Tracing the model with torch.jit.trace ...")
|
106 |
+
with torch.no_grad():
|
107 |
+
return torch.jit.trace(self.traceable_model, (self.traceable_inputs,))
|
108 |
+
|
109 |
+
|
110 |
+
class Caffe2Model(nn.Module):
|
111 |
+
"""
|
112 |
+
A wrapper around the traced model in Caffe2's protobuf format.
|
113 |
+
The exported graph has different inputs/outputs from the original Pytorch
|
114 |
+
model, as explained in :class:`Caffe2Tracer`. This class wraps around the
|
115 |
+
exported graph to simulate the same interface as the original Pytorch model.
|
116 |
+
It also provides functions to save/load models in Caffe2's format.'
|
117 |
+
|
118 |
+
Examples:
|
119 |
+
::
|
120 |
+
c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2()
|
121 |
+
inputs = [{"image": img_tensor_CHW}]
|
122 |
+
outputs = c2_model(inputs)
|
123 |
+
orig_outputs = torch_model(inputs)
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self, predict_net, init_net):
|
127 |
+
super().__init__()
|
128 |
+
self.eval() # always in eval mode
|
129 |
+
self._predict_net = predict_net
|
130 |
+
self._init_net = init_net
|
131 |
+
self._predictor = None
|
132 |
+
|
133 |
+
__init__.__HIDE_SPHINX_DOC__ = True
|
134 |
+
|
135 |
+
@property
|
136 |
+
def predict_net(self):
|
137 |
+
"""
|
138 |
+
caffe2.core.Net: the underlying caffe2 predict net
|
139 |
+
"""
|
140 |
+
return self._predict_net
|
141 |
+
|
142 |
+
@property
|
143 |
+
def init_net(self):
|
144 |
+
"""
|
145 |
+
caffe2.core.Net: the underlying caffe2 init net
|
146 |
+
"""
|
147 |
+
return self._init_net
|
148 |
+
|
149 |
+
def save_protobuf(self, output_dir):
|
150 |
+
"""
|
151 |
+
Save the model as caffe2's protobuf format.
|
152 |
+
It saves the following files:
|
153 |
+
|
154 |
+
* "model.pb": definition of the graph. Can be visualized with
|
155 |
+
tools like `netron <https://github.com/lutzroeder/netron>`_.
|
156 |
+
* "model_init.pb": model parameters
|
157 |
+
* "model.pbtxt": human-readable definition of the graph. Not
|
158 |
+
needed for deployment.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
output_dir (str): the output directory to save protobuf files.
|
162 |
+
"""
|
163 |
+
logger = logging.getLogger(__name__)
|
164 |
+
logger.info("Saving model to {} ...".format(output_dir))
|
165 |
+
if not PathManager.exists(output_dir):
|
166 |
+
PathManager.mkdirs(output_dir)
|
167 |
+
|
168 |
+
with PathManager.open(os.path.join(output_dir, "model.pb"), "wb") as f:
|
169 |
+
f.write(self._predict_net.SerializeToString())
|
170 |
+
with PathManager.open(os.path.join(output_dir, "model.pbtxt"), "w") as f:
|
171 |
+
f.write(str(self._predict_net))
|
172 |
+
with PathManager.open(os.path.join(output_dir, "model_init.pb"), "wb") as f:
|
173 |
+
f.write(self._init_net.SerializeToString())
|
174 |
+
|
175 |
+
def save_graph(self, output_file, inputs=None):
|
176 |
+
"""
|
177 |
+
Save the graph as SVG format.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
output_file (str): a SVG file
|
181 |
+
inputs: optional inputs given to the model.
|
182 |
+
If given, the inputs will be used to run the graph to record
|
183 |
+
shape of every tensor. The shape information will be
|
184 |
+
saved together with the graph.
|
185 |
+
"""
|
186 |
+
from .caffe2_export import run_and_save_graph
|
187 |
+
|
188 |
+
if inputs is None:
|
189 |
+
save_graph(self._predict_net, output_file, op_only=False)
|
190 |
+
else:
|
191 |
+
size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0)
|
192 |
+
device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii")
|
193 |
+
inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device)
|
194 |
+
inputs = [x.cpu().numpy() for x in inputs]
|
195 |
+
run_and_save_graph(self._predict_net, self._init_net, inputs, output_file)
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
def load_protobuf(dir):
|
199 |
+
"""
|
200 |
+
Args:
|
201 |
+
dir (str): a directory used to save Caffe2Model with
|
202 |
+
:meth:`save_protobuf`.
|
203 |
+
The files "model.pb" and "model_init.pb" are needed.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Caffe2Model: the caffe2 model loaded from this directory.
|
207 |
+
"""
|
208 |
+
predict_net = caffe2_pb2.NetDef()
|
209 |
+
with PathManager.open(os.path.join(dir, "model.pb"), "rb") as f:
|
210 |
+
predict_net.ParseFromString(f.read())
|
211 |
+
|
212 |
+
init_net = caffe2_pb2.NetDef()
|
213 |
+
with PathManager.open(os.path.join(dir, "model_init.pb"), "rb") as f:
|
214 |
+
init_net.ParseFromString(f.read())
|
215 |
+
|
216 |
+
return Caffe2Model(predict_net, init_net)
|
217 |
+
|
218 |
+
def __call__(self, inputs):
|
219 |
+
"""
|
220 |
+
An interface that wraps around a Caffe2 model and mimics detectron2's models'
|
221 |
+
input/output format. See details about the format at :doc:`/tutorials/models`.
|
222 |
+
This is used to compare the outputs of caffe2 model with its original torch model.
|
223 |
+
|
224 |
+
Due to the extra conversion between Pytorch/Caffe2, this method is not meant for
|
225 |
+
benchmark. Because of the conversion, this method also has dependency
|
226 |
+
on detectron2 in order to convert to detectron2's output format.
|
227 |
+
"""
|
228 |
+
if self._predictor is None:
|
229 |
+
self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net)
|
230 |
+
return self._predictor(inputs)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/c10.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Dict
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from annotator.oneformer.detectron2.layers import ShapeSpec, cat
|
9 |
+
from annotator.oneformer.detectron2.layers.roi_align_rotated import ROIAlignRotated
|
10 |
+
from annotator.oneformer.detectron2.modeling import poolers
|
11 |
+
from annotator.oneformer.detectron2.modeling.proposal_generator import rpn
|
12 |
+
from annotator.oneformer.detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference
|
13 |
+
from annotator.oneformer.detectron2.structures import Boxes, ImageList, Instances, Keypoints, RotatedBoxes
|
14 |
+
|
15 |
+
from .shared import alias, to_device
|
16 |
+
|
17 |
+
|
18 |
+
"""
|
19 |
+
This file contains caffe2-compatible implementation of several detectron2 components.
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
class Caffe2Boxes(Boxes):
|
24 |
+
"""
|
25 |
+
Representing a list of detectron2.structures.Boxes from minibatch, each box
|
26 |
+
is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector
|
27 |
+
(batch index + 5 coordinates) for RotatedBoxes.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, tensor):
|
31 |
+
assert isinstance(tensor, torch.Tensor)
|
32 |
+
assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size()
|
33 |
+
# TODO: make tensor immutable when dim is Nx5 for Boxes,
|
34 |
+
# and Nx6 for RotatedBoxes?
|
35 |
+
self.tensor = tensor
|
36 |
+
|
37 |
+
|
38 |
+
# TODO clean up this class, maybe just extend Instances
|
39 |
+
class InstancesList(object):
|
40 |
+
"""
|
41 |
+
Tensor representation of a list of Instances object for a batch of images.
|
42 |
+
|
43 |
+
When dealing with a batch of images with Caffe2 ops, a list of bboxes
|
44 |
+
(instances) are usually represented by single Tensor with size
|
45 |
+
(sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is
|
46 |
+
for providing common functions to convert between these two representations.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, im_info, indices, extra_fields=None):
|
50 |
+
# [N, 3] -> (H, W, Scale)
|
51 |
+
self.im_info = im_info
|
52 |
+
# [N,] -> indice of batch to which the instance belongs
|
53 |
+
self.indices = indices
|
54 |
+
# [N, ...]
|
55 |
+
self.batch_extra_fields = extra_fields or {}
|
56 |
+
|
57 |
+
self.image_size = self.im_info
|
58 |
+
|
59 |
+
def get_fields(self):
|
60 |
+
"""like `get_fields` in the Instances object,
|
61 |
+
but return each field in tensor representations"""
|
62 |
+
ret = {}
|
63 |
+
for k, v in self.batch_extra_fields.items():
|
64 |
+
# if isinstance(v, torch.Tensor):
|
65 |
+
# tensor_rep = v
|
66 |
+
# elif isinstance(v, (Boxes, Keypoints)):
|
67 |
+
# tensor_rep = v.tensor
|
68 |
+
# else:
|
69 |
+
# raise ValueError("Can't find tensor representation for: {}".format())
|
70 |
+
ret[k] = v
|
71 |
+
return ret
|
72 |
+
|
73 |
+
def has(self, name):
|
74 |
+
return name in self.batch_extra_fields
|
75 |
+
|
76 |
+
def set(self, name, value):
|
77 |
+
# len(tensor) is a bad practice that generates ONNX constants during tracing.
|
78 |
+
# Although not a problem for the `assert` statement below, torch ONNX exporter
|
79 |
+
# still raises a misleading warning as it does not this call comes from `assert`
|
80 |
+
if isinstance(value, Boxes):
|
81 |
+
data_len = value.tensor.shape[0]
|
82 |
+
elif isinstance(value, torch.Tensor):
|
83 |
+
data_len = value.shape[0]
|
84 |
+
else:
|
85 |
+
data_len = len(value)
|
86 |
+
if len(self.batch_extra_fields):
|
87 |
+
assert (
|
88 |
+
len(self) == data_len
|
89 |
+
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
|
90 |
+
self.batch_extra_fields[name] = value
|
91 |
+
|
92 |
+
def __getattr__(self, name):
|
93 |
+
if name not in self.batch_extra_fields:
|
94 |
+
raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
|
95 |
+
return self.batch_extra_fields[name]
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
return len(self.indices)
|
99 |
+
|
100 |
+
def flatten(self):
|
101 |
+
ret = []
|
102 |
+
for _, v in self.batch_extra_fields.items():
|
103 |
+
if isinstance(v, (Boxes, Keypoints)):
|
104 |
+
ret.append(v.tensor)
|
105 |
+
else:
|
106 |
+
ret.append(v)
|
107 |
+
return ret
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def to_d2_instances_list(instances_list):
|
111 |
+
"""
|
112 |
+
Convert InstancesList to List[Instances]. The input `instances_list` can
|
113 |
+
also be a List[Instances], in this case this method is a non-op.
|
114 |
+
"""
|
115 |
+
if not isinstance(instances_list, InstancesList):
|
116 |
+
assert all(isinstance(x, Instances) for x in instances_list)
|
117 |
+
return instances_list
|
118 |
+
|
119 |
+
ret = []
|
120 |
+
for i, info in enumerate(instances_list.im_info):
|
121 |
+
instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())]))
|
122 |
+
|
123 |
+
ids = instances_list.indices == i
|
124 |
+
for k, v in instances_list.batch_extra_fields.items():
|
125 |
+
if isinstance(v, torch.Tensor):
|
126 |
+
instances.set(k, v[ids])
|
127 |
+
continue
|
128 |
+
elif isinstance(v, Boxes):
|
129 |
+
instances.set(k, v[ids, -4:])
|
130 |
+
continue
|
131 |
+
|
132 |
+
target_type, tensor_source = v
|
133 |
+
assert isinstance(tensor_source, torch.Tensor)
|
134 |
+
assert tensor_source.shape[0] == instances_list.indices.shape[0]
|
135 |
+
tensor_source = tensor_source[ids]
|
136 |
+
|
137 |
+
if issubclass(target_type, Boxes):
|
138 |
+
instances.set(k, Boxes(tensor_source[:, -4:]))
|
139 |
+
elif issubclass(target_type, Keypoints):
|
140 |
+
instances.set(k, Keypoints(tensor_source))
|
141 |
+
elif issubclass(target_type, torch.Tensor):
|
142 |
+
instances.set(k, tensor_source)
|
143 |
+
else:
|
144 |
+
raise ValueError("Can't handle targe type: {}".format(target_type))
|
145 |
+
|
146 |
+
ret.append(instances)
|
147 |
+
return ret
|
148 |
+
|
149 |
+
|
150 |
+
class Caffe2Compatible(object):
|
151 |
+
"""
|
152 |
+
A model can inherit this class to indicate that it can be traced and deployed with caffe2.
|
153 |
+
"""
|
154 |
+
|
155 |
+
def _get_tensor_mode(self):
|
156 |
+
return self._tensor_mode
|
157 |
+
|
158 |
+
def _set_tensor_mode(self, v):
|
159 |
+
self._tensor_mode = v
|
160 |
+
|
161 |
+
tensor_mode = property(_get_tensor_mode, _set_tensor_mode)
|
162 |
+
"""
|
163 |
+
If true, the model expects C2-style tensor only inputs/outputs format.
|
164 |
+
"""
|
165 |
+
|
166 |
+
|
167 |
+
class Caffe2RPN(Caffe2Compatible, rpn.RPN):
|
168 |
+
@classmethod
|
169 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
170 |
+
ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape)
|
171 |
+
assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1.0, 1.0, 1.0, 1.0) or tuple(
|
172 |
+
cfg.MODEL.RPN.BBOX_REG_WEIGHTS
|
173 |
+
) == (1.0, 1.0, 1.0, 1.0, 1.0)
|
174 |
+
return ret
|
175 |
+
|
176 |
+
def _generate_proposals(
|
177 |
+
self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None
|
178 |
+
):
|
179 |
+
assert isinstance(images, ImageList)
|
180 |
+
if self.tensor_mode:
|
181 |
+
im_info = images.image_sizes
|
182 |
+
else:
|
183 |
+
im_info = torch.tensor([[im_sz[0], im_sz[1], 1.0] for im_sz in images.image_sizes]).to(
|
184 |
+
images.tensor.device
|
185 |
+
)
|
186 |
+
assert isinstance(im_info, torch.Tensor)
|
187 |
+
|
188 |
+
rpn_rois_list = []
|
189 |
+
rpn_roi_probs_list = []
|
190 |
+
for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip(
|
191 |
+
objectness_logits_pred,
|
192 |
+
anchor_deltas_pred,
|
193 |
+
[b for (n, b) in self.anchor_generator.cell_anchors.named_buffers()],
|
194 |
+
self.anchor_generator.strides,
|
195 |
+
):
|
196 |
+
scores = scores.detach()
|
197 |
+
bbox_deltas = bbox_deltas.detach()
|
198 |
+
|
199 |
+
rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals(
|
200 |
+
scores,
|
201 |
+
bbox_deltas,
|
202 |
+
im_info,
|
203 |
+
cell_anchors_tensor,
|
204 |
+
spatial_scale=1.0 / feat_stride,
|
205 |
+
pre_nms_topN=self.pre_nms_topk[self.training],
|
206 |
+
post_nms_topN=self.post_nms_topk[self.training],
|
207 |
+
nms_thresh=self.nms_thresh,
|
208 |
+
min_size=self.min_box_size,
|
209 |
+
# correct_transform_coords=True, # deprecated argument
|
210 |
+
angle_bound_on=True, # Default
|
211 |
+
angle_bound_lo=-180,
|
212 |
+
angle_bound_hi=180,
|
213 |
+
clip_angle_thresh=1.0, # Default
|
214 |
+
legacy_plus_one=False,
|
215 |
+
)
|
216 |
+
rpn_rois_list.append(rpn_rois)
|
217 |
+
rpn_roi_probs_list.append(rpn_roi_probs)
|
218 |
+
|
219 |
+
# For FPN in D2, in RPN all proposals from different levels are concated
|
220 |
+
# together, ranked and picked by top post_nms_topk. Then in ROIPooler
|
221 |
+
# it calculates level_assignments and calls the RoIAlign from
|
222 |
+
# the corresponding level.
|
223 |
+
|
224 |
+
if len(objectness_logits_pred) == 1:
|
225 |
+
rpn_rois = rpn_rois_list[0]
|
226 |
+
rpn_roi_probs = rpn_roi_probs_list[0]
|
227 |
+
else:
|
228 |
+
assert len(rpn_rois_list) == len(rpn_roi_probs_list)
|
229 |
+
rpn_post_nms_topN = self.post_nms_topk[self.training]
|
230 |
+
|
231 |
+
device = rpn_rois_list[0].device
|
232 |
+
input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)]
|
233 |
+
|
234 |
+
# TODO remove this after confirming rpn_max_level/rpn_min_level
|
235 |
+
# is not needed in CollectRpnProposals.
|
236 |
+
feature_strides = list(self.anchor_generator.strides)
|
237 |
+
rpn_min_level = int(math.log2(feature_strides[0]))
|
238 |
+
rpn_max_level = int(math.log2(feature_strides[-1]))
|
239 |
+
assert (rpn_max_level - rpn_min_level + 1) == len(
|
240 |
+
rpn_rois_list
|
241 |
+
), "CollectRpnProposals requires continuous levels"
|
242 |
+
|
243 |
+
rpn_rois = torch.ops._caffe2.CollectRpnProposals(
|
244 |
+
input_list,
|
245 |
+
# NOTE: in current implementation, rpn_max_level and rpn_min_level
|
246 |
+
# are not needed, only the subtraction of two matters and it
|
247 |
+
# can be infer from the number of inputs. Keep them now for
|
248 |
+
# consistency.
|
249 |
+
rpn_max_level=2 + len(rpn_rois_list) - 1,
|
250 |
+
rpn_min_level=2,
|
251 |
+
rpn_post_nms_topN=rpn_post_nms_topN,
|
252 |
+
)
|
253 |
+
rpn_rois = to_device(rpn_rois, device)
|
254 |
+
rpn_roi_probs = []
|
255 |
+
|
256 |
+
proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode)
|
257 |
+
return proposals, {}
|
258 |
+
|
259 |
+
def forward(self, images, features, gt_instances=None):
|
260 |
+
assert not self.training
|
261 |
+
features = [features[f] for f in self.in_features]
|
262 |
+
objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features)
|
263 |
+
return self._generate_proposals(
|
264 |
+
images,
|
265 |
+
objectness_logits_pred,
|
266 |
+
anchor_deltas_pred,
|
267 |
+
gt_instances,
|
268 |
+
)
|
269 |
+
|
270 |
+
@staticmethod
|
271 |
+
def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode):
|
272 |
+
proposals = InstancesList(
|
273 |
+
im_info=im_info,
|
274 |
+
indices=rpn_rois[:, 0],
|
275 |
+
extra_fields={
|
276 |
+
"proposal_boxes": Caffe2Boxes(rpn_rois),
|
277 |
+
"objectness_logits": (torch.Tensor, rpn_roi_probs),
|
278 |
+
},
|
279 |
+
)
|
280 |
+
if not tensor_mode:
|
281 |
+
proposals = InstancesList.to_d2_instances_list(proposals)
|
282 |
+
else:
|
283 |
+
proposals = [proposals]
|
284 |
+
return proposals
|
285 |
+
|
286 |
+
|
287 |
+
class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler):
|
288 |
+
@staticmethod
|
289 |
+
def c2_preprocess(box_lists):
|
290 |
+
assert all(isinstance(x, Boxes) for x in box_lists)
|
291 |
+
if all(isinstance(x, Caffe2Boxes) for x in box_lists):
|
292 |
+
# input is pure-tensor based
|
293 |
+
assert len(box_lists) == 1
|
294 |
+
pooler_fmt_boxes = box_lists[0].tensor
|
295 |
+
else:
|
296 |
+
pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists)
|
297 |
+
return pooler_fmt_boxes
|
298 |
+
|
299 |
+
def forward(self, x, box_lists):
|
300 |
+
assert not self.training
|
301 |
+
|
302 |
+
pooler_fmt_boxes = self.c2_preprocess(box_lists)
|
303 |
+
num_level_assignments = len(self.level_poolers)
|
304 |
+
|
305 |
+
if num_level_assignments == 1:
|
306 |
+
if isinstance(self.level_poolers[0], ROIAlignRotated):
|
307 |
+
c2_roi_align = torch.ops._caffe2.RoIAlignRotated
|
308 |
+
aligned = True
|
309 |
+
else:
|
310 |
+
c2_roi_align = torch.ops._caffe2.RoIAlign
|
311 |
+
aligned = self.level_poolers[0].aligned
|
312 |
+
|
313 |
+
x0 = x[0]
|
314 |
+
if x0.is_quantized:
|
315 |
+
x0 = x0.dequantize()
|
316 |
+
|
317 |
+
out = c2_roi_align(
|
318 |
+
x0,
|
319 |
+
pooler_fmt_boxes,
|
320 |
+
order="NCHW",
|
321 |
+
spatial_scale=float(self.level_poolers[0].spatial_scale),
|
322 |
+
pooled_h=int(self.output_size[0]),
|
323 |
+
pooled_w=int(self.output_size[1]),
|
324 |
+
sampling_ratio=int(self.level_poolers[0].sampling_ratio),
|
325 |
+
aligned=aligned,
|
326 |
+
)
|
327 |
+
return out
|
328 |
+
|
329 |
+
device = pooler_fmt_boxes.device
|
330 |
+
assert (
|
331 |
+
self.max_level - self.min_level + 1 == 4
|
332 |
+
), "Currently DistributeFpnProposals only support 4 levels"
|
333 |
+
fpn_outputs = torch.ops._caffe2.DistributeFpnProposals(
|
334 |
+
to_device(pooler_fmt_boxes, "cpu"),
|
335 |
+
roi_canonical_scale=self.canonical_box_size,
|
336 |
+
roi_canonical_level=self.canonical_level,
|
337 |
+
roi_max_level=self.max_level,
|
338 |
+
roi_min_level=self.min_level,
|
339 |
+
legacy_plus_one=False,
|
340 |
+
)
|
341 |
+
fpn_outputs = [to_device(x, device) for x in fpn_outputs]
|
342 |
+
|
343 |
+
rois_fpn_list = fpn_outputs[:-1]
|
344 |
+
rois_idx_restore_int32 = fpn_outputs[-1]
|
345 |
+
|
346 |
+
roi_feat_fpn_list = []
|
347 |
+
for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers):
|
348 |
+
if isinstance(pooler, ROIAlignRotated):
|
349 |
+
c2_roi_align = torch.ops._caffe2.RoIAlignRotated
|
350 |
+
aligned = True
|
351 |
+
else:
|
352 |
+
c2_roi_align = torch.ops._caffe2.RoIAlign
|
353 |
+
aligned = bool(pooler.aligned)
|
354 |
+
|
355 |
+
if x_level.is_quantized:
|
356 |
+
x_level = x_level.dequantize()
|
357 |
+
|
358 |
+
roi_feat_fpn = c2_roi_align(
|
359 |
+
x_level,
|
360 |
+
roi_fpn,
|
361 |
+
order="NCHW",
|
362 |
+
spatial_scale=float(pooler.spatial_scale),
|
363 |
+
pooled_h=int(self.output_size[0]),
|
364 |
+
pooled_w=int(self.output_size[1]),
|
365 |
+
sampling_ratio=int(pooler.sampling_ratio),
|
366 |
+
aligned=aligned,
|
367 |
+
)
|
368 |
+
roi_feat_fpn_list.append(roi_feat_fpn)
|
369 |
+
|
370 |
+
roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0)
|
371 |
+
assert roi_feat_shuffled.numel() > 0 and rois_idx_restore_int32.numel() > 0, (
|
372 |
+
"Caffe2 export requires tracing with a model checkpoint + input that can produce valid"
|
373 |
+
" detections. But no detections were obtained with the given checkpoint and input!"
|
374 |
+
)
|
375 |
+
roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32)
|
376 |
+
return roi_feat
|
377 |
+
|
378 |
+
|
379 |
+
class Caffe2FastRCNNOutputsInference:
|
380 |
+
def __init__(self, tensor_mode):
|
381 |
+
self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode
|
382 |
+
|
383 |
+
def __call__(self, box_predictor, predictions, proposals):
|
384 |
+
"""equivalent to FastRCNNOutputLayers.inference"""
|
385 |
+
num_classes = box_predictor.num_classes
|
386 |
+
score_thresh = box_predictor.test_score_thresh
|
387 |
+
nms_thresh = box_predictor.test_nms_thresh
|
388 |
+
topk_per_image = box_predictor.test_topk_per_image
|
389 |
+
is_rotated = len(box_predictor.box2box_transform.weights) == 5
|
390 |
+
|
391 |
+
if is_rotated:
|
392 |
+
box_dim = 5
|
393 |
+
assert box_predictor.box2box_transform.weights[4] == 1, (
|
394 |
+
"The weights for Rotated BBoxTransform in C2 have only 4 dimensions,"
|
395 |
+
+ " thus enforcing the angle weight to be 1 for now"
|
396 |
+
)
|
397 |
+
box2box_transform_weights = box_predictor.box2box_transform.weights[:4]
|
398 |
+
else:
|
399 |
+
box_dim = 4
|
400 |
+
box2box_transform_weights = box_predictor.box2box_transform.weights
|
401 |
+
|
402 |
+
class_logits, box_regression = predictions
|
403 |
+
if num_classes + 1 == class_logits.shape[1]:
|
404 |
+
class_prob = F.softmax(class_logits, -1)
|
405 |
+
else:
|
406 |
+
assert num_classes == class_logits.shape[1]
|
407 |
+
class_prob = F.sigmoid(class_logits)
|
408 |
+
# BoxWithNMSLimit will infer num_classes from the shape of the class_prob
|
409 |
+
# So append a zero column as placeholder for the background class
|
410 |
+
class_prob = torch.cat((class_prob, torch.zeros(class_prob.shape[0], 1)), dim=1)
|
411 |
+
|
412 |
+
assert box_regression.shape[1] % box_dim == 0
|
413 |
+
cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1
|
414 |
+
|
415 |
+
input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1
|
416 |
+
|
417 |
+
proposal_boxes = proposals[0].proposal_boxes
|
418 |
+
if isinstance(proposal_boxes, Caffe2Boxes):
|
419 |
+
rois = Caffe2Boxes.cat([p.proposal_boxes for p in proposals])
|
420 |
+
elif isinstance(proposal_boxes, RotatedBoxes):
|
421 |
+
rois = RotatedBoxes.cat([p.proposal_boxes for p in proposals])
|
422 |
+
elif isinstance(proposal_boxes, Boxes):
|
423 |
+
rois = Boxes.cat([p.proposal_boxes for p in proposals])
|
424 |
+
else:
|
425 |
+
raise NotImplementedError(
|
426 |
+
'Expected proposals[0].proposal_boxes to be type "Boxes", '
|
427 |
+
f"instead got {type(proposal_boxes)}"
|
428 |
+
)
|
429 |
+
|
430 |
+
device, dtype = rois.tensor.device, rois.tensor.dtype
|
431 |
+
if input_tensor_mode:
|
432 |
+
im_info = proposals[0].image_size
|
433 |
+
rois = rois.tensor
|
434 |
+
else:
|
435 |
+
im_info = torch.tensor(
|
436 |
+
[[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]]
|
437 |
+
)
|
438 |
+
batch_ids = cat(
|
439 |
+
[
|
440 |
+
torch.full((b, 1), i, dtype=dtype, device=device)
|
441 |
+
for i, b in enumerate(len(p) for p in proposals)
|
442 |
+
],
|
443 |
+
dim=0,
|
444 |
+
)
|
445 |
+
rois = torch.cat([batch_ids, rois.tensor], dim=1)
|
446 |
+
|
447 |
+
roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform(
|
448 |
+
to_device(rois, "cpu"),
|
449 |
+
to_device(box_regression, "cpu"),
|
450 |
+
to_device(im_info, "cpu"),
|
451 |
+
weights=box2box_transform_weights,
|
452 |
+
apply_scale=True,
|
453 |
+
rotated=is_rotated,
|
454 |
+
angle_bound_on=True,
|
455 |
+
angle_bound_lo=-180,
|
456 |
+
angle_bound_hi=180,
|
457 |
+
clip_angle_thresh=1.0,
|
458 |
+
legacy_plus_one=False,
|
459 |
+
)
|
460 |
+
roi_pred_bbox = to_device(roi_pred_bbox, device)
|
461 |
+
roi_batch_splits = to_device(roi_batch_splits, device)
|
462 |
+
|
463 |
+
nms_outputs = torch.ops._caffe2.BoxWithNMSLimit(
|
464 |
+
to_device(class_prob, "cpu"),
|
465 |
+
to_device(roi_pred_bbox, "cpu"),
|
466 |
+
to_device(roi_batch_splits, "cpu"),
|
467 |
+
score_thresh=float(score_thresh),
|
468 |
+
nms=float(nms_thresh),
|
469 |
+
detections_per_im=int(topk_per_image),
|
470 |
+
soft_nms_enabled=False,
|
471 |
+
soft_nms_method="linear",
|
472 |
+
soft_nms_sigma=0.5,
|
473 |
+
soft_nms_min_score_thres=0.001,
|
474 |
+
rotated=is_rotated,
|
475 |
+
cls_agnostic_bbox_reg=cls_agnostic_bbox_reg,
|
476 |
+
input_boxes_include_bg_cls=False,
|
477 |
+
output_classes_include_bg_cls=False,
|
478 |
+
legacy_plus_one=False,
|
479 |
+
)
|
480 |
+
roi_score_nms = to_device(nms_outputs[0], device)
|
481 |
+
roi_bbox_nms = to_device(nms_outputs[1], device)
|
482 |
+
roi_class_nms = to_device(nms_outputs[2], device)
|
483 |
+
roi_batch_splits_nms = to_device(nms_outputs[3], device)
|
484 |
+
roi_keeps_nms = to_device(nms_outputs[4], device)
|
485 |
+
roi_keeps_size_nms = to_device(nms_outputs[5], device)
|
486 |
+
if not self.tensor_mode:
|
487 |
+
roi_class_nms = roi_class_nms.to(torch.int64)
|
488 |
+
|
489 |
+
roi_batch_ids = cat(
|
490 |
+
[
|
491 |
+
torch.full((b, 1), i, dtype=dtype, device=device)
|
492 |
+
for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms)
|
493 |
+
],
|
494 |
+
dim=0,
|
495 |
+
)
|
496 |
+
|
497 |
+
roi_class_nms = alias(roi_class_nms, "class_nms")
|
498 |
+
roi_score_nms = alias(roi_score_nms, "score_nms")
|
499 |
+
roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms")
|
500 |
+
roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms")
|
501 |
+
roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms")
|
502 |
+
roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms")
|
503 |
+
|
504 |
+
results = InstancesList(
|
505 |
+
im_info=im_info,
|
506 |
+
indices=roi_batch_ids[:, 0],
|
507 |
+
extra_fields={
|
508 |
+
"pred_boxes": Caffe2Boxes(roi_bbox_nms),
|
509 |
+
"scores": roi_score_nms,
|
510 |
+
"pred_classes": roi_class_nms,
|
511 |
+
},
|
512 |
+
)
|
513 |
+
|
514 |
+
if not self.tensor_mode:
|
515 |
+
results = InstancesList.to_d2_instances_list(results)
|
516 |
+
batch_splits = roi_batch_splits_nms.int().tolist()
|
517 |
+
kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits))
|
518 |
+
else:
|
519 |
+
results = [results]
|
520 |
+
kept_indices = [roi_keeps_nms]
|
521 |
+
|
522 |
+
return results, kept_indices
|
523 |
+
|
524 |
+
|
525 |
+
class Caffe2MaskRCNNInference:
|
526 |
+
def __call__(self, pred_mask_logits, pred_instances):
|
527 |
+
"""equivalent to mask_head.mask_rcnn_inference"""
|
528 |
+
if all(isinstance(x, InstancesList) for x in pred_instances):
|
529 |
+
assert len(pred_instances) == 1
|
530 |
+
mask_probs_pred = pred_mask_logits.sigmoid()
|
531 |
+
mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs")
|
532 |
+
pred_instances[0].set("pred_masks", mask_probs_pred)
|
533 |
+
else:
|
534 |
+
mask_rcnn_inference(pred_mask_logits, pred_instances)
|
535 |
+
|
536 |
+
|
537 |
+
class Caffe2KeypointRCNNInference:
|
538 |
+
def __init__(self, use_heatmap_max_keypoint):
|
539 |
+
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
540 |
+
|
541 |
+
def __call__(self, pred_keypoint_logits, pred_instances):
|
542 |
+
# just return the keypoint heatmap for now,
|
543 |
+
# there will be option to call HeatmapMaxKeypointOp
|
544 |
+
output = alias(pred_keypoint_logits, "kps_score")
|
545 |
+
if all(isinstance(x, InstancesList) for x in pred_instances):
|
546 |
+
assert len(pred_instances) == 1
|
547 |
+
if self.use_heatmap_max_keypoint:
|
548 |
+
device = output.device
|
549 |
+
output = torch.ops._caffe2.HeatmapMaxKeypoint(
|
550 |
+
to_device(output, "cpu"),
|
551 |
+
pred_instances[0].pred_boxes.tensor,
|
552 |
+
should_output_softmax=True, # worth make it configerable?
|
553 |
+
)
|
554 |
+
output = to_device(output, device)
|
555 |
+
output = alias(output, "keypoints_out")
|
556 |
+
pred_instances[0].set("pred_keypoints", output)
|
557 |
+
return pred_keypoint_logits
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/caffe2_export.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import io
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
from typing import List
|
8 |
+
import onnx
|
9 |
+
import onnx.optimizer
|
10 |
+
import torch
|
11 |
+
from caffe2.proto import caffe2_pb2
|
12 |
+
from caffe2.python import core
|
13 |
+
from caffe2.python.onnx.backend import Caffe2Backend
|
14 |
+
from tabulate import tabulate
|
15 |
+
from termcolor import colored
|
16 |
+
from torch.onnx import OperatorExportTypes
|
17 |
+
|
18 |
+
from .shared import (
|
19 |
+
ScopedWS,
|
20 |
+
construct_init_net_from_params,
|
21 |
+
fuse_alias_placeholder,
|
22 |
+
fuse_copy_between_cpu_and_gpu,
|
23 |
+
get_params_from_init_net,
|
24 |
+
group_norm_replace_aten_with_caffe2,
|
25 |
+
infer_device_type,
|
26 |
+
remove_dead_end_ops,
|
27 |
+
remove_reshape_for_fc,
|
28 |
+
save_graph,
|
29 |
+
)
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def export_onnx_model(model, inputs):
|
35 |
+
"""
|
36 |
+
Trace and export a model to onnx format.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
model (nn.Module):
|
40 |
+
inputs (tuple[args]): the model will be called by `model(*inputs)`
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
an onnx model
|
44 |
+
"""
|
45 |
+
assert isinstance(model, torch.nn.Module)
|
46 |
+
|
47 |
+
# make sure all modules are in eval mode, onnx may change the training state
|
48 |
+
# of the module if the states are not consistent
|
49 |
+
def _check_eval(module):
|
50 |
+
assert not module.training
|
51 |
+
|
52 |
+
model.apply(_check_eval)
|
53 |
+
|
54 |
+
# Export the model to ONNX
|
55 |
+
with torch.no_grad():
|
56 |
+
with io.BytesIO() as f:
|
57 |
+
torch.onnx.export(
|
58 |
+
model,
|
59 |
+
inputs,
|
60 |
+
f,
|
61 |
+
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
62 |
+
# verbose=True, # NOTE: uncomment this for debugging
|
63 |
+
# export_params=True,
|
64 |
+
)
|
65 |
+
onnx_model = onnx.load_from_string(f.getvalue())
|
66 |
+
|
67 |
+
return onnx_model
|
68 |
+
|
69 |
+
|
70 |
+
def _op_stats(net_def):
|
71 |
+
type_count = {}
|
72 |
+
for t in [op.type for op in net_def.op]:
|
73 |
+
type_count[t] = type_count.get(t, 0) + 1
|
74 |
+
type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet
|
75 |
+
type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count
|
76 |
+
return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list)
|
77 |
+
|
78 |
+
|
79 |
+
def _assign_device_option(
|
80 |
+
predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor]
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
ONNX exported network doesn't have concept of device, assign necessary
|
84 |
+
device option for each op in order to make it runable on GPU runtime.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def _get_device_type(torch_tensor):
|
88 |
+
assert torch_tensor.device.type in ["cpu", "cuda"]
|
89 |
+
assert torch_tensor.device.index == 0
|
90 |
+
return torch_tensor.device.type
|
91 |
+
|
92 |
+
def _assign_op_device_option(net_proto, net_ssa, blob_device_types):
|
93 |
+
for op, ssa_i in zip(net_proto.op, net_ssa):
|
94 |
+
if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]:
|
95 |
+
op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
|
96 |
+
else:
|
97 |
+
devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]]
|
98 |
+
assert all(d == devices[0] for d in devices)
|
99 |
+
if devices[0] == "cuda":
|
100 |
+
op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
|
101 |
+
|
102 |
+
# update ops in predict_net
|
103 |
+
predict_net_input_device_types = {
|
104 |
+
(name, 0): _get_device_type(tensor)
|
105 |
+
for name, tensor in zip(predict_net.external_input, tensor_inputs)
|
106 |
+
}
|
107 |
+
predict_net_device_types = infer_device_type(
|
108 |
+
predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch"
|
109 |
+
)
|
110 |
+
predict_net_ssa, _ = core.get_ssa(predict_net)
|
111 |
+
_assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types)
|
112 |
+
|
113 |
+
# update ops in init_net
|
114 |
+
init_net_ssa, versions = core.get_ssa(init_net)
|
115 |
+
init_net_output_device_types = {
|
116 |
+
(name, versions[name]): predict_net_device_types[(name, 0)]
|
117 |
+
for name in init_net.external_output
|
118 |
+
}
|
119 |
+
init_net_device_types = infer_device_type(
|
120 |
+
init_net, known_status=init_net_output_device_types, device_name_style="pytorch"
|
121 |
+
)
|
122 |
+
_assign_op_device_option(init_net, init_net_ssa, init_net_device_types)
|
123 |
+
|
124 |
+
|
125 |
+
def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]):
|
126 |
+
"""
|
127 |
+
Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX.
|
128 |
+
|
129 |
+
Arg:
|
130 |
+
model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py
|
131 |
+
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
132 |
+
"""
|
133 |
+
model = copy.deepcopy(model)
|
134 |
+
assert isinstance(model, torch.nn.Module)
|
135 |
+
assert hasattr(model, "encode_additional_info")
|
136 |
+
|
137 |
+
# Export via ONNX
|
138 |
+
logger.info(
|
139 |
+
"Exporting a {} model via ONNX ...".format(type(model).__name__)
|
140 |
+
+ " Some warnings from ONNX are expected and are usually not to worry about."
|
141 |
+
)
|
142 |
+
onnx_model = export_onnx_model(model, (tensor_inputs,))
|
143 |
+
# Convert ONNX model to Caffe2 protobuf
|
144 |
+
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
|
145 |
+
ops_table = [[op.type, op.input, op.output] for op in predict_net.op]
|
146 |
+
table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe")
|
147 |
+
logger.info(
|
148 |
+
"ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan")
|
149 |
+
)
|
150 |
+
|
151 |
+
# Apply protobuf optimization
|
152 |
+
fuse_alias_placeholder(predict_net, init_net)
|
153 |
+
if any(t.device.type != "cpu" for t in tensor_inputs):
|
154 |
+
fuse_copy_between_cpu_and_gpu(predict_net)
|
155 |
+
remove_dead_end_ops(init_net)
|
156 |
+
_assign_device_option(predict_net, init_net, tensor_inputs)
|
157 |
+
params, device_options = get_params_from_init_net(init_net)
|
158 |
+
predict_net, params = remove_reshape_for_fc(predict_net, params)
|
159 |
+
init_net = construct_init_net_from_params(params, device_options)
|
160 |
+
group_norm_replace_aten_with_caffe2(predict_net)
|
161 |
+
|
162 |
+
# Record necessary information for running the pb model in Detectron2 system.
|
163 |
+
model.encode_additional_info(predict_net, init_net)
|
164 |
+
|
165 |
+
logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net)))
|
166 |
+
logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net)))
|
167 |
+
|
168 |
+
return predict_net, init_net
|
169 |
+
|
170 |
+
|
171 |
+
def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path):
|
172 |
+
"""
|
173 |
+
Run the caffe2 model on given inputs, recording the shape and draw the graph.
|
174 |
+
|
175 |
+
predict_net/init_net: caffe2 model.
|
176 |
+
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
177 |
+
graph_save_path: path for saving graph of exported model.
|
178 |
+
"""
|
179 |
+
|
180 |
+
logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
|
181 |
+
save_graph(predict_net, graph_save_path, op_only=False)
|
182 |
+
|
183 |
+
# Run the exported Caffe2 net
|
184 |
+
logger.info("Running ONNX exported model ...")
|
185 |
+
with ScopedWS("__ws_tmp__", True) as ws:
|
186 |
+
ws.RunNetOnce(init_net)
|
187 |
+
initialized_blobs = set(ws.Blobs())
|
188 |
+
uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs]
|
189 |
+
for name, blob in zip(uninitialized, tensor_inputs):
|
190 |
+
ws.FeedBlob(name, blob)
|
191 |
+
|
192 |
+
try:
|
193 |
+
ws.RunNetOnce(predict_net)
|
194 |
+
except RuntimeError as e:
|
195 |
+
logger.warning("Encountered RuntimeError: \n{}".format(str(e)))
|
196 |
+
|
197 |
+
ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()}
|
198 |
+
blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)}
|
199 |
+
|
200 |
+
logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path))
|
201 |
+
save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes)
|
202 |
+
|
203 |
+
return ws_blobs
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/caffe2_inference.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
from itertools import count
|
6 |
+
import torch
|
7 |
+
from caffe2.proto import caffe2_pb2
|
8 |
+
from caffe2.python import core
|
9 |
+
|
10 |
+
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
|
11 |
+
from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
# ===== ref: mobile-vision predictor's 'Caffe2Wrapper' class ======
|
17 |
+
class ProtobufModel(torch.nn.Module):
|
18 |
+
"""
|
19 |
+
Wrapper of a caffe2's protobuf model.
|
20 |
+
It works just like nn.Module, but running caffe2 under the hood.
|
21 |
+
Input/Output are tuple[tensor] that match the caffe2 net's external_input/output.
|
22 |
+
"""
|
23 |
+
|
24 |
+
_ids = count(0)
|
25 |
+
|
26 |
+
def __init__(self, predict_net, init_net):
|
27 |
+
logger.info(f"Initializing ProtobufModel for: {predict_net.name} ...")
|
28 |
+
super().__init__()
|
29 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
30 |
+
assert isinstance(init_net, caffe2_pb2.NetDef)
|
31 |
+
# create unique temporary workspace for each instance
|
32 |
+
self.ws_name = "__tmp_ProtobufModel_{}__".format(next(self._ids))
|
33 |
+
self.net = core.Net(predict_net)
|
34 |
+
|
35 |
+
logger.info("Running init_net once to fill the parameters ...")
|
36 |
+
with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws:
|
37 |
+
ws.RunNetOnce(init_net)
|
38 |
+
uninitialized_external_input = []
|
39 |
+
for blob in self.net.Proto().external_input:
|
40 |
+
if blob not in ws.Blobs():
|
41 |
+
uninitialized_external_input.append(blob)
|
42 |
+
ws.CreateBlob(blob)
|
43 |
+
ws.CreateNet(self.net)
|
44 |
+
|
45 |
+
self._error_msgs = set()
|
46 |
+
self._input_blobs = uninitialized_external_input
|
47 |
+
|
48 |
+
def _infer_output_devices(self, inputs):
|
49 |
+
"""
|
50 |
+
Returns:
|
51 |
+
list[str]: list of device for each external output
|
52 |
+
"""
|
53 |
+
|
54 |
+
def _get_device_type(torch_tensor):
|
55 |
+
assert torch_tensor.device.type in ["cpu", "cuda"]
|
56 |
+
assert torch_tensor.device.index == 0
|
57 |
+
return torch_tensor.device.type
|
58 |
+
|
59 |
+
predict_net = self.net.Proto()
|
60 |
+
input_device_types = {
|
61 |
+
(name, 0): _get_device_type(tensor) for name, tensor in zip(self._input_blobs, inputs)
|
62 |
+
}
|
63 |
+
device_type_map = infer_device_type(
|
64 |
+
predict_net, known_status=input_device_types, device_name_style="pytorch"
|
65 |
+
)
|
66 |
+
ssa, versions = core.get_ssa(predict_net)
|
67 |
+
versioned_outputs = [(name, versions[name]) for name in predict_net.external_output]
|
68 |
+
output_devices = [device_type_map[outp] for outp in versioned_outputs]
|
69 |
+
return output_devices
|
70 |
+
|
71 |
+
def forward(self, inputs):
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
inputs (tuple[torch.Tensor])
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
tuple[torch.Tensor]
|
78 |
+
"""
|
79 |
+
assert len(inputs) == len(self._input_blobs), (
|
80 |
+
f"Length of inputs ({len(inputs)}) "
|
81 |
+
f"doesn't match the required input blobs: {self._input_blobs}"
|
82 |
+
)
|
83 |
+
|
84 |
+
with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws:
|
85 |
+
for b, tensor in zip(self._input_blobs, inputs):
|
86 |
+
ws.FeedBlob(b, tensor)
|
87 |
+
|
88 |
+
try:
|
89 |
+
ws.RunNet(self.net.Proto().name)
|
90 |
+
except RuntimeError as e:
|
91 |
+
if not str(e) in self._error_msgs:
|
92 |
+
self._error_msgs.add(str(e))
|
93 |
+
logger.warning("Encountered new RuntimeError: \n{}".format(str(e)))
|
94 |
+
logger.warning("Catch the error and use partial results.")
|
95 |
+
|
96 |
+
c2_outputs = [ws.FetchBlob(b) for b in self.net.Proto().external_output]
|
97 |
+
# Remove outputs of current run, this is necessary in order to
|
98 |
+
# prevent fetching the result from previous run if the model fails
|
99 |
+
# in the middle.
|
100 |
+
for b in self.net.Proto().external_output:
|
101 |
+
# Needs to create uninitialized blob to make the net runable.
|
102 |
+
# This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b),
|
103 |
+
# but there'no such API.
|
104 |
+
ws.FeedBlob(b, f"{b}, a C++ native class of type nullptr (uninitialized).")
|
105 |
+
|
106 |
+
# Cast output to torch.Tensor on the desired device
|
107 |
+
output_devices = (
|
108 |
+
self._infer_output_devices(inputs)
|
109 |
+
if any(t.device.type != "cpu" for t in inputs)
|
110 |
+
else ["cpu" for _ in self.net.Proto().external_output]
|
111 |
+
)
|
112 |
+
|
113 |
+
outputs = []
|
114 |
+
for name, c2_output, device in zip(
|
115 |
+
self.net.Proto().external_output, c2_outputs, output_devices
|
116 |
+
):
|
117 |
+
if not isinstance(c2_output, np.ndarray):
|
118 |
+
raise RuntimeError(
|
119 |
+
"Invalid output for blob {}, received: {}".format(name, c2_output)
|
120 |
+
)
|
121 |
+
outputs.append(torch.tensor(c2_output).to(device=device))
|
122 |
+
return tuple(outputs)
|
123 |
+
|
124 |
+
|
125 |
+
class ProtobufDetectionModel(torch.nn.Module):
|
126 |
+
"""
|
127 |
+
A class works just like a pytorch meta arch in terms of inference, but running
|
128 |
+
caffe2 model under the hood.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, predict_net, init_net, *, convert_outputs=None):
|
132 |
+
"""
|
133 |
+
Args:
|
134 |
+
predict_net, init_net (core.Net): caffe2 nets
|
135 |
+
convert_outptus (callable): a function that converts caffe2
|
136 |
+
outputs to the same format of the original pytorch model.
|
137 |
+
By default, use the one defined in the caffe2 meta_arch.
|
138 |
+
"""
|
139 |
+
super().__init__()
|
140 |
+
self.protobuf_model = ProtobufModel(predict_net, init_net)
|
141 |
+
self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0)
|
142 |
+
self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii")
|
143 |
+
|
144 |
+
if convert_outputs is None:
|
145 |
+
meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN")
|
146 |
+
meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")]
|
147 |
+
self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net)
|
148 |
+
else:
|
149 |
+
self._convert_outputs = convert_outputs
|
150 |
+
|
151 |
+
def _convert_inputs(self, batched_inputs):
|
152 |
+
# currently all models convert inputs in the same way
|
153 |
+
return convert_batched_inputs_to_c2_format(
|
154 |
+
batched_inputs, self.size_divisibility, self.device
|
155 |
+
)
|
156 |
+
|
157 |
+
def forward(self, batched_inputs):
|
158 |
+
c2_inputs = self._convert_inputs(batched_inputs)
|
159 |
+
c2_results = self.protobuf_model(c2_inputs)
|
160 |
+
c2_results = dict(zip(self.protobuf_model.net.Proto().external_output, c2_results))
|
161 |
+
return self._convert_outputs(batched_inputs, c2_inputs, c2_results)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/caffe2_modeling.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import functools
|
4 |
+
import io
|
5 |
+
import struct
|
6 |
+
import types
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from annotator.oneformer.detectron2.modeling import meta_arch
|
10 |
+
from annotator.oneformer.detectron2.modeling.box_regression import Box2BoxTransform
|
11 |
+
from annotator.oneformer.detectron2.modeling.roi_heads import keypoint_head
|
12 |
+
from annotator.oneformer.detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes
|
13 |
+
|
14 |
+
from .c10 import Caffe2Compatible
|
15 |
+
from .caffe2_patch import ROIHeadsPatcher, patch_generalized_rcnn
|
16 |
+
from .shared import (
|
17 |
+
alias,
|
18 |
+
check_set_pb_arg,
|
19 |
+
get_pb_arg_floats,
|
20 |
+
get_pb_arg_valf,
|
21 |
+
get_pb_arg_vali,
|
22 |
+
get_pb_arg_vals,
|
23 |
+
mock_torch_nn_functional_interpolate,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False):
|
28 |
+
"""
|
29 |
+
A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor])
|
30 |
+
to detectron2's format (i.e. list of Instances instance).
|
31 |
+
This only works when the model follows the Caffe2 detectron's naming convention.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
image_sizes (List[List[int, int]]): [H, W] of every image.
|
35 |
+
tensor_outputs (Dict[str, Tensor]): external_output to its tensor.
|
36 |
+
|
37 |
+
force_mask_on (Bool): if true, the it make sure there'll be pred_masks even
|
38 |
+
if the mask is not found from tensor_outputs (usually due to model crash)
|
39 |
+
"""
|
40 |
+
|
41 |
+
results = [Instances(image_size) for image_size in image_sizes]
|
42 |
+
|
43 |
+
batch_splits = tensor_outputs.get("batch_splits", None)
|
44 |
+
if batch_splits:
|
45 |
+
raise NotImplementedError()
|
46 |
+
assert len(image_sizes) == 1
|
47 |
+
result = results[0]
|
48 |
+
|
49 |
+
bbox_nms = tensor_outputs["bbox_nms"]
|
50 |
+
score_nms = tensor_outputs["score_nms"]
|
51 |
+
class_nms = tensor_outputs["class_nms"]
|
52 |
+
# Detection will always success because Conv support 0-batch
|
53 |
+
assert bbox_nms is not None
|
54 |
+
assert score_nms is not None
|
55 |
+
assert class_nms is not None
|
56 |
+
if bbox_nms.shape[1] == 5:
|
57 |
+
result.pred_boxes = RotatedBoxes(bbox_nms)
|
58 |
+
else:
|
59 |
+
result.pred_boxes = Boxes(bbox_nms)
|
60 |
+
result.scores = score_nms
|
61 |
+
result.pred_classes = class_nms.to(torch.int64)
|
62 |
+
|
63 |
+
mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None)
|
64 |
+
if mask_fcn_probs is not None:
|
65 |
+
# finish the mask pred
|
66 |
+
mask_probs_pred = mask_fcn_probs
|
67 |
+
num_masks = mask_probs_pred.shape[0]
|
68 |
+
class_pred = result.pred_classes
|
69 |
+
indices = torch.arange(num_masks, device=class_pred.device)
|
70 |
+
mask_probs_pred = mask_probs_pred[indices, class_pred][:, None]
|
71 |
+
result.pred_masks = mask_probs_pred
|
72 |
+
elif force_mask_on:
|
73 |
+
# NOTE: there's no way to know the height/width of mask here, it won't be
|
74 |
+
# used anyway when batch size is 0, so just set them to 0.
|
75 |
+
result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8)
|
76 |
+
|
77 |
+
keypoints_out = tensor_outputs.get("keypoints_out", None)
|
78 |
+
kps_score = tensor_outputs.get("kps_score", None)
|
79 |
+
if keypoints_out is not None:
|
80 |
+
# keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob)
|
81 |
+
keypoints_tensor = keypoints_out
|
82 |
+
# NOTE: it's possible that prob is not calculated if "should_output_softmax"
|
83 |
+
# is set to False in HeatmapMaxKeypoint, so just using raw score, seems
|
84 |
+
# it doesn't affect mAP. TODO: check more carefully.
|
85 |
+
keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]]
|
86 |
+
result.pred_keypoints = keypoint_xyp
|
87 |
+
elif kps_score is not None:
|
88 |
+
# keypoint heatmap to sparse data structure
|
89 |
+
pred_keypoint_logits = kps_score
|
90 |
+
keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result])
|
91 |
+
|
92 |
+
return results
|
93 |
+
|
94 |
+
|
95 |
+
def _cast_to_f32(f64):
|
96 |
+
return struct.unpack("f", struct.pack("f", f64))[0]
|
97 |
+
|
98 |
+
|
99 |
+
def set_caffe2_compatible_tensor_mode(model, enable=True):
|
100 |
+
def _fn(m):
|
101 |
+
if isinstance(m, Caffe2Compatible):
|
102 |
+
m.tensor_mode = enable
|
103 |
+
|
104 |
+
model.apply(_fn)
|
105 |
+
|
106 |
+
|
107 |
+
def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device):
|
108 |
+
"""
|
109 |
+
See get_caffe2_inputs() below.
|
110 |
+
"""
|
111 |
+
assert all(isinstance(x, dict) for x in batched_inputs)
|
112 |
+
assert all(x["image"].dim() == 3 for x in batched_inputs)
|
113 |
+
|
114 |
+
images = [x["image"] for x in batched_inputs]
|
115 |
+
images = ImageList.from_tensors(images, size_divisibility)
|
116 |
+
|
117 |
+
im_info = []
|
118 |
+
for input_per_image, image_size in zip(batched_inputs, images.image_sizes):
|
119 |
+
target_height = input_per_image.get("height", image_size[0])
|
120 |
+
target_width = input_per_image.get("width", image_size[1]) # noqa
|
121 |
+
# NOTE: The scale inside im_info is kept as convention and for providing
|
122 |
+
# post-processing information if further processing is needed. For
|
123 |
+
# current Caffe2 model definitions that don't include post-processing inside
|
124 |
+
# the model, this number is not used.
|
125 |
+
# NOTE: There can be a slight difference between width and height
|
126 |
+
# scales, using a single number can results in numerical difference
|
127 |
+
# compared with D2's post-processing.
|
128 |
+
scale = target_height / image_size[0]
|
129 |
+
im_info.append([image_size[0], image_size[1], scale])
|
130 |
+
im_info = torch.Tensor(im_info)
|
131 |
+
|
132 |
+
return images.tensor.to(device), im_info.to(device)
|
133 |
+
|
134 |
+
|
135 |
+
class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module):
|
136 |
+
"""
|
137 |
+
Base class for caffe2-compatible implementation of a meta architecture.
|
138 |
+
The forward is traceable and its traced graph can be converted to caffe2
|
139 |
+
graph through ONNX.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, cfg, torch_model):
|
143 |
+
"""
|
144 |
+
Args:
|
145 |
+
cfg (CfgNode):
|
146 |
+
torch_model (nn.Module): the detectron2 model (meta_arch) to be
|
147 |
+
converted.
|
148 |
+
"""
|
149 |
+
super().__init__()
|
150 |
+
self._wrapped_model = torch_model
|
151 |
+
self.eval()
|
152 |
+
set_caffe2_compatible_tensor_mode(self, True)
|
153 |
+
|
154 |
+
def get_caffe2_inputs(self, batched_inputs):
|
155 |
+
"""
|
156 |
+
Convert pytorch-style structured inputs to caffe2-style inputs that
|
157 |
+
are tuples of tensors.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
batched_inputs (list[dict]): inputs to a detectron2 model
|
161 |
+
in its standard format. Each dict has "image" (CHW tensor), and optionally
|
162 |
+
"height" and "width".
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
tuple[Tensor]:
|
166 |
+
tuple of tensors that will be the inputs to the
|
167 |
+
:meth:`forward` method. For existing models, the first
|
168 |
+
is an NCHW tensor (padded and batched); the second is
|
169 |
+
a im_info Nx3 tensor, where the rows are
|
170 |
+
(height, width, unused legacy parameter)
|
171 |
+
"""
|
172 |
+
return convert_batched_inputs_to_c2_format(
|
173 |
+
batched_inputs,
|
174 |
+
self._wrapped_model.backbone.size_divisibility,
|
175 |
+
self._wrapped_model.device,
|
176 |
+
)
|
177 |
+
|
178 |
+
def encode_additional_info(self, predict_net, init_net):
|
179 |
+
"""
|
180 |
+
Save extra metadata that will be used by inference in the output protobuf.
|
181 |
+
"""
|
182 |
+
pass
|
183 |
+
|
184 |
+
def forward(self, inputs):
|
185 |
+
"""
|
186 |
+
Run the forward in caffe2-style. It has to use caffe2-compatible ops
|
187 |
+
and the method will be used for tracing.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`.
|
191 |
+
They will be the inputs of the converted caffe2 graph.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
tuple[Tensor]: output tensors. They will be the outputs of the
|
195 |
+
converted caffe2 graph.
|
196 |
+
"""
|
197 |
+
raise NotImplementedError
|
198 |
+
|
199 |
+
def _caffe2_preprocess_image(self, inputs):
|
200 |
+
"""
|
201 |
+
Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward.
|
202 |
+
It normalizes the input images, and the final caffe2 graph assumes the
|
203 |
+
inputs have been batched already.
|
204 |
+
"""
|
205 |
+
data, im_info = inputs
|
206 |
+
data = alias(data, "data")
|
207 |
+
im_info = alias(im_info, "im_info")
|
208 |
+
mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std
|
209 |
+
normalized_data = (data - mean) / std
|
210 |
+
normalized_data = alias(normalized_data, "normalized_data")
|
211 |
+
|
212 |
+
# Pack (data, im_info) into ImageList which is recognized by self.inference.
|
213 |
+
images = ImageList(tensor=normalized_data, image_sizes=im_info)
|
214 |
+
return images
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def get_outputs_converter(predict_net, init_net):
|
218 |
+
"""
|
219 |
+
Creates a function that converts outputs of the caffe2 model to
|
220 |
+
detectron2's standard format.
|
221 |
+
The function uses information in `predict_net` and `init_net` that are
|
222 |
+
available at inferene time. Therefore the function logic can be used in inference.
|
223 |
+
|
224 |
+
The returned function has the following signature:
|
225 |
+
|
226 |
+
def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs
|
227 |
+
|
228 |
+
Where
|
229 |
+
|
230 |
+
* batched_inputs (list[dict]): the original input format of the meta arch
|
231 |
+
* c2_inputs (tuple[Tensor]): the caffe2 inputs.
|
232 |
+
* c2_results (dict[str, Tensor]): the caffe2 output format,
|
233 |
+
corresponding to the outputs of the :meth:`forward` function.
|
234 |
+
* detectron2_outputs: the original output format of the meta arch.
|
235 |
+
|
236 |
+
This function can be used to compare the outputs of the original meta arch and
|
237 |
+
the converted caffe2 graph.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
callable: a callable of the above signature.
|
241 |
+
"""
|
242 |
+
raise NotImplementedError
|
243 |
+
|
244 |
+
|
245 |
+
class Caffe2GeneralizedRCNN(Caffe2MetaArch):
|
246 |
+
def __init__(self, cfg, torch_model):
|
247 |
+
assert isinstance(torch_model, meta_arch.GeneralizedRCNN)
|
248 |
+
torch_model = patch_generalized_rcnn(torch_model)
|
249 |
+
super().__init__(cfg, torch_model)
|
250 |
+
|
251 |
+
try:
|
252 |
+
use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
|
253 |
+
except AttributeError:
|
254 |
+
use_heatmap_max_keypoint = False
|
255 |
+
self.roi_heads_patcher = ROIHeadsPatcher(
|
256 |
+
self._wrapped_model.roi_heads, use_heatmap_max_keypoint
|
257 |
+
)
|
258 |
+
|
259 |
+
def encode_additional_info(self, predict_net, init_net):
|
260 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
261 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
262 |
+
check_set_pb_arg(
|
263 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
264 |
+
)
|
265 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN")
|
266 |
+
|
267 |
+
@mock_torch_nn_functional_interpolate()
|
268 |
+
def forward(self, inputs):
|
269 |
+
if not self.tensor_mode:
|
270 |
+
return self._wrapped_model.inference(inputs)
|
271 |
+
images = self._caffe2_preprocess_image(inputs)
|
272 |
+
features = self._wrapped_model.backbone(images.tensor)
|
273 |
+
proposals, _ = self._wrapped_model.proposal_generator(images, features)
|
274 |
+
with self.roi_heads_patcher.mock_roi_heads():
|
275 |
+
detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals)
|
276 |
+
return tuple(detector_results[0].flatten())
|
277 |
+
|
278 |
+
@staticmethod
|
279 |
+
def get_outputs_converter(predict_net, init_net):
|
280 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
281 |
+
_, im_info = c2_inputs
|
282 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
|
283 |
+
results = assemble_rcnn_outputs_by_name(image_sizes, c2_results)
|
284 |
+
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
|
285 |
+
|
286 |
+
return f
|
287 |
+
|
288 |
+
|
289 |
+
class Caffe2RetinaNet(Caffe2MetaArch):
|
290 |
+
def __init__(self, cfg, torch_model):
|
291 |
+
assert isinstance(torch_model, meta_arch.RetinaNet)
|
292 |
+
super().__init__(cfg, torch_model)
|
293 |
+
|
294 |
+
@mock_torch_nn_functional_interpolate()
|
295 |
+
def forward(self, inputs):
|
296 |
+
assert self.tensor_mode
|
297 |
+
images = self._caffe2_preprocess_image(inputs)
|
298 |
+
|
299 |
+
# explicitly return the images sizes to avoid removing "im_info" by ONNX
|
300 |
+
# since it's not used in the forward path
|
301 |
+
return_tensors = [images.image_sizes]
|
302 |
+
|
303 |
+
features = self._wrapped_model.backbone(images.tensor)
|
304 |
+
features = [features[f] for f in self._wrapped_model.head_in_features]
|
305 |
+
for i, feature_i in enumerate(features):
|
306 |
+
features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True)
|
307 |
+
return_tensors.append(features[i])
|
308 |
+
|
309 |
+
pred_logits, pred_anchor_deltas = self._wrapped_model.head(features)
|
310 |
+
for i, (box_cls_i, box_delta_i) in enumerate(zip(pred_logits, pred_anchor_deltas)):
|
311 |
+
return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i)))
|
312 |
+
return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i)))
|
313 |
+
|
314 |
+
return tuple(return_tensors)
|
315 |
+
|
316 |
+
def encode_additional_info(self, predict_net, init_net):
|
317 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
318 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
319 |
+
check_set_pb_arg(
|
320 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
321 |
+
)
|
322 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet")
|
323 |
+
|
324 |
+
# Inference parameters:
|
325 |
+
check_set_pb_arg(
|
326 |
+
predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.test_score_thresh)
|
327 |
+
)
|
328 |
+
check_set_pb_arg(
|
329 |
+
predict_net, "topk_candidates", "i", self._wrapped_model.test_topk_candidates
|
330 |
+
)
|
331 |
+
check_set_pb_arg(
|
332 |
+
predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.test_nms_thresh)
|
333 |
+
)
|
334 |
+
check_set_pb_arg(
|
335 |
+
predict_net,
|
336 |
+
"max_detections_per_image",
|
337 |
+
"i",
|
338 |
+
self._wrapped_model.max_detections_per_image,
|
339 |
+
)
|
340 |
+
|
341 |
+
check_set_pb_arg(
|
342 |
+
predict_net,
|
343 |
+
"bbox_reg_weights",
|
344 |
+
"floats",
|
345 |
+
[_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights],
|
346 |
+
)
|
347 |
+
self._encode_anchor_generator_cfg(predict_net)
|
348 |
+
|
349 |
+
def _encode_anchor_generator_cfg(self, predict_net):
|
350 |
+
# serialize anchor_generator for future use
|
351 |
+
serialized_anchor_generator = io.BytesIO()
|
352 |
+
torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator)
|
353 |
+
# Ideally we can put anchor generating inside the model, then we don't
|
354 |
+
# need to store this information.
|
355 |
+
bytes = serialized_anchor_generator.getvalue()
|
356 |
+
check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes)
|
357 |
+
|
358 |
+
@staticmethod
|
359 |
+
def get_outputs_converter(predict_net, init_net):
|
360 |
+
self = types.SimpleNamespace()
|
361 |
+
serialized_anchor_generator = io.BytesIO(
|
362 |
+
get_pb_arg_vals(predict_net, "serialized_anchor_generator", None)
|
363 |
+
)
|
364 |
+
self.anchor_generator = torch.load(serialized_anchor_generator)
|
365 |
+
bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None)
|
366 |
+
self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights))
|
367 |
+
self.test_score_thresh = get_pb_arg_valf(predict_net, "score_threshold", None)
|
368 |
+
self.test_topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None)
|
369 |
+
self.test_nms_thresh = get_pb_arg_valf(predict_net, "nms_threshold", None)
|
370 |
+
self.max_detections_per_image = get_pb_arg_vali(
|
371 |
+
predict_net, "max_detections_per_image", None
|
372 |
+
)
|
373 |
+
|
374 |
+
# hack to reuse inference code from RetinaNet
|
375 |
+
for meth in [
|
376 |
+
"forward_inference",
|
377 |
+
"inference_single_image",
|
378 |
+
"_transpose_dense_predictions",
|
379 |
+
"_decode_multi_level_predictions",
|
380 |
+
"_decode_per_level_predictions",
|
381 |
+
]:
|
382 |
+
setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self))
|
383 |
+
|
384 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
385 |
+
_, im_info = c2_inputs
|
386 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
|
387 |
+
dummy_images = ImageList(
|
388 |
+
torch.randn(
|
389 |
+
(
|
390 |
+
len(im_info),
|
391 |
+
3,
|
392 |
+
)
|
393 |
+
+ tuple(image_sizes[0])
|
394 |
+
),
|
395 |
+
image_sizes,
|
396 |
+
)
|
397 |
+
|
398 |
+
num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")])
|
399 |
+
pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)]
|
400 |
+
pred_anchor_deltas = [c2_results["box_delta_{}".format(i)] for i in range(num_features)]
|
401 |
+
|
402 |
+
# For each feature level, feature should have the same batch size and
|
403 |
+
# spatial dimension as the box_cls and box_delta.
|
404 |
+
dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits]
|
405 |
+
# self.num_classess can be inferred
|
406 |
+
self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4)
|
407 |
+
|
408 |
+
results = self.forward_inference(
|
409 |
+
dummy_images, dummy_features, [pred_logits, pred_anchor_deltas]
|
410 |
+
)
|
411 |
+
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
|
412 |
+
|
413 |
+
return f
|
414 |
+
|
415 |
+
|
416 |
+
META_ARCH_CAFFE2_EXPORT_TYPE_MAP = {
|
417 |
+
"GeneralizedRCNN": Caffe2GeneralizedRCNN,
|
418 |
+
"RetinaNet": Caffe2RetinaNet,
|
419 |
+
}
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/caffe2_patch.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
from unittest import mock
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from annotator.oneformer.detectron2.modeling import poolers
|
8 |
+
from annotator.oneformer.detectron2.modeling.proposal_generator import rpn
|
9 |
+
from annotator.oneformer.detectron2.modeling.roi_heads import keypoint_head, mask_head
|
10 |
+
from annotator.oneformer.detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
|
11 |
+
|
12 |
+
from .c10 import (
|
13 |
+
Caffe2Compatible,
|
14 |
+
Caffe2FastRCNNOutputsInference,
|
15 |
+
Caffe2KeypointRCNNInference,
|
16 |
+
Caffe2MaskRCNNInference,
|
17 |
+
Caffe2ROIPooler,
|
18 |
+
Caffe2RPN,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class GenericMixin(object):
|
23 |
+
pass
|
24 |
+
|
25 |
+
|
26 |
+
class Caffe2CompatibleConverter(object):
|
27 |
+
"""
|
28 |
+
A GenericUpdater which implements the `create_from` interface, by modifying
|
29 |
+
module object and assign it with another class replaceCls.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, replaceCls):
|
33 |
+
self.replaceCls = replaceCls
|
34 |
+
|
35 |
+
def create_from(self, module):
|
36 |
+
# update module's class to the new class
|
37 |
+
assert isinstance(module, torch.nn.Module)
|
38 |
+
if issubclass(self.replaceCls, GenericMixin):
|
39 |
+
# replaceCls should act as mixin, create a new class on-the-fly
|
40 |
+
new_class = type(
|
41 |
+
"{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
|
42 |
+
(self.replaceCls, module.__class__),
|
43 |
+
{}, # {"new_method": lambda self: ...},
|
44 |
+
)
|
45 |
+
module.__class__ = new_class
|
46 |
+
else:
|
47 |
+
# replaceCls is complete class, this allow arbitrary class swap
|
48 |
+
module.__class__ = self.replaceCls
|
49 |
+
|
50 |
+
# initialize Caffe2Compatible
|
51 |
+
if isinstance(module, Caffe2Compatible):
|
52 |
+
module.tensor_mode = False
|
53 |
+
|
54 |
+
return module
|
55 |
+
|
56 |
+
|
57 |
+
def patch(model, target, updater, *args, **kwargs):
|
58 |
+
"""
|
59 |
+
recursively (post-order) update all modules with the target type and its
|
60 |
+
subclasses, make a initialization/composition/inheritance/... via the
|
61 |
+
updater.create_from.
|
62 |
+
"""
|
63 |
+
for name, module in model.named_children():
|
64 |
+
model._modules[name] = patch(module, target, updater, *args, **kwargs)
|
65 |
+
if isinstance(model, target):
|
66 |
+
return updater.create_from(model, *args, **kwargs)
|
67 |
+
return model
|
68 |
+
|
69 |
+
|
70 |
+
def patch_generalized_rcnn(model):
|
71 |
+
ccc = Caffe2CompatibleConverter
|
72 |
+
model = patch(model, rpn.RPN, ccc(Caffe2RPN))
|
73 |
+
model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
|
74 |
+
|
75 |
+
return model
|
76 |
+
|
77 |
+
|
78 |
+
@contextlib.contextmanager
|
79 |
+
def mock_fastrcnn_outputs_inference(
|
80 |
+
tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
|
81 |
+
):
|
82 |
+
with mock.patch.object(
|
83 |
+
box_predictor_type,
|
84 |
+
"inference",
|
85 |
+
autospec=True,
|
86 |
+
side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
|
87 |
+
) as mocked_func:
|
88 |
+
yield
|
89 |
+
if check:
|
90 |
+
assert mocked_func.call_count > 0
|
91 |
+
|
92 |
+
|
93 |
+
@contextlib.contextmanager
|
94 |
+
def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
|
95 |
+
with mock.patch(
|
96 |
+
"{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
|
97 |
+
) as mocked_func:
|
98 |
+
yield
|
99 |
+
if check:
|
100 |
+
assert mocked_func.call_count > 0
|
101 |
+
|
102 |
+
|
103 |
+
@contextlib.contextmanager
|
104 |
+
def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
|
105 |
+
with mock.patch(
|
106 |
+
"{}.keypoint_rcnn_inference".format(patched_module),
|
107 |
+
side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
|
108 |
+
) as mocked_func:
|
109 |
+
yield
|
110 |
+
if check:
|
111 |
+
assert mocked_func.call_count > 0
|
112 |
+
|
113 |
+
|
114 |
+
class ROIHeadsPatcher:
|
115 |
+
def __init__(self, heads, use_heatmap_max_keypoint):
|
116 |
+
self.heads = heads
|
117 |
+
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
118 |
+
|
119 |
+
@contextlib.contextmanager
|
120 |
+
def mock_roi_heads(self, tensor_mode=True):
|
121 |
+
"""
|
122 |
+
Patching several inference functions inside ROIHeads and its subclasses
|
123 |
+
|
124 |
+
Args:
|
125 |
+
tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
|
126 |
+
format or not. Default to True.
|
127 |
+
"""
|
128 |
+
# NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference`
|
129 |
+
# are called inside the same file as BaseXxxHead due to using mock.patch.
|
130 |
+
kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
|
131 |
+
mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
|
132 |
+
|
133 |
+
mock_ctx_managers = [
|
134 |
+
mock_fastrcnn_outputs_inference(
|
135 |
+
tensor_mode=tensor_mode,
|
136 |
+
check=True,
|
137 |
+
box_predictor_type=type(self.heads.box_predictor),
|
138 |
+
)
|
139 |
+
]
|
140 |
+
if getattr(self.heads, "keypoint_on", False):
|
141 |
+
mock_ctx_managers += [
|
142 |
+
mock_keypoint_rcnn_inference(
|
143 |
+
tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
|
144 |
+
)
|
145 |
+
]
|
146 |
+
if getattr(self.heads, "mask_on", False):
|
147 |
+
mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
|
148 |
+
|
149 |
+
with contextlib.ExitStack() as stack: # python 3.3+
|
150 |
+
for mgr in mock_ctx_managers:
|
151 |
+
stack.enter_context(mgr)
|
152 |
+
yield
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/flatten.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import collections
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Callable, List, Optional, Tuple
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from annotator.oneformer.detectron2.structures import Boxes, Instances, ROIMasks
|
9 |
+
from annotator.oneformer.detectron2.utils.registry import _convert_target_to_string, locate
|
10 |
+
|
11 |
+
from .torchscript_patch import patch_builtin_len
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class Schema:
|
16 |
+
"""
|
17 |
+
A Schema defines how to flatten a possibly hierarchical object into tuple of
|
18 |
+
primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
|
19 |
+
|
20 |
+
PyTorch does not support tracing a function that produces rich output
|
21 |
+
structures (e.g. dict, Instances, Boxes). To trace such a function, we
|
22 |
+
flatten the rich object into tuple of tensors, and return this tuple of tensors
|
23 |
+
instead. Meanwhile, we also need to know how to "rebuild" the original object
|
24 |
+
from the flattened results, so we can evaluate the flattened results.
|
25 |
+
A Schema defines how to flatten an object, and while flattening it, it records
|
26 |
+
necessary schemas so that the object can be rebuilt using the flattened outputs.
|
27 |
+
|
28 |
+
The flattened object and the schema object is returned by ``.flatten`` classmethod.
|
29 |
+
Then the original object can be rebuilt with the ``__call__`` method of schema.
|
30 |
+
|
31 |
+
A Schema is a dataclass that can be serialized easily.
|
32 |
+
"""
|
33 |
+
|
34 |
+
# inspired by FetchMapper in tensorflow/python/client/session.py
|
35 |
+
|
36 |
+
@classmethod
|
37 |
+
def flatten(cls, obj):
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
def __call__(self, values):
|
41 |
+
raise NotImplementedError
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def _concat(values):
|
45 |
+
ret = ()
|
46 |
+
sizes = []
|
47 |
+
for v in values:
|
48 |
+
assert isinstance(v, tuple), "Flattened results must be a tuple"
|
49 |
+
ret = ret + v
|
50 |
+
sizes.append(len(v))
|
51 |
+
return ret, sizes
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def _split(values, sizes):
|
55 |
+
if len(sizes):
|
56 |
+
expected_len = sum(sizes)
|
57 |
+
assert (
|
58 |
+
len(values) == expected_len
|
59 |
+
), f"Values has length {len(values)} but expect length {expected_len}."
|
60 |
+
ret = []
|
61 |
+
for k in range(len(sizes)):
|
62 |
+
begin, end = sum(sizes[:k]), sum(sizes[: k + 1])
|
63 |
+
ret.append(values[begin:end])
|
64 |
+
return ret
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class ListSchema(Schema):
|
69 |
+
schemas: List[Schema] # the schemas that define how to flatten each element in the list
|
70 |
+
sizes: List[int] # the flattened length of each element
|
71 |
+
|
72 |
+
def __call__(self, values):
|
73 |
+
values = self._split(values, self.sizes)
|
74 |
+
if len(values) != len(self.schemas):
|
75 |
+
raise ValueError(
|
76 |
+
f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
|
77 |
+
)
|
78 |
+
values = [m(v) for m, v in zip(self.schemas, values)]
|
79 |
+
return list(values)
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def flatten(cls, obj):
|
83 |
+
res = [flatten_to_tuple(k) for k in obj]
|
84 |
+
values, sizes = cls._concat([k[0] for k in res])
|
85 |
+
return values, cls([k[1] for k in res], sizes)
|
86 |
+
|
87 |
+
|
88 |
+
@dataclass
|
89 |
+
class TupleSchema(ListSchema):
|
90 |
+
def __call__(self, values):
|
91 |
+
return tuple(super().__call__(values))
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class IdentitySchema(Schema):
|
96 |
+
def __call__(self, values):
|
97 |
+
return values[0]
|
98 |
+
|
99 |
+
@classmethod
|
100 |
+
def flatten(cls, obj):
|
101 |
+
return (obj,), cls()
|
102 |
+
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class DictSchema(ListSchema):
|
106 |
+
keys: List[str]
|
107 |
+
|
108 |
+
def __call__(self, values):
|
109 |
+
values = super().__call__(values)
|
110 |
+
return dict(zip(self.keys, values))
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def flatten(cls, obj):
|
114 |
+
for k in obj.keys():
|
115 |
+
if not isinstance(k, str):
|
116 |
+
raise KeyError("Only support flattening dictionaries if keys are str.")
|
117 |
+
keys = sorted(obj.keys())
|
118 |
+
values = [obj[k] for k in keys]
|
119 |
+
ret, schema = ListSchema.flatten(values)
|
120 |
+
return ret, cls(schema.schemas, schema.sizes, keys)
|
121 |
+
|
122 |
+
|
123 |
+
@dataclass
|
124 |
+
class InstancesSchema(DictSchema):
|
125 |
+
def __call__(self, values):
|
126 |
+
image_size, fields = values[-1], values[:-1]
|
127 |
+
fields = super().__call__(fields)
|
128 |
+
return Instances(image_size, **fields)
|
129 |
+
|
130 |
+
@classmethod
|
131 |
+
def flatten(cls, obj):
|
132 |
+
ret, schema = super().flatten(obj.get_fields())
|
133 |
+
size = obj.image_size
|
134 |
+
if not isinstance(size, torch.Tensor):
|
135 |
+
size = torch.tensor(size)
|
136 |
+
return ret + (size,), schema
|
137 |
+
|
138 |
+
|
139 |
+
@dataclass
|
140 |
+
class TensorWrapSchema(Schema):
|
141 |
+
"""
|
142 |
+
For classes that are simple wrapper of tensors, e.g.
|
143 |
+
Boxes, RotatedBoxes, BitMasks
|
144 |
+
"""
|
145 |
+
|
146 |
+
class_name: str
|
147 |
+
|
148 |
+
def __call__(self, values):
|
149 |
+
return locate(self.class_name)(values[0])
|
150 |
+
|
151 |
+
@classmethod
|
152 |
+
def flatten(cls, obj):
|
153 |
+
return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
|
154 |
+
|
155 |
+
|
156 |
+
# if more custom structures needed in the future, can allow
|
157 |
+
# passing in extra schemas for custom types
|
158 |
+
def flatten_to_tuple(obj):
|
159 |
+
"""
|
160 |
+
Flatten an object so it can be used for PyTorch tracing.
|
161 |
+
Also returns how to rebuild the original object from the flattened outputs.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
res (tuple): the flattened results that can be used as tracing outputs
|
165 |
+
schema: an object with a ``__call__`` method such that ``schema(res) == obj``.
|
166 |
+
It is a pure dataclass that can be serialized.
|
167 |
+
"""
|
168 |
+
schemas = [
|
169 |
+
((str, bytes), IdentitySchema),
|
170 |
+
(list, ListSchema),
|
171 |
+
(tuple, TupleSchema),
|
172 |
+
(collections.abc.Mapping, DictSchema),
|
173 |
+
(Instances, InstancesSchema),
|
174 |
+
((Boxes, ROIMasks), TensorWrapSchema),
|
175 |
+
]
|
176 |
+
for klass, schema in schemas:
|
177 |
+
if isinstance(obj, klass):
|
178 |
+
F = schema
|
179 |
+
break
|
180 |
+
else:
|
181 |
+
F = IdentitySchema
|
182 |
+
|
183 |
+
return F.flatten(obj)
|
184 |
+
|
185 |
+
|
186 |
+
class TracingAdapter(nn.Module):
|
187 |
+
"""
|
188 |
+
A model may take rich input/output format (e.g. dict or custom classes),
|
189 |
+
but `torch.jit.trace` requires tuple of tensors as input/output.
|
190 |
+
This adapter flattens input/output format of a model so it becomes traceable.
|
191 |
+
|
192 |
+
It also records the necessary schema to rebuild model's inputs/outputs from flattened
|
193 |
+
inputs/outputs.
|
194 |
+
|
195 |
+
Example:
|
196 |
+
::
|
197 |
+
outputs = model(inputs) # inputs/outputs may be rich structure
|
198 |
+
adapter = TracingAdapter(model, inputs)
|
199 |
+
|
200 |
+
# can now trace the model, with adapter.flattened_inputs, or another
|
201 |
+
# tuple of tensors with the same length and meaning
|
202 |
+
traced = torch.jit.trace(adapter, adapter.flattened_inputs)
|
203 |
+
|
204 |
+
# traced model can only produce flattened outputs (tuple of tensors)
|
205 |
+
flattened_outputs = traced(*adapter.flattened_inputs)
|
206 |
+
# adapter knows the schema to convert it back (new_outputs == outputs)
|
207 |
+
new_outputs = adapter.outputs_schema(flattened_outputs)
|
208 |
+
"""
|
209 |
+
|
210 |
+
flattened_inputs: Tuple[torch.Tensor] = None
|
211 |
+
"""
|
212 |
+
Flattened version of inputs given to this class's constructor.
|
213 |
+
"""
|
214 |
+
|
215 |
+
inputs_schema: Schema = None
|
216 |
+
"""
|
217 |
+
Schema of the inputs given to this class's constructor.
|
218 |
+
"""
|
219 |
+
|
220 |
+
outputs_schema: Schema = None
|
221 |
+
"""
|
222 |
+
Schema of the output produced by calling the given model with inputs.
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(
|
226 |
+
self,
|
227 |
+
model: nn.Module,
|
228 |
+
inputs,
|
229 |
+
inference_func: Optional[Callable] = None,
|
230 |
+
allow_non_tensor: bool = False,
|
231 |
+
):
|
232 |
+
"""
|
233 |
+
Args:
|
234 |
+
model: an nn.Module
|
235 |
+
inputs: An input argument or a tuple of input arguments used to call model.
|
236 |
+
After flattening, it has to only consist of tensors.
|
237 |
+
inference_func: a callable that takes (model, *inputs), calls the
|
238 |
+
model with inputs, and return outputs. By default it
|
239 |
+
is ``lambda model, *inputs: model(*inputs)``. Can be override
|
240 |
+
if you need to call the model differently.
|
241 |
+
allow_non_tensor: allow inputs/outputs to contain non-tensor objects.
|
242 |
+
This option will filter out non-tensor objects to make the
|
243 |
+
model traceable, but ``inputs_schema``/``outputs_schema`` cannot be
|
244 |
+
used anymore because inputs/outputs cannot be rebuilt from pure tensors.
|
245 |
+
This is useful when you're only interested in the single trace of
|
246 |
+
execution (e.g. for flop count), but not interested in
|
247 |
+
generalizing the traced graph to new inputs.
|
248 |
+
"""
|
249 |
+
super().__init__()
|
250 |
+
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
|
251 |
+
model = model.module
|
252 |
+
self.model = model
|
253 |
+
if not isinstance(inputs, tuple):
|
254 |
+
inputs = (inputs,)
|
255 |
+
self.inputs = inputs
|
256 |
+
self.allow_non_tensor = allow_non_tensor
|
257 |
+
|
258 |
+
if inference_func is None:
|
259 |
+
inference_func = lambda model, *inputs: model(*inputs) # noqa
|
260 |
+
self.inference_func = inference_func
|
261 |
+
|
262 |
+
self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
|
263 |
+
|
264 |
+
if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs):
|
265 |
+
return
|
266 |
+
if self.allow_non_tensor:
|
267 |
+
self.flattened_inputs = tuple(
|
268 |
+
[x for x in self.flattened_inputs if isinstance(x, torch.Tensor)]
|
269 |
+
)
|
270 |
+
self.inputs_schema = None
|
271 |
+
else:
|
272 |
+
for input in self.flattened_inputs:
|
273 |
+
if not isinstance(input, torch.Tensor):
|
274 |
+
raise ValueError(
|
275 |
+
"Inputs for tracing must only contain tensors. "
|
276 |
+
f"Got a {type(input)} instead."
|
277 |
+
)
|
278 |
+
|
279 |
+
def forward(self, *args: torch.Tensor):
|
280 |
+
with torch.no_grad(), patch_builtin_len():
|
281 |
+
if self.inputs_schema is not None:
|
282 |
+
inputs_orig_format = self.inputs_schema(args)
|
283 |
+
else:
|
284 |
+
if len(args) != len(self.flattened_inputs) or any(
|
285 |
+
x is not y for x, y in zip(args, self.flattened_inputs)
|
286 |
+
):
|
287 |
+
raise ValueError(
|
288 |
+
"TracingAdapter does not contain valid inputs_schema."
|
289 |
+
" So it cannot generalize to other inputs and must be"
|
290 |
+
" traced with `.flattened_inputs`."
|
291 |
+
)
|
292 |
+
inputs_orig_format = self.inputs
|
293 |
+
|
294 |
+
outputs = self.inference_func(self.model, *inputs_orig_format)
|
295 |
+
flattened_outputs, schema = flatten_to_tuple(outputs)
|
296 |
+
|
297 |
+
flattened_output_tensors = tuple(
|
298 |
+
[x for x in flattened_outputs if isinstance(x, torch.Tensor)]
|
299 |
+
)
|
300 |
+
if len(flattened_output_tensors) < len(flattened_outputs):
|
301 |
+
if self.allow_non_tensor:
|
302 |
+
flattened_outputs = flattened_output_tensors
|
303 |
+
self.outputs_schema = None
|
304 |
+
else:
|
305 |
+
raise ValueError(
|
306 |
+
"Model cannot be traced because some model outputs "
|
307 |
+
"cannot flatten to tensors."
|
308 |
+
)
|
309 |
+
else: # schema is valid
|
310 |
+
if self.outputs_schema is None:
|
311 |
+
self.outputs_schema = schema
|
312 |
+
else:
|
313 |
+
assert self.outputs_schema == schema, (
|
314 |
+
"Model should always return outputs with the same "
|
315 |
+
"structure so it can be traced!"
|
316 |
+
)
|
317 |
+
return flattened_outputs
|
318 |
+
|
319 |
+
def _create_wrapper(self, traced_model):
|
320 |
+
"""
|
321 |
+
Return a function that has an input/output interface the same as the
|
322 |
+
original model, but it calls the given traced model under the hood.
|
323 |
+
"""
|
324 |
+
|
325 |
+
def forward(*args):
|
326 |
+
flattened_inputs, _ = flatten_to_tuple(args)
|
327 |
+
flattened_outputs = traced_model(*flattened_inputs)
|
328 |
+
return self.outputs_schema(flattened_outputs)
|
329 |
+
|
330 |
+
return forward
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/shared.py
ADDED
@@ -0,0 +1,1039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import copy
|
5 |
+
import functools
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
10 |
+
from unittest import mock
|
11 |
+
import caffe2.python.utils as putils
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from caffe2.proto import caffe2_pb2
|
15 |
+
from caffe2.python import core, net_drawer, workspace
|
16 |
+
from torch.nn.functional import interpolate as interp
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
# ==== torch/utils_toffee/cast.py =======================================
|
22 |
+
|
23 |
+
|
24 |
+
def to_device(t, device_str):
|
25 |
+
"""
|
26 |
+
This function is a replacement of .to(another_device) such that it allows the
|
27 |
+
casting to be traced properly by explicitly calling the underlying copy ops.
|
28 |
+
It also avoids introducing unncessary op when casting to the same device.
|
29 |
+
"""
|
30 |
+
src = t.device
|
31 |
+
dst = torch.device(device_str)
|
32 |
+
|
33 |
+
if src == dst:
|
34 |
+
return t
|
35 |
+
elif src.type == "cuda" and dst.type == "cpu":
|
36 |
+
return torch.ops._caffe2.CopyGPUToCPU(t)
|
37 |
+
elif src.type == "cpu" and dst.type == "cuda":
|
38 |
+
return torch.ops._caffe2.CopyCPUToGPU(t)
|
39 |
+
else:
|
40 |
+
raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst))
|
41 |
+
|
42 |
+
|
43 |
+
# ==== torch/utils_toffee/interpolate.py =======================================
|
44 |
+
|
45 |
+
|
46 |
+
# Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py
|
47 |
+
def BilinearInterpolation(tensor_in, up_scale):
|
48 |
+
assert up_scale % 2 == 0, "Scale should be even"
|
49 |
+
|
50 |
+
def upsample_filt(size):
|
51 |
+
factor = (size + 1) // 2
|
52 |
+
if size % 2 == 1:
|
53 |
+
center = factor - 1
|
54 |
+
else:
|
55 |
+
center = factor - 0.5
|
56 |
+
|
57 |
+
og = np.ogrid[:size, :size]
|
58 |
+
return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
|
59 |
+
|
60 |
+
kernel_size = int(up_scale) * 2
|
61 |
+
bil_filt = upsample_filt(kernel_size)
|
62 |
+
|
63 |
+
dim = int(tensor_in.shape[1])
|
64 |
+
kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32)
|
65 |
+
kernel[range(dim), range(dim), :, :] = bil_filt
|
66 |
+
|
67 |
+
tensor_out = F.conv_transpose2d(
|
68 |
+
tensor_in,
|
69 |
+
weight=to_device(torch.Tensor(kernel), tensor_in.device),
|
70 |
+
bias=None,
|
71 |
+
stride=int(up_scale),
|
72 |
+
padding=int(up_scale / 2),
|
73 |
+
)
|
74 |
+
|
75 |
+
return tensor_out
|
76 |
+
|
77 |
+
|
78 |
+
# NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if
|
79 |
+
# using dynamic `scale_factor` rather than static `size`. (T43166860)
|
80 |
+
# NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly.
|
81 |
+
def onnx_compatibale_interpolate(
|
82 |
+
input, size=None, scale_factor=None, mode="nearest", align_corners=None
|
83 |
+
):
|
84 |
+
# NOTE: The input dimensions are interpreted in the form:
|
85 |
+
# `mini-batch x channels x [optional depth] x [optional height] x width`.
|
86 |
+
if size is None and scale_factor is not None:
|
87 |
+
if input.dim() == 4:
|
88 |
+
if isinstance(scale_factor, (int, float)):
|
89 |
+
height_scale, width_scale = (scale_factor, scale_factor)
|
90 |
+
else:
|
91 |
+
assert isinstance(scale_factor, (tuple, list))
|
92 |
+
assert len(scale_factor) == 2
|
93 |
+
height_scale, width_scale = scale_factor
|
94 |
+
|
95 |
+
assert not align_corners, "No matching C2 op for align_corners == True"
|
96 |
+
if mode == "nearest":
|
97 |
+
return torch.ops._caffe2.ResizeNearest(
|
98 |
+
input, order="NCHW", width_scale=width_scale, height_scale=height_scale
|
99 |
+
)
|
100 |
+
elif mode == "bilinear":
|
101 |
+
logger.warning(
|
102 |
+
"Use F.conv_transpose2d for bilinear interpolate"
|
103 |
+
" because there's no such C2 op, this may cause significant"
|
104 |
+
" slowdown and the boundary pixels won't be as same as"
|
105 |
+
" using F.interpolate due to padding."
|
106 |
+
)
|
107 |
+
assert height_scale == width_scale
|
108 |
+
return BilinearInterpolation(input, up_scale=height_scale)
|
109 |
+
logger.warning("Output size is not static, it might cause ONNX conversion issue")
|
110 |
+
|
111 |
+
return interp(input, size, scale_factor, mode, align_corners)
|
112 |
+
|
113 |
+
|
114 |
+
def mock_torch_nn_functional_interpolate():
|
115 |
+
def decorator(func):
|
116 |
+
@functools.wraps(func)
|
117 |
+
def _mock_torch_nn_functional_interpolate(*args, **kwargs):
|
118 |
+
if torch.onnx.is_in_onnx_export():
|
119 |
+
with mock.patch(
|
120 |
+
"torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate
|
121 |
+
):
|
122 |
+
return func(*args, **kwargs)
|
123 |
+
else:
|
124 |
+
return func(*args, **kwargs)
|
125 |
+
|
126 |
+
return _mock_torch_nn_functional_interpolate
|
127 |
+
|
128 |
+
return decorator
|
129 |
+
|
130 |
+
|
131 |
+
# ==== torch/utils_caffe2/ws_utils.py ==========================================
|
132 |
+
|
133 |
+
|
134 |
+
class ScopedWS(object):
|
135 |
+
def __init__(self, ws_name, is_reset, is_cleanup=False):
|
136 |
+
self.ws_name = ws_name
|
137 |
+
self.is_reset = is_reset
|
138 |
+
self.is_cleanup = is_cleanup
|
139 |
+
self.org_ws = ""
|
140 |
+
|
141 |
+
def __enter__(self):
|
142 |
+
self.org_ws = workspace.CurrentWorkspace()
|
143 |
+
if self.ws_name is not None:
|
144 |
+
workspace.SwitchWorkspace(self.ws_name, True)
|
145 |
+
if self.is_reset:
|
146 |
+
workspace.ResetWorkspace()
|
147 |
+
|
148 |
+
return workspace
|
149 |
+
|
150 |
+
def __exit__(self, *args):
|
151 |
+
if self.is_cleanup:
|
152 |
+
workspace.ResetWorkspace()
|
153 |
+
if self.ws_name is not None:
|
154 |
+
workspace.SwitchWorkspace(self.org_ws)
|
155 |
+
|
156 |
+
|
157 |
+
def fetch_any_blob(name):
|
158 |
+
bb = None
|
159 |
+
try:
|
160 |
+
bb = workspace.FetchBlob(name)
|
161 |
+
except TypeError:
|
162 |
+
bb = workspace.FetchInt8Blob(name)
|
163 |
+
except Exception as e:
|
164 |
+
logger.error("Get blob {} error: {}".format(name, e))
|
165 |
+
|
166 |
+
return bb
|
167 |
+
|
168 |
+
|
169 |
+
# ==== torch/utils_caffe2/protobuf.py ==========================================
|
170 |
+
|
171 |
+
|
172 |
+
def get_pb_arg(pb, arg_name):
|
173 |
+
for x in pb.arg:
|
174 |
+
if x.name == arg_name:
|
175 |
+
return x
|
176 |
+
return None
|
177 |
+
|
178 |
+
|
179 |
+
def get_pb_arg_valf(pb, arg_name, default_val):
|
180 |
+
arg = get_pb_arg(pb, arg_name)
|
181 |
+
return arg.f if arg is not None else default_val
|
182 |
+
|
183 |
+
|
184 |
+
def get_pb_arg_floats(pb, arg_name, default_val):
|
185 |
+
arg = get_pb_arg(pb, arg_name)
|
186 |
+
return list(map(float, arg.floats)) if arg is not None else default_val
|
187 |
+
|
188 |
+
|
189 |
+
def get_pb_arg_ints(pb, arg_name, default_val):
|
190 |
+
arg = get_pb_arg(pb, arg_name)
|
191 |
+
return list(map(int, arg.ints)) if arg is not None else default_val
|
192 |
+
|
193 |
+
|
194 |
+
def get_pb_arg_vali(pb, arg_name, default_val):
|
195 |
+
arg = get_pb_arg(pb, arg_name)
|
196 |
+
return arg.i if arg is not None else default_val
|
197 |
+
|
198 |
+
|
199 |
+
def get_pb_arg_vals(pb, arg_name, default_val):
|
200 |
+
arg = get_pb_arg(pb, arg_name)
|
201 |
+
return arg.s if arg is not None else default_val
|
202 |
+
|
203 |
+
|
204 |
+
def get_pb_arg_valstrings(pb, arg_name, default_val):
|
205 |
+
arg = get_pb_arg(pb, arg_name)
|
206 |
+
return list(arg.strings) if arg is not None else default_val
|
207 |
+
|
208 |
+
|
209 |
+
def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False):
|
210 |
+
arg = get_pb_arg(pb, arg_name)
|
211 |
+
if arg is None:
|
212 |
+
arg = putils.MakeArgument(arg_name, arg_value)
|
213 |
+
assert hasattr(arg, arg_attr)
|
214 |
+
pb.arg.extend([arg])
|
215 |
+
if allow_override and getattr(arg, arg_attr) != arg_value:
|
216 |
+
logger.warning(
|
217 |
+
"Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value)
|
218 |
+
)
|
219 |
+
setattr(arg, arg_attr, arg_value)
|
220 |
+
else:
|
221 |
+
assert arg is not None
|
222 |
+
assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format(
|
223 |
+
getattr(arg, arg_attr), arg_value
|
224 |
+
)
|
225 |
+
|
226 |
+
|
227 |
+
def _create_const_fill_op_from_numpy(name, tensor, device_option=None):
|
228 |
+
assert type(tensor) == np.ndarray
|
229 |
+
kTypeNameMapper = {
|
230 |
+
np.dtype("float32"): "GivenTensorFill",
|
231 |
+
np.dtype("int32"): "GivenTensorIntFill",
|
232 |
+
np.dtype("int64"): "GivenTensorInt64Fill",
|
233 |
+
np.dtype("uint8"): "GivenTensorStringFill",
|
234 |
+
}
|
235 |
+
|
236 |
+
args_dict = {}
|
237 |
+
if tensor.dtype == np.dtype("uint8"):
|
238 |
+
args_dict.update({"values": [str(tensor.data)], "shape": [1]})
|
239 |
+
else:
|
240 |
+
args_dict.update({"values": tensor, "shape": tensor.shape})
|
241 |
+
|
242 |
+
if device_option is not None:
|
243 |
+
args_dict["device_option"] = device_option
|
244 |
+
|
245 |
+
return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict)
|
246 |
+
|
247 |
+
|
248 |
+
def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor):
|
249 |
+
assert type(int8_tensor) == workspace.Int8Tensor
|
250 |
+
kTypeNameMapper = {
|
251 |
+
np.dtype("int32"): "Int8GivenIntTensorFill",
|
252 |
+
np.dtype("uint8"): "Int8GivenTensorFill",
|
253 |
+
}
|
254 |
+
|
255 |
+
tensor = int8_tensor.data
|
256 |
+
assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")]
|
257 |
+
values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor
|
258 |
+
|
259 |
+
return core.CreateOperator(
|
260 |
+
kTypeNameMapper[tensor.dtype],
|
261 |
+
[],
|
262 |
+
[name],
|
263 |
+
values=values,
|
264 |
+
shape=tensor.shape,
|
265 |
+
Y_scale=int8_tensor.scale,
|
266 |
+
Y_zero_point=int8_tensor.zero_point,
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
def create_const_fill_op(
|
271 |
+
name: str,
|
272 |
+
blob: Union[np.ndarray, workspace.Int8Tensor],
|
273 |
+
device_option: Optional[caffe2_pb2.DeviceOption] = None,
|
274 |
+
) -> caffe2_pb2.OperatorDef:
|
275 |
+
"""
|
276 |
+
Given a blob object, return the Caffe2 operator that creates this blob
|
277 |
+
as constant. Currently support NumPy tensor and Caffe2 Int8Tensor.
|
278 |
+
"""
|
279 |
+
|
280 |
+
tensor_type = type(blob)
|
281 |
+
assert tensor_type in [
|
282 |
+
np.ndarray,
|
283 |
+
workspace.Int8Tensor,
|
284 |
+
], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format(
|
285 |
+
name, type(blob)
|
286 |
+
)
|
287 |
+
|
288 |
+
if tensor_type == np.ndarray:
|
289 |
+
return _create_const_fill_op_from_numpy(name, blob, device_option)
|
290 |
+
elif tensor_type == workspace.Int8Tensor:
|
291 |
+
assert device_option is None
|
292 |
+
return _create_const_fill_op_from_c2_int8_tensor(name, blob)
|
293 |
+
|
294 |
+
|
295 |
+
def construct_init_net_from_params(
|
296 |
+
params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None
|
297 |
+
) -> caffe2_pb2.NetDef:
|
298 |
+
"""
|
299 |
+
Construct the init_net from params dictionary
|
300 |
+
"""
|
301 |
+
init_net = caffe2_pb2.NetDef()
|
302 |
+
device_options = device_options or {}
|
303 |
+
for name, blob in params.items():
|
304 |
+
if isinstance(blob, str):
|
305 |
+
logger.warning(
|
306 |
+
(
|
307 |
+
"Blob {} with type {} is not supported in generating init net,"
|
308 |
+
" skipped.".format(name, type(blob))
|
309 |
+
)
|
310 |
+
)
|
311 |
+
continue
|
312 |
+
init_net.op.extend(
|
313 |
+
[create_const_fill_op(name, blob, device_option=device_options.get(name, None))]
|
314 |
+
)
|
315 |
+
init_net.external_output.append(name)
|
316 |
+
return init_net
|
317 |
+
|
318 |
+
|
319 |
+
def get_producer_map(ssa):
|
320 |
+
"""
|
321 |
+
Return dict from versioned blob to (i, j),
|
322 |
+
where i is index of producer op, j is the index of output of that op.
|
323 |
+
"""
|
324 |
+
producer_map = {}
|
325 |
+
for i in range(len(ssa)):
|
326 |
+
outputs = ssa[i][1]
|
327 |
+
for j, outp in enumerate(outputs):
|
328 |
+
producer_map[outp] = (i, j)
|
329 |
+
return producer_map
|
330 |
+
|
331 |
+
|
332 |
+
def get_consumer_map(ssa):
|
333 |
+
"""
|
334 |
+
Return dict from versioned blob to list of (i, j),
|
335 |
+
where i is index of consumer op, j is the index of input of that op.
|
336 |
+
"""
|
337 |
+
consumer_map = collections.defaultdict(list)
|
338 |
+
for i in range(len(ssa)):
|
339 |
+
inputs = ssa[i][0]
|
340 |
+
for j, inp in enumerate(inputs):
|
341 |
+
consumer_map[inp].append((i, j))
|
342 |
+
return consumer_map
|
343 |
+
|
344 |
+
|
345 |
+
def get_params_from_init_net(
|
346 |
+
init_net: caffe2_pb2.NetDef,
|
347 |
+
) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]:
|
348 |
+
"""
|
349 |
+
Take the output blobs from init_net by running it.
|
350 |
+
Outputs:
|
351 |
+
params: dict from blob name to numpy array
|
352 |
+
device_options: dict from blob name to the device option of its creating op
|
353 |
+
"""
|
354 |
+
# NOTE: this assumes that the params is determined by producer op with the
|
355 |
+
# only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor.
|
356 |
+
def _get_device_option(producer_op):
|
357 |
+
if producer_op.type == "CopyGPUToCPU":
|
358 |
+
return caffe2_pb2.DeviceOption()
|
359 |
+
else:
|
360 |
+
return producer_op.device_option
|
361 |
+
|
362 |
+
with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws:
|
363 |
+
ws.RunNetOnce(init_net)
|
364 |
+
params = {b: fetch_any_blob(b) for b in init_net.external_output}
|
365 |
+
ssa, versions = core.get_ssa(init_net)
|
366 |
+
producer_map = get_producer_map(ssa)
|
367 |
+
device_options = {
|
368 |
+
b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]])
|
369 |
+
for b in init_net.external_output
|
370 |
+
}
|
371 |
+
return params, device_options
|
372 |
+
|
373 |
+
|
374 |
+
def _updater_raise(op, input_types, output_types):
|
375 |
+
raise RuntimeError(
|
376 |
+
"Failed to apply updater for op {} given input_types {} and"
|
377 |
+
" output_types {}".format(op, input_types, output_types)
|
378 |
+
)
|
379 |
+
|
380 |
+
|
381 |
+
def _generic_status_identifier(
|
382 |
+
predict_net: caffe2_pb2.NetDef,
|
383 |
+
status_updater: Callable,
|
384 |
+
known_status: Dict[Tuple[str, int], Any],
|
385 |
+
) -> Dict[Tuple[str, int], Any]:
|
386 |
+
"""
|
387 |
+
Statically infer the status of each blob, the status can be such as device type
|
388 |
+
(CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here
|
389 |
+
is versioned blob (Tuple[str, int]) in the format compatible with ssa.
|
390 |
+
Inputs:
|
391 |
+
predict_net: the caffe2 network
|
392 |
+
status_updater: a callable, given an op and the status of its input/output,
|
393 |
+
it returns the updated status of input/output. `None` is used for
|
394 |
+
representing unknown status.
|
395 |
+
known_status: a dict containing known status, used as initialization.
|
396 |
+
Outputs:
|
397 |
+
A dict mapping from versioned blob to its status
|
398 |
+
"""
|
399 |
+
ssa, versions = core.get_ssa(predict_net)
|
400 |
+
versioned_ext_input = [(b, 0) for b in predict_net.external_input]
|
401 |
+
versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output]
|
402 |
+
all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa])
|
403 |
+
|
404 |
+
allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output)
|
405 |
+
assert all(k in allowed_vbs for k in known_status)
|
406 |
+
assert all(v is not None for v in known_status.values())
|
407 |
+
_known_status = copy.deepcopy(known_status)
|
408 |
+
|
409 |
+
def _check_and_update(key, value):
|
410 |
+
assert value is not None
|
411 |
+
if key in _known_status:
|
412 |
+
if not _known_status[key] == value:
|
413 |
+
raise RuntimeError(
|
414 |
+
"Confilict status for {}, existing status {}, new status {}".format(
|
415 |
+
key, _known_status[key], value
|
416 |
+
)
|
417 |
+
)
|
418 |
+
_known_status[key] = value
|
419 |
+
|
420 |
+
def _update_i(op, ssa_i):
|
421 |
+
versioned_inputs = ssa_i[0]
|
422 |
+
versioned_outputs = ssa_i[1]
|
423 |
+
|
424 |
+
inputs_status = [_known_status.get(b, None) for b in versioned_inputs]
|
425 |
+
outputs_status = [_known_status.get(b, None) for b in versioned_outputs]
|
426 |
+
|
427 |
+
new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status)
|
428 |
+
|
429 |
+
for versioned_blob, status in zip(
|
430 |
+
versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status
|
431 |
+
):
|
432 |
+
if status is not None:
|
433 |
+
_check_and_update(versioned_blob, status)
|
434 |
+
|
435 |
+
for op, ssa_i in zip(predict_net.op, ssa):
|
436 |
+
_update_i(op, ssa_i)
|
437 |
+
for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)):
|
438 |
+
_update_i(op, ssa_i)
|
439 |
+
|
440 |
+
# NOTE: This strictly checks all the blob from predict_net must be assgined
|
441 |
+
# a known status. However sometimes it's impossible (eg. having deadend op),
|
442 |
+
# we may relax this constraint if
|
443 |
+
for k in all_versioned_blobs:
|
444 |
+
if k not in _known_status:
|
445 |
+
raise NotImplementedError(
|
446 |
+
"Can not infer the status for {}. Currently only support the case where"
|
447 |
+
" a single forward and backward pass can identify status for all blobs.".format(k)
|
448 |
+
)
|
449 |
+
|
450 |
+
return _known_status
|
451 |
+
|
452 |
+
|
453 |
+
def infer_device_type(
|
454 |
+
predict_net: caffe2_pb2.NetDef,
|
455 |
+
known_status: Dict[Tuple[str, int], Any],
|
456 |
+
device_name_style: str = "caffe2",
|
457 |
+
) -> Dict[Tuple[str, int], str]:
|
458 |
+
"""Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob"""
|
459 |
+
|
460 |
+
assert device_name_style in ["caffe2", "pytorch"]
|
461 |
+
_CPU_STR = "cpu"
|
462 |
+
_GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda"
|
463 |
+
|
464 |
+
def _copy_cpu_to_gpu_updater(op, input_types, output_types):
|
465 |
+
if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR:
|
466 |
+
_updater_raise(op, input_types, output_types)
|
467 |
+
return ([_CPU_STR], [_GPU_STR])
|
468 |
+
|
469 |
+
def _copy_gpu_to_cpu_updater(op, input_types, output_types):
|
470 |
+
if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR:
|
471 |
+
_updater_raise(op, input_types, output_types)
|
472 |
+
return ([_GPU_STR], [_CPU_STR])
|
473 |
+
|
474 |
+
def _other_ops_updater(op, input_types, output_types):
|
475 |
+
non_none_types = [x for x in input_types + output_types if x is not None]
|
476 |
+
if len(non_none_types) > 0:
|
477 |
+
the_type = non_none_types[0]
|
478 |
+
if not all(x == the_type for x in non_none_types):
|
479 |
+
_updater_raise(op, input_types, output_types)
|
480 |
+
else:
|
481 |
+
the_type = None
|
482 |
+
return ([the_type for _ in op.input], [the_type for _ in op.output])
|
483 |
+
|
484 |
+
def _device_updater(op, *args, **kwargs):
|
485 |
+
return {
|
486 |
+
"CopyCPUToGPU": _copy_cpu_to_gpu_updater,
|
487 |
+
"CopyGPUToCPU": _copy_gpu_to_cpu_updater,
|
488 |
+
}.get(op.type, _other_ops_updater)(op, *args, **kwargs)
|
489 |
+
|
490 |
+
return _generic_status_identifier(predict_net, _device_updater, known_status)
|
491 |
+
|
492 |
+
|
493 |
+
# ==== torch/utils_caffe2/vis.py ===============================================
|
494 |
+
|
495 |
+
|
496 |
+
def _modify_blob_names(ops, blob_rename_f):
|
497 |
+
ret = []
|
498 |
+
|
499 |
+
def _replace_list(blob_list, replaced_list):
|
500 |
+
del blob_list[:]
|
501 |
+
blob_list.extend(replaced_list)
|
502 |
+
|
503 |
+
for x in ops:
|
504 |
+
cur = copy.deepcopy(x)
|
505 |
+
_replace_list(cur.input, list(map(blob_rename_f, cur.input)))
|
506 |
+
_replace_list(cur.output, list(map(blob_rename_f, cur.output)))
|
507 |
+
ret.append(cur)
|
508 |
+
|
509 |
+
return ret
|
510 |
+
|
511 |
+
|
512 |
+
def _rename_blob(name, blob_sizes, blob_ranges):
|
513 |
+
def _list_to_str(bsize):
|
514 |
+
ret = ", ".join([str(x) for x in bsize])
|
515 |
+
ret = "[" + ret + "]"
|
516 |
+
return ret
|
517 |
+
|
518 |
+
ret = name
|
519 |
+
if blob_sizes is not None and name in blob_sizes:
|
520 |
+
ret += "\n" + _list_to_str(blob_sizes[name])
|
521 |
+
if blob_ranges is not None and name in blob_ranges:
|
522 |
+
ret += "\n" + _list_to_str(blob_ranges[name])
|
523 |
+
|
524 |
+
return ret
|
525 |
+
|
526 |
+
|
527 |
+
# graph_name could not contain word 'graph'
|
528 |
+
def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None):
|
529 |
+
blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges)
|
530 |
+
return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f)
|
531 |
+
|
532 |
+
|
533 |
+
def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None):
|
534 |
+
graph = None
|
535 |
+
ops = net.op
|
536 |
+
if blob_rename_func is not None:
|
537 |
+
ops = _modify_blob_names(ops, blob_rename_func)
|
538 |
+
if not op_only:
|
539 |
+
graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB")
|
540 |
+
else:
|
541 |
+
graph = net_drawer.GetPydotGraphMinimal(
|
542 |
+
ops, graph_name, rankdir="TB", minimal_dependency=True
|
543 |
+
)
|
544 |
+
|
545 |
+
try:
|
546 |
+
par_dir = os.path.dirname(file_name)
|
547 |
+
if not os.path.exists(par_dir):
|
548 |
+
os.makedirs(par_dir)
|
549 |
+
|
550 |
+
format = os.path.splitext(os.path.basename(file_name))[-1]
|
551 |
+
if format == ".png":
|
552 |
+
graph.write_png(file_name)
|
553 |
+
elif format == ".pdf":
|
554 |
+
graph.write_pdf(file_name)
|
555 |
+
elif format == ".svg":
|
556 |
+
graph.write_svg(file_name)
|
557 |
+
else:
|
558 |
+
print("Incorrect format {}".format(format))
|
559 |
+
except Exception as e:
|
560 |
+
print("Error when writing graph to image {}".format(e))
|
561 |
+
|
562 |
+
return graph
|
563 |
+
|
564 |
+
|
565 |
+
# ==== torch/utils_toffee/aten_to_caffe2.py ====================================
|
566 |
+
|
567 |
+
|
568 |
+
def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef):
|
569 |
+
"""
|
570 |
+
For ONNX exported model, GroupNorm will be represented as ATen op,
|
571 |
+
this can be a drop in replacement from ATen to GroupNorm
|
572 |
+
"""
|
573 |
+
count = 0
|
574 |
+
for op in predict_net.op:
|
575 |
+
if op.type == "ATen":
|
576 |
+
op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3
|
577 |
+
if op_name and op_name.decode() == "group_norm":
|
578 |
+
op.arg.remove(get_pb_arg(op, "operator"))
|
579 |
+
|
580 |
+
if get_pb_arg_vali(op, "cudnn_enabled", None):
|
581 |
+
op.arg.remove(get_pb_arg(op, "cudnn_enabled"))
|
582 |
+
|
583 |
+
num_groups = get_pb_arg_vali(op, "num_groups", None)
|
584 |
+
if num_groups is not None:
|
585 |
+
op.arg.remove(get_pb_arg(op, "num_groups"))
|
586 |
+
check_set_pb_arg(op, "group", "i", num_groups)
|
587 |
+
|
588 |
+
op.type = "GroupNorm"
|
589 |
+
count += 1
|
590 |
+
if count > 1:
|
591 |
+
logger.info("Replaced {} ATen operator to GroupNormOp".format(count))
|
592 |
+
|
593 |
+
|
594 |
+
# ==== torch/utils_toffee/alias.py =============================================
|
595 |
+
|
596 |
+
|
597 |
+
def alias(x, name, is_backward=False):
|
598 |
+
if not torch.onnx.is_in_onnx_export():
|
599 |
+
return x
|
600 |
+
assert isinstance(x, torch.Tensor)
|
601 |
+
return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
|
602 |
+
|
603 |
+
|
604 |
+
def fuse_alias_placeholder(predict_net, init_net):
|
605 |
+
"""Remove AliasWithName placeholder and rename the input/output of it"""
|
606 |
+
# First we finish all the re-naming
|
607 |
+
for i, op in enumerate(predict_net.op):
|
608 |
+
if op.type == "AliasWithName":
|
609 |
+
assert len(op.input) == 1
|
610 |
+
assert len(op.output) == 1
|
611 |
+
name = get_pb_arg_vals(op, "name", None).decode()
|
612 |
+
is_backward = bool(get_pb_arg_vali(op, "is_backward", 0))
|
613 |
+
rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward)
|
614 |
+
rename_op_output(predict_net, i, 0, name)
|
615 |
+
|
616 |
+
# Remove AliasWithName, should be very safe since it's a non-op
|
617 |
+
new_ops = []
|
618 |
+
for op in predict_net.op:
|
619 |
+
if op.type != "AliasWithName":
|
620 |
+
new_ops.append(op)
|
621 |
+
else:
|
622 |
+
# safety check
|
623 |
+
assert op.input == op.output
|
624 |
+
assert op.input[0] == op.arg[0].s.decode()
|
625 |
+
del predict_net.op[:]
|
626 |
+
predict_net.op.extend(new_ops)
|
627 |
+
|
628 |
+
|
629 |
+
# ==== torch/utils_caffe2/graph_transform.py ===================================
|
630 |
+
|
631 |
+
|
632 |
+
class IllegalGraphTransformError(ValueError):
|
633 |
+
"""When a graph transform function call can't be executed."""
|
634 |
+
|
635 |
+
|
636 |
+
def _rename_versioned_blob_in_proto(
|
637 |
+
proto: caffe2_pb2.NetDef,
|
638 |
+
old_name: str,
|
639 |
+
new_name: str,
|
640 |
+
version: int,
|
641 |
+
ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]],
|
642 |
+
start_versions: Dict[str, int],
|
643 |
+
end_versions: Dict[str, int],
|
644 |
+
):
|
645 |
+
"""In given proto, rename all blobs with matched version"""
|
646 |
+
# Operater list
|
647 |
+
for op, i_th_ssa in zip(proto.op, ssa):
|
648 |
+
versioned_inputs, versioned_outputs = i_th_ssa
|
649 |
+
for i in range(len(op.input)):
|
650 |
+
if versioned_inputs[i] == (old_name, version):
|
651 |
+
op.input[i] = new_name
|
652 |
+
for i in range(len(op.output)):
|
653 |
+
if versioned_outputs[i] == (old_name, version):
|
654 |
+
op.output[i] = new_name
|
655 |
+
# external_input
|
656 |
+
if start_versions.get(old_name, 0) == version:
|
657 |
+
for i in range(len(proto.external_input)):
|
658 |
+
if proto.external_input[i] == old_name:
|
659 |
+
proto.external_input[i] = new_name
|
660 |
+
# external_output
|
661 |
+
if end_versions.get(old_name, 0) == version:
|
662 |
+
for i in range(len(proto.external_output)):
|
663 |
+
if proto.external_output[i] == old_name:
|
664 |
+
proto.external_output[i] = new_name
|
665 |
+
|
666 |
+
|
667 |
+
def rename_op_input(
|
668 |
+
predict_net: caffe2_pb2.NetDef,
|
669 |
+
init_net: caffe2_pb2.NetDef,
|
670 |
+
op_id: int,
|
671 |
+
input_id: int,
|
672 |
+
new_name: str,
|
673 |
+
from_producer: bool = False,
|
674 |
+
):
|
675 |
+
"""
|
676 |
+
Rename the op_id-th operator in predict_net, change it's input_id-th input's
|
677 |
+
name to the new_name. It also does automatic re-route and change
|
678 |
+
external_input and init_net if necessary.
|
679 |
+
- It requires the input is only consumed by this op.
|
680 |
+
- This function modifies predict_net and init_net in-place.
|
681 |
+
- When from_producer is enable, this also updates other operators that consumes
|
682 |
+
the same input. Be cautious because may trigger unintended behavior.
|
683 |
+
"""
|
684 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
685 |
+
assert isinstance(init_net, caffe2_pb2.NetDef)
|
686 |
+
|
687 |
+
init_net_ssa, init_net_versions = core.get_ssa(init_net)
|
688 |
+
predict_net_ssa, predict_net_versions = core.get_ssa(
|
689 |
+
predict_net, copy.deepcopy(init_net_versions)
|
690 |
+
)
|
691 |
+
|
692 |
+
versioned_inputs, versioned_outputs = predict_net_ssa[op_id]
|
693 |
+
old_name, version = versioned_inputs[input_id]
|
694 |
+
|
695 |
+
if from_producer:
|
696 |
+
producer_map = get_producer_map(predict_net_ssa)
|
697 |
+
if not (old_name, version) in producer_map:
|
698 |
+
raise NotImplementedError(
|
699 |
+
"Can't find producer, the input {} is probably from"
|
700 |
+
" init_net, this is not supported yet.".format(old_name)
|
701 |
+
)
|
702 |
+
producer = producer_map[(old_name, version)]
|
703 |
+
rename_op_output(predict_net, producer[0], producer[1], new_name)
|
704 |
+
return
|
705 |
+
|
706 |
+
def contain_targets(op_ssa):
|
707 |
+
return (old_name, version) in op_ssa[0]
|
708 |
+
|
709 |
+
is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa]
|
710 |
+
if sum(is_consumer) > 1:
|
711 |
+
raise IllegalGraphTransformError(
|
712 |
+
(
|
713 |
+
"Input '{}' of operator(#{}) are consumed by other ops, please use"
|
714 |
+
+ " rename_op_output on the producer instead. Offending op: \n{}"
|
715 |
+
).format(old_name, op_id, predict_net.op[op_id])
|
716 |
+
)
|
717 |
+
|
718 |
+
# update init_net
|
719 |
+
_rename_versioned_blob_in_proto(
|
720 |
+
init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions
|
721 |
+
)
|
722 |
+
# update predict_net
|
723 |
+
_rename_versioned_blob_in_proto(
|
724 |
+
predict_net,
|
725 |
+
old_name,
|
726 |
+
new_name,
|
727 |
+
version,
|
728 |
+
predict_net_ssa,
|
729 |
+
init_net_versions,
|
730 |
+
predict_net_versions,
|
731 |
+
)
|
732 |
+
|
733 |
+
|
734 |
+
def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str):
|
735 |
+
"""
|
736 |
+
Rename the op_id-th operator in predict_net, change it's output_id-th input's
|
737 |
+
name to the new_name. It also does automatic re-route and change
|
738 |
+
external_output and if necessary.
|
739 |
+
- It allows multiple consumers of its output.
|
740 |
+
- This function modifies predict_net in-place, doesn't need init_net.
|
741 |
+
"""
|
742 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
743 |
+
|
744 |
+
ssa, blob_versions = core.get_ssa(predict_net)
|
745 |
+
|
746 |
+
versioned_inputs, versioned_outputs = ssa[op_id]
|
747 |
+
old_name, version = versioned_outputs[output_id]
|
748 |
+
|
749 |
+
# update predict_net
|
750 |
+
_rename_versioned_blob_in_proto(
|
751 |
+
predict_net, old_name, new_name, version, ssa, {}, blob_versions
|
752 |
+
)
|
753 |
+
|
754 |
+
|
755 |
+
def get_sub_graph_external_input_output(
|
756 |
+
predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int]
|
757 |
+
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
758 |
+
"""
|
759 |
+
Return the list of external input/output of sub-graph,
|
760 |
+
each element is tuple of the name and corresponding version in predict_net.
|
761 |
+
|
762 |
+
external input/output is defined the same way as caffe2 NetDef.
|
763 |
+
"""
|
764 |
+
ssa, versions = core.get_ssa(predict_net)
|
765 |
+
|
766 |
+
all_inputs = []
|
767 |
+
all_outputs = []
|
768 |
+
for op_id in sub_graph_op_indices:
|
769 |
+
all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs]
|
770 |
+
all_outputs += list(ssa[op_id][1]) # ssa output won't repeat
|
771 |
+
|
772 |
+
# for versioned blobs, external inputs are just those blob in all_inputs
|
773 |
+
# but not in all_outputs
|
774 |
+
ext_inputs = [inp for inp in all_inputs if inp not in all_outputs]
|
775 |
+
|
776 |
+
# external outputs are essentially outputs of this subgraph that are used
|
777 |
+
# outside of this sub-graph (including predict_net.external_output)
|
778 |
+
all_other_inputs = sum(
|
779 |
+
(ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices),
|
780 |
+
[(outp, versions[outp]) for outp in predict_net.external_output],
|
781 |
+
)
|
782 |
+
ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)]
|
783 |
+
|
784 |
+
return ext_inputs, ext_outputs
|
785 |
+
|
786 |
+
|
787 |
+
class DiGraph:
|
788 |
+
"""A DAG representation of caffe2 graph, each vertice is a versioned blob."""
|
789 |
+
|
790 |
+
def __init__(self):
|
791 |
+
self.vertices = set()
|
792 |
+
self.graph = collections.defaultdict(list)
|
793 |
+
|
794 |
+
def add_edge(self, u, v):
|
795 |
+
self.graph[u].append(v)
|
796 |
+
self.vertices.add(u)
|
797 |
+
self.vertices.add(v)
|
798 |
+
|
799 |
+
# grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/
|
800 |
+
def get_all_paths(self, s, d):
|
801 |
+
visited = {k: False for k in self.vertices}
|
802 |
+
path = []
|
803 |
+
all_paths = []
|
804 |
+
|
805 |
+
def _get_all_paths_util(graph, u, d, visited, path):
|
806 |
+
visited[u] = True
|
807 |
+
path.append(u)
|
808 |
+
if u == d:
|
809 |
+
all_paths.append(copy.deepcopy(path))
|
810 |
+
else:
|
811 |
+
for i in graph[u]:
|
812 |
+
if not visited[i]:
|
813 |
+
_get_all_paths_util(graph, i, d, visited, path)
|
814 |
+
path.pop()
|
815 |
+
visited[u] = False
|
816 |
+
|
817 |
+
_get_all_paths_util(self.graph, s, d, visited, path)
|
818 |
+
return all_paths
|
819 |
+
|
820 |
+
@staticmethod
|
821 |
+
def from_ssa(ssa):
|
822 |
+
graph = DiGraph()
|
823 |
+
for op_id in range(len(ssa)):
|
824 |
+
for inp in ssa[op_id][0]:
|
825 |
+
for outp in ssa[op_id][1]:
|
826 |
+
graph.add_edge(inp, outp)
|
827 |
+
return graph
|
828 |
+
|
829 |
+
|
830 |
+
def _get_dependency_chain(ssa, versioned_target, versioned_source):
|
831 |
+
"""
|
832 |
+
Return the index list of relevant operator to produce target blob from source blob,
|
833 |
+
if there's no dependency, return empty list.
|
834 |
+
"""
|
835 |
+
|
836 |
+
# finding all paths between nodes can be O(N!), thus we can only search
|
837 |
+
# in the subgraph using the op starting from the first consumer of source blob
|
838 |
+
# to the producer of the target blob.
|
839 |
+
consumer_map = get_consumer_map(ssa)
|
840 |
+
producer_map = get_producer_map(ssa)
|
841 |
+
start_op = min(x[0] for x in consumer_map[versioned_source]) - 15
|
842 |
+
end_op = (
|
843 |
+
producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op
|
844 |
+
)
|
845 |
+
sub_graph_ssa = ssa[start_op : end_op + 1]
|
846 |
+
if len(sub_graph_ssa) > 30:
|
847 |
+
logger.warning(
|
848 |
+
"Subgraph bebetween {} and {} is large (from op#{} to op#{}), it"
|
849 |
+
" might take non-trival time to find all paths between them.".format(
|
850 |
+
versioned_source, versioned_target, start_op, end_op
|
851 |
+
)
|
852 |
+
)
|
853 |
+
|
854 |
+
dag = DiGraph.from_ssa(sub_graph_ssa)
|
855 |
+
paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends
|
856 |
+
ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths]
|
857 |
+
return sorted(set().union(*[set(ops) for ops in ops_in_paths]))
|
858 |
+
|
859 |
+
|
860 |
+
def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]:
|
861 |
+
"""
|
862 |
+
Idenfity the reshape sub-graph in a protobuf.
|
863 |
+
The reshape sub-graph is defined as matching the following pattern:
|
864 |
+
|
865 |
+
(input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐
|
866 |
+
└-------------------------------------------> Reshape -> (output_blob)
|
867 |
+
|
868 |
+
Return:
|
869 |
+
List of sub-graphs, each sub-graph is represented as a list of indices
|
870 |
+
of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape]
|
871 |
+
"""
|
872 |
+
|
873 |
+
ssa, _ = core.get_ssa(predict_net)
|
874 |
+
|
875 |
+
ret = []
|
876 |
+
for i, op in enumerate(predict_net.op):
|
877 |
+
if op.type == "Reshape":
|
878 |
+
assert len(op.input) == 2
|
879 |
+
input_ssa = ssa[i][0]
|
880 |
+
data_source = input_ssa[0]
|
881 |
+
shape_source = input_ssa[1]
|
882 |
+
op_indices = _get_dependency_chain(ssa, shape_source, data_source)
|
883 |
+
ret.append(op_indices + [i])
|
884 |
+
return ret
|
885 |
+
|
886 |
+
|
887 |
+
def remove_reshape_for_fc(predict_net, params):
|
888 |
+
"""
|
889 |
+
In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape
|
890 |
+
a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping
|
891 |
+
doesn't work well with ONNX and Int8 tools, and cause using extra
|
892 |
+
ops (eg. ExpandDims) that might not be available on mobile.
|
893 |
+
Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape
|
894 |
+
after exporting ONNX model.
|
895 |
+
"""
|
896 |
+
from caffe2.python import core
|
897 |
+
|
898 |
+
# find all reshape sub-graph that can be removed, which is now all Reshape
|
899 |
+
# sub-graph whose output is only consumed by FC.
|
900 |
+
# TODO: to make it safer, we may need the actually value to better determine
|
901 |
+
# if a Reshape before FC is removable.
|
902 |
+
reshape_sub_graphs = identify_reshape_sub_graph(predict_net)
|
903 |
+
sub_graphs_to_remove = []
|
904 |
+
for reshape_sub_graph in reshape_sub_graphs:
|
905 |
+
reshape_op_id = reshape_sub_graph[-1]
|
906 |
+
assert predict_net.op[reshape_op_id].type == "Reshape"
|
907 |
+
ssa, _ = core.get_ssa(predict_net)
|
908 |
+
reshape_output = ssa[reshape_op_id][1][0]
|
909 |
+
consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]]
|
910 |
+
if all(predict_net.op[consumer].type == "FC" for consumer in consumers):
|
911 |
+
# safety check if the sub-graph is isolated, for this reshape sub-graph,
|
912 |
+
# it means it has one non-param external input and one external output.
|
913 |
+
ext_inputs, ext_outputs = get_sub_graph_external_input_output(
|
914 |
+
predict_net, reshape_sub_graph
|
915 |
+
)
|
916 |
+
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
917 |
+
if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1:
|
918 |
+
sub_graphs_to_remove.append(reshape_sub_graph)
|
919 |
+
|
920 |
+
# perform removing subgraph by:
|
921 |
+
# 1: rename the Reshape's output to its input, then the graph can be
|
922 |
+
# seen as in-place itentify, meaning whose external input/output are the same.
|
923 |
+
# 2: simply remove those ops.
|
924 |
+
remove_op_ids = []
|
925 |
+
params_to_remove = []
|
926 |
+
for sub_graph in sub_graphs_to_remove:
|
927 |
+
logger.info(
|
928 |
+
"Remove Reshape sub-graph:\n{}".format(
|
929 |
+
"".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph])
|
930 |
+
)
|
931 |
+
)
|
932 |
+
reshape_op_id = sub_graph[-1]
|
933 |
+
new_reshap_output = predict_net.op[reshape_op_id].input[0]
|
934 |
+
rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output)
|
935 |
+
ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph)
|
936 |
+
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
937 |
+
params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0]
|
938 |
+
assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1
|
939 |
+
assert ext_outputs[0][0] == non_params_ext_inputs[0][0]
|
940 |
+
assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1
|
941 |
+
remove_op_ids.extend(sub_graph)
|
942 |
+
params_to_remove.extend(params_ext_inputs)
|
943 |
+
|
944 |
+
predict_net = copy.deepcopy(predict_net)
|
945 |
+
new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids]
|
946 |
+
del predict_net.op[:]
|
947 |
+
predict_net.op.extend(new_ops)
|
948 |
+
for versioned_params in params_to_remove:
|
949 |
+
name = versioned_params[0]
|
950 |
+
logger.info("Remove params: {} from init_net and predict_net.external_input".format(name))
|
951 |
+
del params[name]
|
952 |
+
predict_net.external_input.remove(name)
|
953 |
+
|
954 |
+
return predict_net, params
|
955 |
+
|
956 |
+
|
957 |
+
def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef):
|
958 |
+
"""
|
959 |
+
In-place fuse extra copy ops between cpu/gpu for the following case:
|
960 |
+
a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1
|
961 |
+
-CopyBToA> c2 -NextOp2-> d2
|
962 |
+
The fused network will look like:
|
963 |
+
a -NextOp1-> d1
|
964 |
+
-NextOp2-> d2
|
965 |
+
"""
|
966 |
+
|
967 |
+
_COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"]
|
968 |
+
|
969 |
+
def _fuse_once(predict_net):
|
970 |
+
ssa, blob_versions = core.get_ssa(predict_net)
|
971 |
+
consumer_map = get_consumer_map(ssa)
|
972 |
+
versioned_external_output = [
|
973 |
+
(name, blob_versions[name]) for name in predict_net.external_output
|
974 |
+
]
|
975 |
+
|
976 |
+
for op_id, op in enumerate(predict_net.op):
|
977 |
+
if op.type in _COPY_OPS:
|
978 |
+
fw_copy_versioned_output = ssa[op_id][1][0]
|
979 |
+
consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]]
|
980 |
+
reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)]
|
981 |
+
|
982 |
+
is_fusable = (
|
983 |
+
len(consumer_ids) > 0
|
984 |
+
and fw_copy_versioned_output not in versioned_external_output
|
985 |
+
and all(
|
986 |
+
predict_net.op[_op_id].type == reverse_op_type
|
987 |
+
and ssa[_op_id][1][0] not in versioned_external_output
|
988 |
+
for _op_id in consumer_ids
|
989 |
+
)
|
990 |
+
)
|
991 |
+
|
992 |
+
if is_fusable:
|
993 |
+
for rv_copy_op_id in consumer_ids:
|
994 |
+
# making each NextOp uses "a" directly and removing Copy ops
|
995 |
+
rs_copy_versioned_output = ssa[rv_copy_op_id][1][0]
|
996 |
+
next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0]
|
997 |
+
predict_net.op[next_op_id].input[inp_id] = op.input[0]
|
998 |
+
# remove CopyOps
|
999 |
+
new_ops = [
|
1000 |
+
op
|
1001 |
+
for i, op in enumerate(predict_net.op)
|
1002 |
+
if i != op_id and i not in consumer_ids
|
1003 |
+
]
|
1004 |
+
del predict_net.op[:]
|
1005 |
+
predict_net.op.extend(new_ops)
|
1006 |
+
return True
|
1007 |
+
|
1008 |
+
return False
|
1009 |
+
|
1010 |
+
# _fuse_once returns False is nothing can be fused
|
1011 |
+
while _fuse_once(predict_net):
|
1012 |
+
pass
|
1013 |
+
|
1014 |
+
|
1015 |
+
def remove_dead_end_ops(net_def: caffe2_pb2.NetDef):
|
1016 |
+
"""remove ops if its output is not used or not in external_output"""
|
1017 |
+
ssa, versions = core.get_ssa(net_def)
|
1018 |
+
versioned_external_output = [(name, versions[name]) for name in net_def.external_output]
|
1019 |
+
consumer_map = get_consumer_map(ssa)
|
1020 |
+
removed_op_ids = set()
|
1021 |
+
|
1022 |
+
def _is_dead_end(versioned_blob):
|
1023 |
+
return not (
|
1024 |
+
versioned_blob in versioned_external_output
|
1025 |
+
or (
|
1026 |
+
len(consumer_map[versioned_blob]) > 0
|
1027 |
+
and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob])
|
1028 |
+
)
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
for i, ssa_i in reversed(list(enumerate(ssa))):
|
1032 |
+
versioned_outputs = ssa_i[1]
|
1033 |
+
if all(_is_dead_end(outp) for outp in versioned_outputs):
|
1034 |
+
removed_op_ids.add(i)
|
1035 |
+
|
1036 |
+
# simply removing those deadend ops should have no effect to external_output
|
1037 |
+
new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids]
|
1038 |
+
del net_def.op[:]
|
1039 |
+
net_def.op.extend(new_ops)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/torchscript.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
7 |
+
|
8 |
+
from .torchscript_patch import freeze_training_mode, patch_instances
|
9 |
+
|
10 |
+
__all__ = ["scripting_with_instances", "dump_torchscript_IR"]
|
11 |
+
|
12 |
+
|
13 |
+
def scripting_with_instances(model, fields):
|
14 |
+
"""
|
15 |
+
Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
|
16 |
+
attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
|
17 |
+
for scripting to support it out of the box. This function is made to support scripting
|
18 |
+
a model that uses :class:`Instances`. It does the following:
|
19 |
+
|
20 |
+
1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
|
21 |
+
but with all attributes been "static".
|
22 |
+
The attributes need to be statically declared in the ``fields`` argument.
|
23 |
+
2. Register ``new_Instances``, and force scripting compiler to
|
24 |
+
use it when trying to compile ``Instances``.
|
25 |
+
|
26 |
+
After this function, the process will be reverted. User should be able to script another model
|
27 |
+
using different fields.
|
28 |
+
|
29 |
+
Example:
|
30 |
+
Assume that ``Instances`` in the model consist of two attributes named
|
31 |
+
``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
|
32 |
+
:class:`Tensor` respectively during inference. You can call this function like:
|
33 |
+
::
|
34 |
+
fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
|
35 |
+
torchscipt_model = scripting_with_instances(model, fields)
|
36 |
+
|
37 |
+
Note:
|
38 |
+
It only support models in evaluation mode.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
model (nn.Module): The input model to be exported by scripting.
|
42 |
+
fields (Dict[str, type]): Attribute names and corresponding type that
|
43 |
+
``Instances`` will use in the model. Note that all attributes used in ``Instances``
|
44 |
+
need to be added, regardless of whether they are inputs/outputs of the model.
|
45 |
+
Data type not defined in detectron2 is not supported for now.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
torch.jit.ScriptModule: the model in torchscript format
|
49 |
+
"""
|
50 |
+
assert (
|
51 |
+
not model.training
|
52 |
+
), "Currently we only support exporting models in evaluation mode to torchscript"
|
53 |
+
|
54 |
+
with freeze_training_mode(model), patch_instances(fields):
|
55 |
+
scripted_model = torch.jit.script(model)
|
56 |
+
return scripted_model
|
57 |
+
|
58 |
+
|
59 |
+
# alias for old name
|
60 |
+
export_torchscript_with_instances = scripting_with_instances
|
61 |
+
|
62 |
+
|
63 |
+
def dump_torchscript_IR(model, dir):
|
64 |
+
"""
|
65 |
+
Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
|
66 |
+
inlined graph). Useful for debugging.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
|
70 |
+
dir (str): output directory to dump files.
|
71 |
+
"""
|
72 |
+
dir = os.path.expanduser(dir)
|
73 |
+
PathManager.mkdirs(dir)
|
74 |
+
|
75 |
+
def _get_script_mod(mod):
|
76 |
+
if isinstance(mod, torch.jit.TracedModule):
|
77 |
+
return mod._actual_script_module
|
78 |
+
return mod
|
79 |
+
|
80 |
+
# Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
|
81 |
+
with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:
|
82 |
+
|
83 |
+
def get_code(mod):
|
84 |
+
# Try a few ways to get code using private attributes.
|
85 |
+
try:
|
86 |
+
# This contains more information than just `mod.code`
|
87 |
+
return _get_script_mod(mod)._c.code
|
88 |
+
except AttributeError:
|
89 |
+
pass
|
90 |
+
try:
|
91 |
+
return mod.code
|
92 |
+
except AttributeError:
|
93 |
+
return None
|
94 |
+
|
95 |
+
def dump_code(prefix, mod):
|
96 |
+
code = get_code(mod)
|
97 |
+
name = prefix or "root model"
|
98 |
+
if code is None:
|
99 |
+
f.write(f"Could not found code for {name} (type={mod.original_name})\n")
|
100 |
+
f.write("\n")
|
101 |
+
else:
|
102 |
+
f.write(f"\nCode for {name}, type={mod.original_name}:\n")
|
103 |
+
f.write(code)
|
104 |
+
f.write("\n")
|
105 |
+
f.write("-" * 80)
|
106 |
+
|
107 |
+
for name, m in mod.named_children():
|
108 |
+
dump_code(prefix + "." + name, m)
|
109 |
+
|
110 |
+
if isinstance(model, torch.jit.ScriptFunction):
|
111 |
+
f.write(get_code(model))
|
112 |
+
else:
|
113 |
+
dump_code("", model)
|
114 |
+
|
115 |
+
def _get_graph(model):
|
116 |
+
try:
|
117 |
+
# Recursively dump IR of all modules
|
118 |
+
return _get_script_mod(model)._c.dump_to_str(True, False, False)
|
119 |
+
except AttributeError:
|
120 |
+
return model.graph.str()
|
121 |
+
|
122 |
+
with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
|
123 |
+
f.write(_get_graph(model))
|
124 |
+
|
125 |
+
# Dump IR of the entire graph (all submodules inlined)
|
126 |
+
with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
|
127 |
+
f.write(str(model.inlined_graph))
|
128 |
+
|
129 |
+
if not isinstance(model, torch.jit.ScriptFunction):
|
130 |
+
# Dump the model structure in pytorch style
|
131 |
+
with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
|
132 |
+
f.write(str(model))
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/export/torchscript_patch.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import tempfile
|
6 |
+
from contextlib import ExitStack, contextmanager
|
7 |
+
from copy import deepcopy
|
8 |
+
from unittest import mock
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
# need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964
|
13 |
+
import annotator.oneformer.detectron2 # noqa F401
|
14 |
+
from annotator.oneformer.detectron2.structures import Boxes, Instances
|
15 |
+
from annotator.oneformer.detectron2.utils.env import _import_file
|
16 |
+
|
17 |
+
_counter = 0
|
18 |
+
|
19 |
+
|
20 |
+
def _clear_jit_cache():
|
21 |
+
from torch.jit._recursive import concrete_type_store
|
22 |
+
from torch.jit._state import _jit_caching_layer
|
23 |
+
|
24 |
+
concrete_type_store.type_store.clear() # for modules
|
25 |
+
_jit_caching_layer.clear() # for free functions
|
26 |
+
|
27 |
+
|
28 |
+
def _add_instances_conversion_methods(newInstances):
|
29 |
+
"""
|
30 |
+
Add from_instances methods to the scripted Instances class.
|
31 |
+
"""
|
32 |
+
cls_name = newInstances.__name__
|
33 |
+
|
34 |
+
@torch.jit.unused
|
35 |
+
def from_instances(instances: Instances):
|
36 |
+
"""
|
37 |
+
Create scripted Instances from original Instances
|
38 |
+
"""
|
39 |
+
fields = instances.get_fields()
|
40 |
+
image_size = instances.image_size
|
41 |
+
ret = newInstances(image_size)
|
42 |
+
for name, val in fields.items():
|
43 |
+
assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}"
|
44 |
+
setattr(ret, name, deepcopy(val))
|
45 |
+
return ret
|
46 |
+
|
47 |
+
newInstances.from_instances = from_instances
|
48 |
+
|
49 |
+
|
50 |
+
@contextmanager
|
51 |
+
def patch_instances(fields):
|
52 |
+
"""
|
53 |
+
A contextmanager, under which the Instances class in detectron2 is replaced
|
54 |
+
by a statically-typed scriptable class, defined by `fields`.
|
55 |
+
See more in `scripting_with_instances`.
|
56 |
+
"""
|
57 |
+
|
58 |
+
with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
|
59 |
+
mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
|
60 |
+
) as f:
|
61 |
+
try:
|
62 |
+
# Objects that use Instances should not reuse previously-compiled
|
63 |
+
# results in cache, because `Instances` could be a new class each time.
|
64 |
+
_clear_jit_cache()
|
65 |
+
|
66 |
+
cls_name, s = _gen_instance_module(fields)
|
67 |
+
f.write(s)
|
68 |
+
f.flush()
|
69 |
+
f.close()
|
70 |
+
|
71 |
+
module = _import(f.name)
|
72 |
+
new_instances = getattr(module, cls_name)
|
73 |
+
_ = torch.jit.script(new_instances)
|
74 |
+
# let torchscript think Instances was scripted already
|
75 |
+
Instances.__torch_script_class__ = True
|
76 |
+
# let torchscript find new_instances when looking for the jit type of Instances
|
77 |
+
Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
|
78 |
+
|
79 |
+
_add_instances_conversion_methods(new_instances)
|
80 |
+
yield new_instances
|
81 |
+
finally:
|
82 |
+
try:
|
83 |
+
del Instances.__torch_script_class__
|
84 |
+
del Instances._jit_override_qualname
|
85 |
+
except AttributeError:
|
86 |
+
pass
|
87 |
+
sys.modules.pop(module.__name__)
|
88 |
+
|
89 |
+
|
90 |
+
def _gen_instance_class(fields):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
fields (dict[name: type])
|
94 |
+
"""
|
95 |
+
|
96 |
+
class _FieldType:
|
97 |
+
def __init__(self, name, type_):
|
98 |
+
assert isinstance(name, str), f"Field name must be str, got {name}"
|
99 |
+
self.name = name
|
100 |
+
self.type_ = type_
|
101 |
+
self.annotation = f"{type_.__module__}.{type_.__name__}"
|
102 |
+
|
103 |
+
fields = [_FieldType(k, v) for k, v in fields.items()]
|
104 |
+
|
105 |
+
def indent(level, s):
|
106 |
+
return " " * 4 * level + s
|
107 |
+
|
108 |
+
lines = []
|
109 |
+
|
110 |
+
global _counter
|
111 |
+
_counter += 1
|
112 |
+
|
113 |
+
cls_name = "ScriptedInstances{}".format(_counter)
|
114 |
+
|
115 |
+
field_names = tuple(x.name for x in fields)
|
116 |
+
extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
|
117 |
+
lines.append(
|
118 |
+
f"""
|
119 |
+
class {cls_name}:
|
120 |
+
def __init__(self, image_size: Tuple[int, int], {extra_args}):
|
121 |
+
self.image_size = image_size
|
122 |
+
self._field_names = {field_names}
|
123 |
+
"""
|
124 |
+
)
|
125 |
+
|
126 |
+
for f in fields:
|
127 |
+
lines.append(
|
128 |
+
indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
|
129 |
+
)
|
130 |
+
|
131 |
+
for f in fields:
|
132 |
+
lines.append(
|
133 |
+
f"""
|
134 |
+
@property
|
135 |
+
def {f.name}(self) -> {f.annotation}:
|
136 |
+
# has to use a local for type refinement
|
137 |
+
# https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
|
138 |
+
t = self._{f.name}
|
139 |
+
assert t is not None, "{f.name} is None and cannot be accessed!"
|
140 |
+
return t
|
141 |
+
|
142 |
+
@{f.name}.setter
|
143 |
+
def {f.name}(self, value: {f.annotation}) -> None:
|
144 |
+
self._{f.name} = value
|
145 |
+
"""
|
146 |
+
)
|
147 |
+
|
148 |
+
# support method `__len__`
|
149 |
+
lines.append(
|
150 |
+
"""
|
151 |
+
def __len__(self) -> int:
|
152 |
+
"""
|
153 |
+
)
|
154 |
+
for f in fields:
|
155 |
+
lines.append(
|
156 |
+
f"""
|
157 |
+
t = self._{f.name}
|
158 |
+
if t is not None:
|
159 |
+
return len(t)
|
160 |
+
"""
|
161 |
+
)
|
162 |
+
lines.append(
|
163 |
+
"""
|
164 |
+
raise NotImplementedError("Empty Instances does not support __len__!")
|
165 |
+
"""
|
166 |
+
)
|
167 |
+
|
168 |
+
# support method `has`
|
169 |
+
lines.append(
|
170 |
+
"""
|
171 |
+
def has(self, name: str) -> bool:
|
172 |
+
"""
|
173 |
+
)
|
174 |
+
for f in fields:
|
175 |
+
lines.append(
|
176 |
+
f"""
|
177 |
+
if name == "{f.name}":
|
178 |
+
return self._{f.name} is not None
|
179 |
+
"""
|
180 |
+
)
|
181 |
+
lines.append(
|
182 |
+
"""
|
183 |
+
return False
|
184 |
+
"""
|
185 |
+
)
|
186 |
+
|
187 |
+
# support method `to`
|
188 |
+
none_args = ", None" * len(fields)
|
189 |
+
lines.append(
|
190 |
+
f"""
|
191 |
+
def to(self, device: torch.device) -> "{cls_name}":
|
192 |
+
ret = {cls_name}(self.image_size{none_args})
|
193 |
+
"""
|
194 |
+
)
|
195 |
+
for f in fields:
|
196 |
+
if hasattr(f.type_, "to"):
|
197 |
+
lines.append(
|
198 |
+
f"""
|
199 |
+
t = self._{f.name}
|
200 |
+
if t is not None:
|
201 |
+
ret._{f.name} = t.to(device)
|
202 |
+
"""
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
# For now, ignore fields that cannot be moved to devices.
|
206 |
+
# Maybe can support other tensor-like classes (e.g. __torch_function__)
|
207 |
+
pass
|
208 |
+
lines.append(
|
209 |
+
"""
|
210 |
+
return ret
|
211 |
+
"""
|
212 |
+
)
|
213 |
+
|
214 |
+
# support method `getitem`
|
215 |
+
none_args = ", None" * len(fields)
|
216 |
+
lines.append(
|
217 |
+
f"""
|
218 |
+
def __getitem__(self, item) -> "{cls_name}":
|
219 |
+
ret = {cls_name}(self.image_size{none_args})
|
220 |
+
"""
|
221 |
+
)
|
222 |
+
for f in fields:
|
223 |
+
lines.append(
|
224 |
+
f"""
|
225 |
+
t = self._{f.name}
|
226 |
+
if t is not None:
|
227 |
+
ret._{f.name} = t[item]
|
228 |
+
"""
|
229 |
+
)
|
230 |
+
lines.append(
|
231 |
+
"""
|
232 |
+
return ret
|
233 |
+
"""
|
234 |
+
)
|
235 |
+
|
236 |
+
# support method `cat`
|
237 |
+
# this version does not contain checks that all instances have same size and fields
|
238 |
+
none_args = ", None" * len(fields)
|
239 |
+
lines.append(
|
240 |
+
f"""
|
241 |
+
def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
|
242 |
+
ret = {cls_name}(self.image_size{none_args})
|
243 |
+
"""
|
244 |
+
)
|
245 |
+
for f in fields:
|
246 |
+
lines.append(
|
247 |
+
f"""
|
248 |
+
t = self._{f.name}
|
249 |
+
if t is not None:
|
250 |
+
values: List[{f.annotation}] = [x.{f.name} for x in instances]
|
251 |
+
if torch.jit.isinstance(t, torch.Tensor):
|
252 |
+
ret._{f.name} = torch.cat(values, dim=0)
|
253 |
+
else:
|
254 |
+
ret._{f.name} = t.cat(values)
|
255 |
+
"""
|
256 |
+
)
|
257 |
+
lines.append(
|
258 |
+
"""
|
259 |
+
return ret"""
|
260 |
+
)
|
261 |
+
|
262 |
+
# support method `get_fields()`
|
263 |
+
lines.append(
|
264 |
+
"""
|
265 |
+
def get_fields(self) -> Dict[str, Tensor]:
|
266 |
+
ret = {}
|
267 |
+
"""
|
268 |
+
)
|
269 |
+
for f in fields:
|
270 |
+
if f.type_ == Boxes:
|
271 |
+
stmt = "t.tensor"
|
272 |
+
elif f.type_ == torch.Tensor:
|
273 |
+
stmt = "t"
|
274 |
+
else:
|
275 |
+
stmt = f'assert False, "unsupported type {str(f.type_)}"'
|
276 |
+
lines.append(
|
277 |
+
f"""
|
278 |
+
t = self._{f.name}
|
279 |
+
if t is not None:
|
280 |
+
ret["{f.name}"] = {stmt}
|
281 |
+
"""
|
282 |
+
)
|
283 |
+
lines.append(
|
284 |
+
"""
|
285 |
+
return ret"""
|
286 |
+
)
|
287 |
+
return cls_name, os.linesep.join(lines)
|
288 |
+
|
289 |
+
|
290 |
+
def _gen_instance_module(fields):
|
291 |
+
# TODO: find a more automatic way to enable import of other classes
|
292 |
+
s = """
|
293 |
+
from copy import deepcopy
|
294 |
+
import torch
|
295 |
+
from torch import Tensor
|
296 |
+
import typing
|
297 |
+
from typing import *
|
298 |
+
|
299 |
+
import annotator.oneformer.detectron2
|
300 |
+
from annotator.oneformer.detectron2.structures import Boxes, Instances
|
301 |
+
|
302 |
+
"""
|
303 |
+
|
304 |
+
cls_name, cls_def = _gen_instance_class(fields)
|
305 |
+
s += cls_def
|
306 |
+
return cls_name, s
|
307 |
+
|
308 |
+
|
309 |
+
def _import(path):
|
310 |
+
return _import_file(
|
311 |
+
"{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
@contextmanager
|
316 |
+
def patch_builtin_len(modules=()):
|
317 |
+
"""
|
318 |
+
Patch the builtin len() function of a few detectron2 modules
|
319 |
+
to use __len__ instead, because __len__ does not convert values to
|
320 |
+
integers and therefore is friendly to tracing.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
modules (list[stsr]): names of extra modules to patch len(), in
|
324 |
+
addition to those in detectron2.
|
325 |
+
"""
|
326 |
+
|
327 |
+
def _new_len(obj):
|
328 |
+
return obj.__len__()
|
329 |
+
|
330 |
+
with ExitStack() as stack:
|
331 |
+
MODULES = [
|
332 |
+
"detectron2.modeling.roi_heads.fast_rcnn",
|
333 |
+
"detectron2.modeling.roi_heads.mask_head",
|
334 |
+
"detectron2.modeling.roi_heads.keypoint_head",
|
335 |
+
] + list(modules)
|
336 |
+
ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES]
|
337 |
+
for m in ctxs:
|
338 |
+
m.side_effect = _new_len
|
339 |
+
yield
|
340 |
+
|
341 |
+
|
342 |
+
def patch_nonscriptable_classes():
|
343 |
+
"""
|
344 |
+
Apply patches on a few nonscriptable detectron2 classes.
|
345 |
+
Should not have side-effects on eager usage.
|
346 |
+
"""
|
347 |
+
# __prepare_scriptable__ can also be added to models for easier maintenance.
|
348 |
+
# But it complicates the clean model code.
|
349 |
+
|
350 |
+
from annotator.oneformer.detectron2.modeling.backbone import ResNet, FPN
|
351 |
+
|
352 |
+
# Due to https://github.com/pytorch/pytorch/issues/36061,
|
353 |
+
# we change backbone to use ModuleList for scripting.
|
354 |
+
# (note: this changes param names in state_dict)
|
355 |
+
|
356 |
+
def prepare_resnet(self):
|
357 |
+
ret = deepcopy(self)
|
358 |
+
ret.stages = nn.ModuleList(ret.stages)
|
359 |
+
for k in self.stage_names:
|
360 |
+
delattr(ret, k)
|
361 |
+
return ret
|
362 |
+
|
363 |
+
ResNet.__prepare_scriptable__ = prepare_resnet
|
364 |
+
|
365 |
+
def prepare_fpn(self):
|
366 |
+
ret = deepcopy(self)
|
367 |
+
ret.lateral_convs = nn.ModuleList(ret.lateral_convs)
|
368 |
+
ret.output_convs = nn.ModuleList(ret.output_convs)
|
369 |
+
for name, _ in self.named_children():
|
370 |
+
if name.startswith("fpn_"):
|
371 |
+
delattr(ret, name)
|
372 |
+
return ret
|
373 |
+
|
374 |
+
FPN.__prepare_scriptable__ = prepare_fpn
|
375 |
+
|
376 |
+
# Annotate some attributes to be constants for the purpose of scripting,
|
377 |
+
# even though they are not constants in eager mode.
|
378 |
+
from annotator.oneformer.detectron2.modeling.roi_heads import StandardROIHeads
|
379 |
+
|
380 |
+
if hasattr(StandardROIHeads, "__annotations__"):
|
381 |
+
# copy first to avoid editing annotations of base class
|
382 |
+
StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__)
|
383 |
+
StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool]
|
384 |
+
StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool]
|
385 |
+
|
386 |
+
|
387 |
+
# These patches are not supposed to have side-effects.
|
388 |
+
patch_nonscriptable_classes()
|
389 |
+
|
390 |
+
|
391 |
+
@contextmanager
|
392 |
+
def freeze_training_mode(model):
|
393 |
+
"""
|
394 |
+
A context manager that annotates the "training" attribute of every submodule
|
395 |
+
to constant, so that the training codepath in these modules can be
|
396 |
+
meta-compiled away. Upon exiting, the annotations are reverted.
|
397 |
+
"""
|
398 |
+
classes = {type(x) for x in model.modules()}
|
399 |
+
# __constants__ is the old way to annotate constants and not compatible
|
400 |
+
# with __annotations__ .
|
401 |
+
classes = {x for x in classes if not hasattr(x, "__constants__")}
|
402 |
+
for cls in classes:
|
403 |
+
cls.__annotations__["training"] = torch.jit.Final[bool]
|
404 |
+
yield
|
405 |
+
for cls in classes:
|
406 |
+
cls.__annotations__["training"] = bool
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .batch_norm import FrozenBatchNorm2d, get_norm, NaiveSyncBatchNorm, CycleBatchNormList
|
3 |
+
from .deform_conv import DeformConv, ModulatedDeformConv
|
4 |
+
from .mask_ops import paste_masks_in_image
|
5 |
+
from .nms import batched_nms, batched_nms_rotated, nms, nms_rotated
|
6 |
+
from .roi_align import ROIAlign, roi_align
|
7 |
+
from .roi_align_rotated import ROIAlignRotated, roi_align_rotated
|
8 |
+
from .shape_spec import ShapeSpec
|
9 |
+
from .wrappers import (
|
10 |
+
BatchNorm2d,
|
11 |
+
Conv2d,
|
12 |
+
ConvTranspose2d,
|
13 |
+
cat,
|
14 |
+
interpolate,
|
15 |
+
Linear,
|
16 |
+
nonzero_tuple,
|
17 |
+
cross_entropy,
|
18 |
+
empty_input_loss_func_wrapper,
|
19 |
+
shapes_to_tensor,
|
20 |
+
move_device_like,
|
21 |
+
)
|
22 |
+
from .blocks import CNNBlockBase, DepthwiseSeparableConv2d
|
23 |
+
from .aspp import ASPP
|
24 |
+
from .losses import ciou_loss, diou_loss
|
25 |
+
|
26 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/aspp.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
import fvcore.nn.weight_init as weight_init
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from .batch_norm import get_norm
|
10 |
+
from .blocks import DepthwiseSeparableConv2d
|
11 |
+
from .wrappers import Conv2d
|
12 |
+
|
13 |
+
|
14 |
+
class ASPP(nn.Module):
|
15 |
+
"""
|
16 |
+
Atrous Spatial Pyramid Pooling (ASPP).
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
in_channels,
|
22 |
+
out_channels,
|
23 |
+
dilations,
|
24 |
+
*,
|
25 |
+
norm,
|
26 |
+
activation,
|
27 |
+
pool_kernel_size=None,
|
28 |
+
dropout: float = 0.0,
|
29 |
+
use_depthwise_separable_conv=False,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
in_channels (int): number of input channels for ASPP.
|
34 |
+
out_channels (int): number of output channels.
|
35 |
+
dilations (list): a list of 3 dilations in ASPP.
|
36 |
+
norm (str or callable): normalization for all conv layers.
|
37 |
+
See :func:`layers.get_norm` for supported format. norm is
|
38 |
+
applied to all conv layers except the conv following
|
39 |
+
global average pooling.
|
40 |
+
activation (callable): activation function.
|
41 |
+
pool_kernel_size (tuple, list): the average pooling size (kh, kw)
|
42 |
+
for image pooling layer in ASPP. If set to None, it always
|
43 |
+
performs global average pooling. If not None, it must be
|
44 |
+
divisible by the shape of inputs in forward(). It is recommended
|
45 |
+
to use a fixed input feature size in training, and set this
|
46 |
+
option to match this size, so that it performs global average
|
47 |
+
pooling in training, and the size of the pooling window stays
|
48 |
+
consistent in inference.
|
49 |
+
dropout (float): apply dropout on the output of ASPP. It is used in
|
50 |
+
the official DeepLab implementation with a rate of 0.1:
|
51 |
+
https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa
|
52 |
+
use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d
|
53 |
+
for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`.
|
54 |
+
"""
|
55 |
+
super(ASPP, self).__init__()
|
56 |
+
assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations))
|
57 |
+
self.pool_kernel_size = pool_kernel_size
|
58 |
+
self.dropout = dropout
|
59 |
+
use_bias = norm == ""
|
60 |
+
self.convs = nn.ModuleList()
|
61 |
+
# conv 1x1
|
62 |
+
self.convs.append(
|
63 |
+
Conv2d(
|
64 |
+
in_channels,
|
65 |
+
out_channels,
|
66 |
+
kernel_size=1,
|
67 |
+
bias=use_bias,
|
68 |
+
norm=get_norm(norm, out_channels),
|
69 |
+
activation=deepcopy(activation),
|
70 |
+
)
|
71 |
+
)
|
72 |
+
weight_init.c2_xavier_fill(self.convs[-1])
|
73 |
+
# atrous convs
|
74 |
+
for dilation in dilations:
|
75 |
+
if use_depthwise_separable_conv:
|
76 |
+
self.convs.append(
|
77 |
+
DepthwiseSeparableConv2d(
|
78 |
+
in_channels,
|
79 |
+
out_channels,
|
80 |
+
kernel_size=3,
|
81 |
+
padding=dilation,
|
82 |
+
dilation=dilation,
|
83 |
+
norm1=norm,
|
84 |
+
activation1=deepcopy(activation),
|
85 |
+
norm2=norm,
|
86 |
+
activation2=deepcopy(activation),
|
87 |
+
)
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
self.convs.append(
|
91 |
+
Conv2d(
|
92 |
+
in_channels,
|
93 |
+
out_channels,
|
94 |
+
kernel_size=3,
|
95 |
+
padding=dilation,
|
96 |
+
dilation=dilation,
|
97 |
+
bias=use_bias,
|
98 |
+
norm=get_norm(norm, out_channels),
|
99 |
+
activation=deepcopy(activation),
|
100 |
+
)
|
101 |
+
)
|
102 |
+
weight_init.c2_xavier_fill(self.convs[-1])
|
103 |
+
# image pooling
|
104 |
+
# We do not add BatchNorm because the spatial resolution is 1x1,
|
105 |
+
# the original TF implementation has BatchNorm.
|
106 |
+
if pool_kernel_size is None:
|
107 |
+
image_pooling = nn.Sequential(
|
108 |
+
nn.AdaptiveAvgPool2d(1),
|
109 |
+
Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)),
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
image_pooling = nn.Sequential(
|
113 |
+
nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1),
|
114 |
+
Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)),
|
115 |
+
)
|
116 |
+
weight_init.c2_xavier_fill(image_pooling[1])
|
117 |
+
self.convs.append(image_pooling)
|
118 |
+
|
119 |
+
self.project = Conv2d(
|
120 |
+
5 * out_channels,
|
121 |
+
out_channels,
|
122 |
+
kernel_size=1,
|
123 |
+
bias=use_bias,
|
124 |
+
norm=get_norm(norm, out_channels),
|
125 |
+
activation=deepcopy(activation),
|
126 |
+
)
|
127 |
+
weight_init.c2_xavier_fill(self.project)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
size = x.shape[-2:]
|
131 |
+
if self.pool_kernel_size is not None:
|
132 |
+
if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]:
|
133 |
+
raise ValueError(
|
134 |
+
"`pool_kernel_size` must be divisible by the shape of inputs. "
|
135 |
+
"Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size)
|
136 |
+
)
|
137 |
+
res = []
|
138 |
+
for conv in self.convs:
|
139 |
+
res.append(conv(x))
|
140 |
+
res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False)
|
141 |
+
res = torch.cat(res, dim=1)
|
142 |
+
res = self.project(res)
|
143 |
+
res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res
|
144 |
+
return res
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/batch_norm.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
from fvcore.nn.distributed import differentiable_all_reduce
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from annotator.oneformer.detectron2.utils import comm, env
|
9 |
+
|
10 |
+
from .wrappers import BatchNorm2d
|
11 |
+
|
12 |
+
|
13 |
+
class FrozenBatchNorm2d(nn.Module):
|
14 |
+
"""
|
15 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
16 |
+
|
17 |
+
It contains non-trainable buffers called
|
18 |
+
"weight" and "bias", "running_mean", "running_var",
|
19 |
+
initialized to perform identity transformation.
|
20 |
+
|
21 |
+
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
22 |
+
which are computed from the original four parameters of BN.
|
23 |
+
The affine transform `x * weight + bias` will perform the equivalent
|
24 |
+
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
25 |
+
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
26 |
+
will be left unchanged as identity transformation.
|
27 |
+
|
28 |
+
Other pre-trained backbone models may contain all 4 parameters.
|
29 |
+
|
30 |
+
The forward is implemented by `F.batch_norm(..., training=False)`.
|
31 |
+
"""
|
32 |
+
|
33 |
+
_version = 3
|
34 |
+
|
35 |
+
def __init__(self, num_features, eps=1e-5):
|
36 |
+
super().__init__()
|
37 |
+
self.num_features = num_features
|
38 |
+
self.eps = eps
|
39 |
+
self.register_buffer("weight", torch.ones(num_features))
|
40 |
+
self.register_buffer("bias", torch.zeros(num_features))
|
41 |
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
42 |
+
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
if x.requires_grad:
|
46 |
+
# When gradients are needed, F.batch_norm will use extra memory
|
47 |
+
# because its backward op computes gradients for weight/bias as well.
|
48 |
+
scale = self.weight * (self.running_var + self.eps).rsqrt()
|
49 |
+
bias = self.bias - self.running_mean * scale
|
50 |
+
scale = scale.reshape(1, -1, 1, 1)
|
51 |
+
bias = bias.reshape(1, -1, 1, 1)
|
52 |
+
out_dtype = x.dtype # may be half
|
53 |
+
return x * scale.to(out_dtype) + bias.to(out_dtype)
|
54 |
+
else:
|
55 |
+
# When gradients are not needed, F.batch_norm is a single fused op
|
56 |
+
# and provide more optimization opportunities.
|
57 |
+
return F.batch_norm(
|
58 |
+
x,
|
59 |
+
self.running_mean,
|
60 |
+
self.running_var,
|
61 |
+
self.weight,
|
62 |
+
self.bias,
|
63 |
+
training=False,
|
64 |
+
eps=self.eps,
|
65 |
+
)
|
66 |
+
|
67 |
+
def _load_from_state_dict(
|
68 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
69 |
+
):
|
70 |
+
version = local_metadata.get("version", None)
|
71 |
+
|
72 |
+
if version is None or version < 2:
|
73 |
+
# No running_mean/var in early versions
|
74 |
+
# This will silent the warnings
|
75 |
+
if prefix + "running_mean" not in state_dict:
|
76 |
+
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
77 |
+
if prefix + "running_var" not in state_dict:
|
78 |
+
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
79 |
+
|
80 |
+
super()._load_from_state_dict(
|
81 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
82 |
+
)
|
83 |
+
|
84 |
+
def __repr__(self):
|
85 |
+
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def convert_frozen_batchnorm(cls, module):
|
89 |
+
"""
|
90 |
+
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
module (torch.nn.Module):
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
97 |
+
Otherwise, in-place convert module and return it.
|
98 |
+
|
99 |
+
Similar to convert_sync_batchnorm in
|
100 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
101 |
+
"""
|
102 |
+
bn_module = nn.modules.batchnorm
|
103 |
+
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
104 |
+
res = module
|
105 |
+
if isinstance(module, bn_module):
|
106 |
+
res = cls(module.num_features)
|
107 |
+
if module.affine:
|
108 |
+
res.weight.data = module.weight.data.clone().detach()
|
109 |
+
res.bias.data = module.bias.data.clone().detach()
|
110 |
+
res.running_mean.data = module.running_mean.data
|
111 |
+
res.running_var.data = module.running_var.data
|
112 |
+
res.eps = module.eps
|
113 |
+
else:
|
114 |
+
for name, child in module.named_children():
|
115 |
+
new_child = cls.convert_frozen_batchnorm(child)
|
116 |
+
if new_child is not child:
|
117 |
+
res.add_module(name, new_child)
|
118 |
+
return res
|
119 |
+
|
120 |
+
|
121 |
+
def get_norm(norm, out_channels):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
125 |
+
or a callable that takes a channel number and returns
|
126 |
+
the normalization layer as a nn.Module.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
nn.Module or None: the normalization layer
|
130 |
+
"""
|
131 |
+
if norm is None:
|
132 |
+
return None
|
133 |
+
if isinstance(norm, str):
|
134 |
+
if len(norm) == 0:
|
135 |
+
return None
|
136 |
+
norm = {
|
137 |
+
"BN": BatchNorm2d,
|
138 |
+
# Fixed in https://github.com/pytorch/pytorch/pull/36382
|
139 |
+
"SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
|
140 |
+
"FrozenBN": FrozenBatchNorm2d,
|
141 |
+
"GN": lambda channels: nn.GroupNorm(32, channels),
|
142 |
+
# for debugging:
|
143 |
+
"nnSyncBN": nn.SyncBatchNorm,
|
144 |
+
"naiveSyncBN": NaiveSyncBatchNorm,
|
145 |
+
# expose stats_mode N as an option to caller, required for zero-len inputs
|
146 |
+
"naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"),
|
147 |
+
"LN": lambda channels: LayerNorm(channels),
|
148 |
+
}[norm]
|
149 |
+
return norm(out_channels)
|
150 |
+
|
151 |
+
|
152 |
+
class NaiveSyncBatchNorm(BatchNorm2d):
|
153 |
+
"""
|
154 |
+
In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
|
155 |
+
when the batch size on each worker is different.
|
156 |
+
(e.g., when scale augmentation is used, or when it is applied to mask head).
|
157 |
+
|
158 |
+
This is a slower but correct alternative to `nn.SyncBatchNorm`.
|
159 |
+
|
160 |
+
Note:
|
161 |
+
There isn't a single definition of Sync BatchNorm.
|
162 |
+
|
163 |
+
When ``stats_mode==""``, this module computes overall statistics by using
|
164 |
+
statistics of each worker with equal weight. The result is true statistics
|
165 |
+
of all samples (as if they are all on one worker) only when all workers
|
166 |
+
have the same (N, H, W). This mode does not support inputs with zero batch size.
|
167 |
+
|
168 |
+
When ``stats_mode=="N"``, this module computes overall statistics by weighting
|
169 |
+
the statistics of each worker by their ``N``. The result is true statistics
|
170 |
+
of all samples (as if they are all on one worker) only when all workers
|
171 |
+
have the same (H, W). It is slower than ``stats_mode==""``.
|
172 |
+
|
173 |
+
Even though the result of this module may not be the true statistics of all samples,
|
174 |
+
it may still be reasonable because it might be preferrable to assign equal weights
|
175 |
+
to all workers, regardless of their (H, W) dimension, instead of putting larger weight
|
176 |
+
on larger images. From preliminary experiments, little difference is found between such
|
177 |
+
a simplified implementation and an accurate computation of overall mean & variance.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, *args, stats_mode="", **kwargs):
|
181 |
+
super().__init__(*args, **kwargs)
|
182 |
+
assert stats_mode in ["", "N"]
|
183 |
+
self._stats_mode = stats_mode
|
184 |
+
|
185 |
+
def forward(self, input):
|
186 |
+
if comm.get_world_size() == 1 or not self.training:
|
187 |
+
return super().forward(input)
|
188 |
+
|
189 |
+
B, C = input.shape[0], input.shape[1]
|
190 |
+
|
191 |
+
half_input = input.dtype == torch.float16
|
192 |
+
if half_input:
|
193 |
+
# fp16 does not have good enough numerics for the reduction here
|
194 |
+
input = input.float()
|
195 |
+
mean = torch.mean(input, dim=[0, 2, 3])
|
196 |
+
meansqr = torch.mean(input * input, dim=[0, 2, 3])
|
197 |
+
|
198 |
+
if self._stats_mode == "":
|
199 |
+
assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
|
200 |
+
vec = torch.cat([mean, meansqr], dim=0)
|
201 |
+
vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
|
202 |
+
mean, meansqr = torch.split(vec, C)
|
203 |
+
momentum = self.momentum
|
204 |
+
else:
|
205 |
+
if B == 0:
|
206 |
+
vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
|
207 |
+
vec = vec + input.sum() # make sure there is gradient w.r.t input
|
208 |
+
else:
|
209 |
+
vec = torch.cat(
|
210 |
+
[mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0
|
211 |
+
)
|
212 |
+
vec = differentiable_all_reduce(vec * B)
|
213 |
+
|
214 |
+
total_batch = vec[-1].detach()
|
215 |
+
momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0
|
216 |
+
mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C) # avoid div-by-zero
|
217 |
+
|
218 |
+
var = meansqr - mean * mean
|
219 |
+
invstd = torch.rsqrt(var + self.eps)
|
220 |
+
scale = self.weight * invstd
|
221 |
+
bias = self.bias - mean * scale
|
222 |
+
scale = scale.reshape(1, -1, 1, 1)
|
223 |
+
bias = bias.reshape(1, -1, 1, 1)
|
224 |
+
|
225 |
+
self.running_mean += momentum * (mean.detach() - self.running_mean)
|
226 |
+
self.running_var += momentum * (var.detach() - self.running_var)
|
227 |
+
ret = input * scale + bias
|
228 |
+
if half_input:
|
229 |
+
ret = ret.half()
|
230 |
+
return ret
|
231 |
+
|
232 |
+
|
233 |
+
class CycleBatchNormList(nn.ModuleList):
|
234 |
+
"""
|
235 |
+
Implement domain-specific BatchNorm by cycling.
|
236 |
+
|
237 |
+
When a BatchNorm layer is used for multiple input domains or input
|
238 |
+
features, it might need to maintain a separate test-time statistics
|
239 |
+
for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`.
|
240 |
+
|
241 |
+
This module implements it by using N separate BN layers
|
242 |
+
and it cycles through them every time a forward() is called.
|
243 |
+
|
244 |
+
NOTE: The caller of this module MUST guarantee to always call
|
245 |
+
this module by multiple of N times. Otherwise its test-time statistics
|
246 |
+
will be incorrect.
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs):
|
250 |
+
"""
|
251 |
+
Args:
|
252 |
+
length: number of BatchNorm layers to cycle.
|
253 |
+
bn_class: the BatchNorm class to use
|
254 |
+
kwargs: arguments of the BatchNorm class, such as num_features.
|
255 |
+
"""
|
256 |
+
self._affine = kwargs.pop("affine", True)
|
257 |
+
super().__init__([bn_class(**kwargs, affine=False) for k in range(length)])
|
258 |
+
if self._affine:
|
259 |
+
# shared affine, domain-specific BN
|
260 |
+
channels = self[0].num_features
|
261 |
+
self.weight = nn.Parameter(torch.ones(channels))
|
262 |
+
self.bias = nn.Parameter(torch.zeros(channels))
|
263 |
+
self._pos = 0
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
ret = self[self._pos](x)
|
267 |
+
self._pos = (self._pos + 1) % len(self)
|
268 |
+
|
269 |
+
if self._affine:
|
270 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
271 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
272 |
+
return ret * w + b
|
273 |
+
else:
|
274 |
+
return ret
|
275 |
+
|
276 |
+
def extra_repr(self):
|
277 |
+
return f"affine={self._affine}"
|
278 |
+
|
279 |
+
|
280 |
+
class LayerNorm(nn.Module):
|
281 |
+
"""
|
282 |
+
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
|
283 |
+
variance normalization over the channel dimension for inputs that have shape
|
284 |
+
(batch_size, channels, height, width).
|
285 |
+
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
|
286 |
+
"""
|
287 |
+
|
288 |
+
def __init__(self, normalized_shape, eps=1e-6):
|
289 |
+
super().__init__()
|
290 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
291 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
292 |
+
self.eps = eps
|
293 |
+
self.normalized_shape = (normalized_shape,)
|
294 |
+
|
295 |
+
def forward(self, x):
|
296 |
+
u = x.mean(1, keepdim=True)
|
297 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
298 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
299 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
300 |
+
return x
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/blocks.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import fvcore.nn.weight_init as weight_init
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from .batch_norm import FrozenBatchNorm2d, get_norm
|
8 |
+
from .wrappers import Conv2d
|
9 |
+
|
10 |
+
|
11 |
+
"""
|
12 |
+
CNN building blocks.
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
class CNNBlockBase(nn.Module):
|
17 |
+
"""
|
18 |
+
A CNN block is assumed to have input channels, output channels and a stride.
|
19 |
+
The input and output of `forward()` method must be NCHW tensors.
|
20 |
+
The method can perform arbitrary computation but must match the given
|
21 |
+
channels and stride specification.
|
22 |
+
|
23 |
+
Attribute:
|
24 |
+
in_channels (int):
|
25 |
+
out_channels (int):
|
26 |
+
stride (int):
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, in_channels, out_channels, stride):
|
30 |
+
"""
|
31 |
+
The `__init__` method of any subclass should also contain these arguments.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
in_channels (int):
|
35 |
+
out_channels (int):
|
36 |
+
stride (int):
|
37 |
+
"""
|
38 |
+
super().__init__()
|
39 |
+
self.in_channels = in_channels
|
40 |
+
self.out_channels = out_channels
|
41 |
+
self.stride = stride
|
42 |
+
|
43 |
+
def freeze(self):
|
44 |
+
"""
|
45 |
+
Make this block not trainable.
|
46 |
+
This method sets all parameters to `requires_grad=False`,
|
47 |
+
and convert all BatchNorm layers to FrozenBatchNorm
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
the block itself
|
51 |
+
"""
|
52 |
+
for p in self.parameters():
|
53 |
+
p.requires_grad = False
|
54 |
+
FrozenBatchNorm2d.convert_frozen_batchnorm(self)
|
55 |
+
return self
|
56 |
+
|
57 |
+
|
58 |
+
class DepthwiseSeparableConv2d(nn.Module):
|
59 |
+
"""
|
60 |
+
A kxk depthwise convolution + a 1x1 convolution.
|
61 |
+
|
62 |
+
In :paper:`xception`, norm & activation are applied on the second conv.
|
63 |
+
:paper:`mobilenet` uses norm & activation on both convs.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
in_channels,
|
69 |
+
out_channels,
|
70 |
+
kernel_size=3,
|
71 |
+
padding=1,
|
72 |
+
dilation=1,
|
73 |
+
*,
|
74 |
+
norm1=None,
|
75 |
+
activation1=None,
|
76 |
+
norm2=None,
|
77 |
+
activation2=None,
|
78 |
+
):
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
norm1, norm2 (str or callable): normalization for the two conv layers.
|
82 |
+
activation1, activation2 (callable(Tensor) -> Tensor): activation
|
83 |
+
function for the two conv layers.
|
84 |
+
"""
|
85 |
+
super().__init__()
|
86 |
+
self.depthwise = Conv2d(
|
87 |
+
in_channels,
|
88 |
+
in_channels,
|
89 |
+
kernel_size=kernel_size,
|
90 |
+
padding=padding,
|
91 |
+
dilation=dilation,
|
92 |
+
groups=in_channels,
|
93 |
+
bias=not norm1,
|
94 |
+
norm=get_norm(norm1, in_channels),
|
95 |
+
activation=activation1,
|
96 |
+
)
|
97 |
+
self.pointwise = Conv2d(
|
98 |
+
in_channels,
|
99 |
+
out_channels,
|
100 |
+
kernel_size=1,
|
101 |
+
bias=not norm2,
|
102 |
+
norm=get_norm(norm2, out_channels),
|
103 |
+
activation=activation2,
|
104 |
+
)
|
105 |
+
|
106 |
+
# default initialization
|
107 |
+
weight_init.c2_msra_fill(self.depthwise)
|
108 |
+
weight_init.c2_msra_fill(self.pointwise)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
return self.pointwise(self.depthwise(x))
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
To add a new Op:
|
4 |
+
|
5 |
+
1. Create a new directory
|
6 |
+
2. Implement new ops there
|
7 |
+
3. Delcare its Python interface in `vision.cpp`.
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#pragma once
|
3 |
+
#include <torch/types.h>
|
4 |
+
|
5 |
+
namespace detectron2 {
|
6 |
+
|
7 |
+
at::Tensor ROIAlignRotated_forward_cpu(
|
8 |
+
const at::Tensor& input,
|
9 |
+
const at::Tensor& rois,
|
10 |
+
const float spatial_scale,
|
11 |
+
const int pooled_height,
|
12 |
+
const int pooled_width,
|
13 |
+
const int sampling_ratio);
|
14 |
+
|
15 |
+
at::Tensor ROIAlignRotated_backward_cpu(
|
16 |
+
const at::Tensor& grad,
|
17 |
+
const at::Tensor& rois,
|
18 |
+
const float spatial_scale,
|
19 |
+
const int pooled_height,
|
20 |
+
const int pooled_width,
|
21 |
+
const int batch_size,
|
22 |
+
const int channels,
|
23 |
+
const int height,
|
24 |
+
const int width,
|
25 |
+
const int sampling_ratio);
|
26 |
+
|
27 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
28 |
+
at::Tensor ROIAlignRotated_forward_cuda(
|
29 |
+
const at::Tensor& input,
|
30 |
+
const at::Tensor& rois,
|
31 |
+
const float spatial_scale,
|
32 |
+
const int pooled_height,
|
33 |
+
const int pooled_width,
|
34 |
+
const int sampling_ratio);
|
35 |
+
|
36 |
+
at::Tensor ROIAlignRotated_backward_cuda(
|
37 |
+
const at::Tensor& grad,
|
38 |
+
const at::Tensor& rois,
|
39 |
+
const float spatial_scale,
|
40 |
+
const int pooled_height,
|
41 |
+
const int pooled_width,
|
42 |
+
const int batch_size,
|
43 |
+
const int channels,
|
44 |
+
const int height,
|
45 |
+
const int width,
|
46 |
+
const int sampling_ratio);
|
47 |
+
#endif
|
48 |
+
|
49 |
+
// Interface for Python
|
50 |
+
inline at::Tensor ROIAlignRotated_forward(
|
51 |
+
const at::Tensor& input,
|
52 |
+
const at::Tensor& rois,
|
53 |
+
const double spatial_scale,
|
54 |
+
const int64_t pooled_height,
|
55 |
+
const int64_t pooled_width,
|
56 |
+
const int64_t sampling_ratio) {
|
57 |
+
if (input.is_cuda()) {
|
58 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
59 |
+
return ROIAlignRotated_forward_cuda(
|
60 |
+
input,
|
61 |
+
rois,
|
62 |
+
spatial_scale,
|
63 |
+
pooled_height,
|
64 |
+
pooled_width,
|
65 |
+
sampling_ratio);
|
66 |
+
#else
|
67 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
68 |
+
#endif
|
69 |
+
}
|
70 |
+
return ROIAlignRotated_forward_cpu(
|
71 |
+
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
|
72 |
+
}
|
73 |
+
|
74 |
+
inline at::Tensor ROIAlignRotated_backward(
|
75 |
+
const at::Tensor& grad,
|
76 |
+
const at::Tensor& rois,
|
77 |
+
const double spatial_scale,
|
78 |
+
const int64_t pooled_height,
|
79 |
+
const int64_t pooled_width,
|
80 |
+
const int64_t batch_size,
|
81 |
+
const int64_t channels,
|
82 |
+
const int64_t height,
|
83 |
+
const int64_t width,
|
84 |
+
const int64_t sampling_ratio) {
|
85 |
+
if (grad.is_cuda()) {
|
86 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
87 |
+
return ROIAlignRotated_backward_cuda(
|
88 |
+
grad,
|
89 |
+
rois,
|
90 |
+
spatial_scale,
|
91 |
+
pooled_height,
|
92 |
+
pooled_width,
|
93 |
+
batch_size,
|
94 |
+
channels,
|
95 |
+
height,
|
96 |
+
width,
|
97 |
+
sampling_ratio);
|
98 |
+
#else
|
99 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
100 |
+
#endif
|
101 |
+
}
|
102 |
+
return ROIAlignRotated_backward_cpu(
|
103 |
+
grad,
|
104 |
+
rois,
|
105 |
+
spatial_scale,
|
106 |
+
pooled_height,
|
107 |
+
pooled_width,
|
108 |
+
batch_size,
|
109 |
+
channels,
|
110 |
+
height,
|
111 |
+
width,
|
112 |
+
sampling_ratio);
|
113 |
+
}
|
114 |
+
|
115 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#include <ATen/TensorUtils.h>
|
3 |
+
#include "ROIAlignRotated.h"
|
4 |
+
|
5 |
+
// Note: this implementation originates from the Caffe2 ROIAlignRotated Op
|
6 |
+
// and PyTorch ROIAlign (non-rotated) Op implementations.
|
7 |
+
// The key difference between this implementation and those ones is
|
8 |
+
// we don't do "legacy offset" in this version, as there aren't many previous
|
9 |
+
// works, if any, using the "legacy" ROIAlignRotated Op.
|
10 |
+
// This would make the interface a bit cleaner.
|
11 |
+
|
12 |
+
namespace detectron2 {
|
13 |
+
|
14 |
+
namespace {
|
15 |
+
template <typename T>
|
16 |
+
struct PreCalc {
|
17 |
+
int pos1;
|
18 |
+
int pos2;
|
19 |
+
int pos3;
|
20 |
+
int pos4;
|
21 |
+
T w1;
|
22 |
+
T w2;
|
23 |
+
T w3;
|
24 |
+
T w4;
|
25 |
+
};
|
26 |
+
|
27 |
+
template <typename T>
|
28 |
+
void pre_calc_for_bilinear_interpolate(
|
29 |
+
const int height,
|
30 |
+
const int width,
|
31 |
+
const int pooled_height,
|
32 |
+
const int pooled_width,
|
33 |
+
const int iy_upper,
|
34 |
+
const int ix_upper,
|
35 |
+
T roi_start_h,
|
36 |
+
T roi_start_w,
|
37 |
+
T bin_size_h,
|
38 |
+
T bin_size_w,
|
39 |
+
int roi_bin_grid_h,
|
40 |
+
int roi_bin_grid_w,
|
41 |
+
T roi_center_h,
|
42 |
+
T roi_center_w,
|
43 |
+
T cos_theta,
|
44 |
+
T sin_theta,
|
45 |
+
std::vector<PreCalc<T>>& pre_calc) {
|
46 |
+
int pre_calc_index = 0;
|
47 |
+
for (int ph = 0; ph < pooled_height; ph++) {
|
48 |
+
for (int pw = 0; pw < pooled_width; pw++) {
|
49 |
+
for (int iy = 0; iy < iy_upper; iy++) {
|
50 |
+
const T yy = roi_start_h + ph * bin_size_h +
|
51 |
+
static_cast<T>(iy + .5f) * bin_size_h /
|
52 |
+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
53 |
+
for (int ix = 0; ix < ix_upper; ix++) {
|
54 |
+
const T xx = roi_start_w + pw * bin_size_w +
|
55 |
+
static_cast<T>(ix + .5f) * bin_size_w /
|
56 |
+
static_cast<T>(roi_bin_grid_w);
|
57 |
+
|
58 |
+
// Rotate by theta around the center and translate
|
59 |
+
// In image space, (y, x) is the order for Right Handed System,
|
60 |
+
// and this is essentially multiplying the point by a rotation matrix
|
61 |
+
// to rotate it counterclockwise through angle theta.
|
62 |
+
T y = yy * cos_theta - xx * sin_theta + roi_center_h;
|
63 |
+
T x = yy * sin_theta + xx * cos_theta + roi_center_w;
|
64 |
+
// deal with: inverse elements are out of feature map boundary
|
65 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
66 |
+
// empty
|
67 |
+
PreCalc<T> pc;
|
68 |
+
pc.pos1 = 0;
|
69 |
+
pc.pos2 = 0;
|
70 |
+
pc.pos3 = 0;
|
71 |
+
pc.pos4 = 0;
|
72 |
+
pc.w1 = 0;
|
73 |
+
pc.w2 = 0;
|
74 |
+
pc.w3 = 0;
|
75 |
+
pc.w4 = 0;
|
76 |
+
pre_calc[pre_calc_index] = pc;
|
77 |
+
pre_calc_index += 1;
|
78 |
+
continue;
|
79 |
+
}
|
80 |
+
|
81 |
+
if (y < 0) {
|
82 |
+
y = 0;
|
83 |
+
}
|
84 |
+
if (x < 0) {
|
85 |
+
x = 0;
|
86 |
+
}
|
87 |
+
|
88 |
+
int y_low = (int)y;
|
89 |
+
int x_low = (int)x;
|
90 |
+
int y_high;
|
91 |
+
int x_high;
|
92 |
+
|
93 |
+
if (y_low >= height - 1) {
|
94 |
+
y_high = y_low = height - 1;
|
95 |
+
y = (T)y_low;
|
96 |
+
} else {
|
97 |
+
y_high = y_low + 1;
|
98 |
+
}
|
99 |
+
|
100 |
+
if (x_low >= width - 1) {
|
101 |
+
x_high = x_low = width - 1;
|
102 |
+
x = (T)x_low;
|
103 |
+
} else {
|
104 |
+
x_high = x_low + 1;
|
105 |
+
}
|
106 |
+
|
107 |
+
T ly = y - y_low;
|
108 |
+
T lx = x - x_low;
|
109 |
+
T hy = 1. - ly, hx = 1. - lx;
|
110 |
+
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
111 |
+
|
112 |
+
// save weights and indices
|
113 |
+
PreCalc<T> pc;
|
114 |
+
pc.pos1 = y_low * width + x_low;
|
115 |
+
pc.pos2 = y_low * width + x_high;
|
116 |
+
pc.pos3 = y_high * width + x_low;
|
117 |
+
pc.pos4 = y_high * width + x_high;
|
118 |
+
pc.w1 = w1;
|
119 |
+
pc.w2 = w2;
|
120 |
+
pc.w3 = w3;
|
121 |
+
pc.w4 = w4;
|
122 |
+
pre_calc[pre_calc_index] = pc;
|
123 |
+
|
124 |
+
pre_calc_index += 1;
|
125 |
+
}
|
126 |
+
}
|
127 |
+
}
|
128 |
+
}
|
129 |
+
}
|
130 |
+
|
131 |
+
template <typename T>
|
132 |
+
void bilinear_interpolate_gradient(
|
133 |
+
const int height,
|
134 |
+
const int width,
|
135 |
+
T y,
|
136 |
+
T x,
|
137 |
+
T& w1,
|
138 |
+
T& w2,
|
139 |
+
T& w3,
|
140 |
+
T& w4,
|
141 |
+
int& x_low,
|
142 |
+
int& x_high,
|
143 |
+
int& y_low,
|
144 |
+
int& y_high) {
|
145 |
+
// deal with cases that inverse elements are out of feature map boundary
|
146 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
147 |
+
// empty
|
148 |
+
w1 = w2 = w3 = w4 = 0.;
|
149 |
+
x_low = x_high = y_low = y_high = -1;
|
150 |
+
return;
|
151 |
+
}
|
152 |
+
|
153 |
+
if (y < 0) {
|
154 |
+
y = 0;
|
155 |
+
}
|
156 |
+
|
157 |
+
if (x < 0) {
|
158 |
+
x = 0;
|
159 |
+
}
|
160 |
+
|
161 |
+
y_low = (int)y;
|
162 |
+
x_low = (int)x;
|
163 |
+
|
164 |
+
if (y_low >= height - 1) {
|
165 |
+
y_high = y_low = height - 1;
|
166 |
+
y = (T)y_low;
|
167 |
+
} else {
|
168 |
+
y_high = y_low + 1;
|
169 |
+
}
|
170 |
+
|
171 |
+
if (x_low >= width - 1) {
|
172 |
+
x_high = x_low = width - 1;
|
173 |
+
x = (T)x_low;
|
174 |
+
} else {
|
175 |
+
x_high = x_low + 1;
|
176 |
+
}
|
177 |
+
|
178 |
+
T ly = y - y_low;
|
179 |
+
T lx = x - x_low;
|
180 |
+
T hy = 1. - ly, hx = 1. - lx;
|
181 |
+
|
182 |
+
// reference in forward
|
183 |
+
// T v1 = input[y_low * width + x_low];
|
184 |
+
// T v2 = input[y_low * width + x_high];
|
185 |
+
// T v3 = input[y_high * width + x_low];
|
186 |
+
// T v4 = input[y_high * width + x_high];
|
187 |
+
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
188 |
+
|
189 |
+
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
190 |
+
|
191 |
+
return;
|
192 |
+
}
|
193 |
+
|
194 |
+
template <class T>
|
195 |
+
inline void add(T* address, const T& val) {
|
196 |
+
*address += val;
|
197 |
+
}
|
198 |
+
|
199 |
+
} // namespace
|
200 |
+
|
201 |
+
template <typename T>
|
202 |
+
void ROIAlignRotatedForward(
|
203 |
+
const int nthreads,
|
204 |
+
const T* input,
|
205 |
+
const T& spatial_scale,
|
206 |
+
const int channels,
|
207 |
+
const int height,
|
208 |
+
const int width,
|
209 |
+
const int pooled_height,
|
210 |
+
const int pooled_width,
|
211 |
+
const int sampling_ratio,
|
212 |
+
const T* rois,
|
213 |
+
T* output) {
|
214 |
+
int n_rois = nthreads / channels / pooled_width / pooled_height;
|
215 |
+
// (n, c, ph, pw) is an element in the pooled output
|
216 |
+
// can be parallelized using omp
|
217 |
+
// #pragma omp parallel for num_threads(32)
|
218 |
+
for (int n = 0; n < n_rois; n++) {
|
219 |
+
int index_n = n * channels * pooled_width * pooled_height;
|
220 |
+
|
221 |
+
const T* current_roi = rois + n * 6;
|
222 |
+
int roi_batch_ind = current_roi[0];
|
223 |
+
|
224 |
+
// Do not use rounding; this implementation detail is critical
|
225 |
+
// ROIAlignRotated supports align == true, i.e., continuous coordinate
|
226 |
+
// by default, thus the 0.5 offset
|
227 |
+
T offset = (T)0.5;
|
228 |
+
T roi_center_w = current_roi[1] * spatial_scale - offset;
|
229 |
+
T roi_center_h = current_roi[2] * spatial_scale - offset;
|
230 |
+
T roi_width = current_roi[3] * spatial_scale;
|
231 |
+
T roi_height = current_roi[4] * spatial_scale;
|
232 |
+
T theta = current_roi[5] * M_PI / 180.0;
|
233 |
+
T cos_theta = cos(theta);
|
234 |
+
T sin_theta = sin(theta);
|
235 |
+
|
236 |
+
AT_ASSERTM(
|
237 |
+
roi_width >= 0 && roi_height >= 0,
|
238 |
+
"ROIs in ROIAlignRotated do not have non-negative size!");
|
239 |
+
|
240 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
241 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
242 |
+
|
243 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
244 |
+
int roi_bin_grid_h = (sampling_ratio > 0)
|
245 |
+
? sampling_ratio
|
246 |
+
: ceil(roi_height / pooled_height); // e.g., = 2
|
247 |
+
int roi_bin_grid_w =
|
248 |
+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
249 |
+
|
250 |
+
// We do average (integral) pooling inside a bin
|
251 |
+
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
252 |
+
|
253 |
+
// we want to precalculate indices and weights shared by all channels,
|
254 |
+
// this is the key point of optimization
|
255 |
+
std::vector<PreCalc<T>> pre_calc(
|
256 |
+
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
|
257 |
+
|
258 |
+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
259 |
+
// Appropriate translation needs to be applied after.
|
260 |
+
T roi_start_h = -roi_height / 2.0;
|
261 |
+
T roi_start_w = -roi_width / 2.0;
|
262 |
+
|
263 |
+
pre_calc_for_bilinear_interpolate(
|
264 |
+
height,
|
265 |
+
width,
|
266 |
+
pooled_height,
|
267 |
+
pooled_width,
|
268 |
+
roi_bin_grid_h,
|
269 |
+
roi_bin_grid_w,
|
270 |
+
roi_start_h,
|
271 |
+
roi_start_w,
|
272 |
+
bin_size_h,
|
273 |
+
bin_size_w,
|
274 |
+
roi_bin_grid_h,
|
275 |
+
roi_bin_grid_w,
|
276 |
+
roi_center_h,
|
277 |
+
roi_center_w,
|
278 |
+
cos_theta,
|
279 |
+
sin_theta,
|
280 |
+
pre_calc);
|
281 |
+
|
282 |
+
for (int c = 0; c < channels; c++) {
|
283 |
+
int index_n_c = index_n + c * pooled_width * pooled_height;
|
284 |
+
const T* offset_input =
|
285 |
+
input + (roi_batch_ind * channels + c) * height * width;
|
286 |
+
int pre_calc_index = 0;
|
287 |
+
|
288 |
+
for (int ph = 0; ph < pooled_height; ph++) {
|
289 |
+
for (int pw = 0; pw < pooled_width; pw++) {
|
290 |
+
int index = index_n_c + ph * pooled_width + pw;
|
291 |
+
|
292 |
+
T output_val = 0.;
|
293 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
294 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
295 |
+
PreCalc<T> pc = pre_calc[pre_calc_index];
|
296 |
+
output_val += pc.w1 * offset_input[pc.pos1] +
|
297 |
+
pc.w2 * offset_input[pc.pos2] +
|
298 |
+
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
|
299 |
+
|
300 |
+
pre_calc_index += 1;
|
301 |
+
}
|
302 |
+
}
|
303 |
+
output_val /= count;
|
304 |
+
|
305 |
+
output[index] = output_val;
|
306 |
+
} // for pw
|
307 |
+
} // for ph
|
308 |
+
} // for c
|
309 |
+
} // for n
|
310 |
+
}
|
311 |
+
|
312 |
+
template <typename T>
|
313 |
+
void ROIAlignRotatedBackward(
|
314 |
+
const int nthreads,
|
315 |
+
// may not be contiguous. should index using n_stride, etc
|
316 |
+
const T* grad_output,
|
317 |
+
const T& spatial_scale,
|
318 |
+
const int channels,
|
319 |
+
const int height,
|
320 |
+
const int width,
|
321 |
+
const int pooled_height,
|
322 |
+
const int pooled_width,
|
323 |
+
const int sampling_ratio,
|
324 |
+
T* grad_input,
|
325 |
+
const T* rois,
|
326 |
+
const int n_stride,
|
327 |
+
const int c_stride,
|
328 |
+
const int h_stride,
|
329 |
+
const int w_stride) {
|
330 |
+
for (int index = 0; index < nthreads; index++) {
|
331 |
+
// (n, c, ph, pw) is an element in the pooled output
|
332 |
+
int pw = index % pooled_width;
|
333 |
+
int ph = (index / pooled_width) % pooled_height;
|
334 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
335 |
+
int n = index / pooled_width / pooled_height / channels;
|
336 |
+
|
337 |
+
const T* current_roi = rois + n * 6;
|
338 |
+
int roi_batch_ind = current_roi[0];
|
339 |
+
|
340 |
+
// Do not use rounding; this implementation detail is critical
|
341 |
+
// ROIAlignRotated supports align == true, i.e., continuous coordinate
|
342 |
+
// by default, thus the 0.5 offset
|
343 |
+
T offset = (T)0.5;
|
344 |
+
T roi_center_w = current_roi[1] * spatial_scale - offset;
|
345 |
+
T roi_center_h = current_roi[2] * spatial_scale - offset;
|
346 |
+
T roi_width = current_roi[3] * spatial_scale;
|
347 |
+
T roi_height = current_roi[4] * spatial_scale;
|
348 |
+
T theta = current_roi[5] * M_PI / 180.0;
|
349 |
+
T cos_theta = cos(theta);
|
350 |
+
T sin_theta = sin(theta);
|
351 |
+
|
352 |
+
AT_ASSERTM(
|
353 |
+
roi_width >= 0 && roi_height >= 0,
|
354 |
+
"ROIs in ROIAlignRotated do not have non-negative size!");
|
355 |
+
|
356 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
357 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
358 |
+
|
359 |
+
T* offset_grad_input =
|
360 |
+
grad_input + ((roi_batch_ind * channels + c) * height * width);
|
361 |
+
|
362 |
+
int output_offset = n * n_stride + c * c_stride;
|
363 |
+
const T* offset_grad_output = grad_output + output_offset;
|
364 |
+
const T grad_output_this_bin =
|
365 |
+
offset_grad_output[ph * h_stride + pw * w_stride];
|
366 |
+
|
367 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
368 |
+
int roi_bin_grid_h = (sampling_ratio > 0)
|
369 |
+
? sampling_ratio
|
370 |
+
: ceil(roi_height / pooled_height); // e.g., = 2
|
371 |
+
int roi_bin_grid_w =
|
372 |
+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
373 |
+
|
374 |
+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
375 |
+
// Appropriate translation needs to be applied after.
|
376 |
+
T roi_start_h = -roi_height / 2.0;
|
377 |
+
T roi_start_w = -roi_width / 2.0;
|
378 |
+
|
379 |
+
// We do average (integral) pooling inside a bin
|
380 |
+
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
381 |
+
|
382 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
383 |
+
const T yy = roi_start_h + ph * bin_size_h +
|
384 |
+
static_cast<T>(iy + .5f) * bin_size_h /
|
385 |
+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
386 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
387 |
+
const T xx = roi_start_w + pw * bin_size_w +
|
388 |
+
static_cast<T>(ix + .5f) * bin_size_w /
|
389 |
+
static_cast<T>(roi_bin_grid_w);
|
390 |
+
|
391 |
+
// Rotate by theta around the center and translate
|
392 |
+
T y = yy * cos_theta - xx * sin_theta + roi_center_h;
|
393 |
+
T x = yy * sin_theta + xx * cos_theta + roi_center_w;
|
394 |
+
|
395 |
+
T w1, w2, w3, w4;
|
396 |
+
int x_low, x_high, y_low, y_high;
|
397 |
+
|
398 |
+
bilinear_interpolate_gradient(
|
399 |
+
height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high);
|
400 |
+
|
401 |
+
T g1 = grad_output_this_bin * w1 / count;
|
402 |
+
T g2 = grad_output_this_bin * w2 / count;
|
403 |
+
T g3 = grad_output_this_bin * w3 / count;
|
404 |
+
T g4 = grad_output_this_bin * w4 / count;
|
405 |
+
|
406 |
+
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
407 |
+
// atomic add is not needed for now since it is single threaded
|
408 |
+
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
|
409 |
+
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
|
410 |
+
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
|
411 |
+
add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
|
412 |
+
} // if
|
413 |
+
} // ix
|
414 |
+
} // iy
|
415 |
+
} // for
|
416 |
+
} // ROIAlignRotatedBackward
|
417 |
+
|
418 |
+
at::Tensor ROIAlignRotated_forward_cpu(
|
419 |
+
const at::Tensor& input,
|
420 |
+
const at::Tensor& rois,
|
421 |
+
const float spatial_scale,
|
422 |
+
const int pooled_height,
|
423 |
+
const int pooled_width,
|
424 |
+
const int sampling_ratio) {
|
425 |
+
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
|
426 |
+
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
|
427 |
+
|
428 |
+
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
|
429 |
+
|
430 |
+
at::CheckedFrom c = "ROIAlign_forward_cpu";
|
431 |
+
at::checkAllSameType(c, {input_t, rois_t});
|
432 |
+
|
433 |
+
auto num_rois = rois.size(0);
|
434 |
+
auto channels = input.size(1);
|
435 |
+
auto height = input.size(2);
|
436 |
+
auto width = input.size(3);
|
437 |
+
|
438 |
+
at::Tensor output = at::zeros(
|
439 |
+
{num_rois, channels, pooled_height, pooled_width}, input.options());
|
440 |
+
|
441 |
+
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
442 |
+
|
443 |
+
if (output.numel() == 0) {
|
444 |
+
return output;
|
445 |
+
}
|
446 |
+
|
447 |
+
auto input_ = input.contiguous(), rois_ = rois.contiguous();
|
448 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
449 |
+
input.scalar_type(), "ROIAlignRotated_forward", [&] {
|
450 |
+
ROIAlignRotatedForward<scalar_t>(
|
451 |
+
output_size,
|
452 |
+
input_.data_ptr<scalar_t>(),
|
453 |
+
spatial_scale,
|
454 |
+
channels,
|
455 |
+
height,
|
456 |
+
width,
|
457 |
+
pooled_height,
|
458 |
+
pooled_width,
|
459 |
+
sampling_ratio,
|
460 |
+
rois_.data_ptr<scalar_t>(),
|
461 |
+
output.data_ptr<scalar_t>());
|
462 |
+
});
|
463 |
+
return output;
|
464 |
+
}
|
465 |
+
|
466 |
+
at::Tensor ROIAlignRotated_backward_cpu(
|
467 |
+
const at::Tensor& grad,
|
468 |
+
const at::Tensor& rois,
|
469 |
+
const float spatial_scale,
|
470 |
+
const int pooled_height,
|
471 |
+
const int pooled_width,
|
472 |
+
const int batch_size,
|
473 |
+
const int channels,
|
474 |
+
const int height,
|
475 |
+
const int width,
|
476 |
+
const int sampling_ratio) {
|
477 |
+
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
|
478 |
+
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
|
479 |
+
|
480 |
+
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
|
481 |
+
|
482 |
+
at::CheckedFrom c = "ROIAlignRotated_backward_cpu";
|
483 |
+
at::checkAllSameType(c, {grad_t, rois_t});
|
484 |
+
|
485 |
+
at::Tensor grad_input =
|
486 |
+
at::zeros({batch_size, channels, height, width}, grad.options());
|
487 |
+
|
488 |
+
// handle possibly empty gradients
|
489 |
+
if (grad.numel() == 0) {
|
490 |
+
return grad_input;
|
491 |
+
}
|
492 |
+
|
493 |
+
// get stride values to ensure indexing into gradients is correct.
|
494 |
+
int n_stride = grad.stride(0);
|
495 |
+
int c_stride = grad.stride(1);
|
496 |
+
int h_stride = grad.stride(2);
|
497 |
+
int w_stride = grad.stride(3);
|
498 |
+
|
499 |
+
auto rois_ = rois.contiguous();
|
500 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
501 |
+
grad.scalar_type(), "ROIAlignRotated_forward", [&] {
|
502 |
+
ROIAlignRotatedBackward<scalar_t>(
|
503 |
+
grad.numel(),
|
504 |
+
grad.data_ptr<scalar_t>(),
|
505 |
+
spatial_scale,
|
506 |
+
channels,
|
507 |
+
height,
|
508 |
+
width,
|
509 |
+
pooled_height,
|
510 |
+
pooled_width,
|
511 |
+
sampling_ratio,
|
512 |
+
grad_input.data_ptr<scalar_t>(),
|
513 |
+
rois_.data_ptr<scalar_t>(),
|
514 |
+
n_stride,
|
515 |
+
c_stride,
|
516 |
+
h_stride,
|
517 |
+
w_stride);
|
518 |
+
});
|
519 |
+
return grad_input;
|
520 |
+
}
|
521 |
+
|
522 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <ATen/cuda/CUDAContext.h>
|
4 |
+
#include <c10/cuda/CUDAGuard.h>
|
5 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
6 |
+
|
7 |
+
// TODO make it in a common file
|
8 |
+
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
9 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
10 |
+
i += blockDim.x * gridDim.x)
|
11 |
+
|
12 |
+
// Note: this implementation originates from the Caffe2 ROIAlignRotated Op
|
13 |
+
// and PyTorch ROIAlign (non-rotated) Op implementations.
|
14 |
+
// The key difference between this implementation and those ones is
|
15 |
+
// we don't do "legacy offset" in this version, as there aren't many previous
|
16 |
+
// works, if any, using the "legacy" ROIAlignRotated Op.
|
17 |
+
// This would make the interface a bit cleaner.
|
18 |
+
|
19 |
+
namespace detectron2 {
|
20 |
+
|
21 |
+
namespace {
|
22 |
+
|
23 |
+
template <typename T>
|
24 |
+
__device__ T bilinear_interpolate(
|
25 |
+
const T* input,
|
26 |
+
const int height,
|
27 |
+
const int width,
|
28 |
+
T y,
|
29 |
+
T x) {
|
30 |
+
// deal with cases that inverse elements are out of feature map boundary
|
31 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
32 |
+
// empty
|
33 |
+
return 0;
|
34 |
+
}
|
35 |
+
|
36 |
+
if (y < 0) {
|
37 |
+
y = 0;
|
38 |
+
}
|
39 |
+
|
40 |
+
if (x < 0) {
|
41 |
+
x = 0;
|
42 |
+
}
|
43 |
+
|
44 |
+
int y_low = (int)y;
|
45 |
+
int x_low = (int)x;
|
46 |
+
int y_high;
|
47 |
+
int x_high;
|
48 |
+
|
49 |
+
if (y_low >= height - 1) {
|
50 |
+
y_high = y_low = height - 1;
|
51 |
+
y = (T)y_low;
|
52 |
+
} else {
|
53 |
+
y_high = y_low + 1;
|
54 |
+
}
|
55 |
+
|
56 |
+
if (x_low >= width - 1) {
|
57 |
+
x_high = x_low = width - 1;
|
58 |
+
x = (T)x_low;
|
59 |
+
} else {
|
60 |
+
x_high = x_low + 1;
|
61 |
+
}
|
62 |
+
|
63 |
+
T ly = y - y_low;
|
64 |
+
T lx = x - x_low;
|
65 |
+
T hy = 1. - ly, hx = 1. - lx;
|
66 |
+
// do bilinear interpolation
|
67 |
+
T v1 = input[y_low * width + x_low];
|
68 |
+
T v2 = input[y_low * width + x_high];
|
69 |
+
T v3 = input[y_high * width + x_low];
|
70 |
+
T v4 = input[y_high * width + x_high];
|
71 |
+
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
72 |
+
|
73 |
+
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
74 |
+
|
75 |
+
return val;
|
76 |
+
}
|
77 |
+
|
78 |
+
template <typename T>
|
79 |
+
__device__ void bilinear_interpolate_gradient(
|
80 |
+
const int height,
|
81 |
+
const int width,
|
82 |
+
T y,
|
83 |
+
T x,
|
84 |
+
T& w1,
|
85 |
+
T& w2,
|
86 |
+
T& w3,
|
87 |
+
T& w4,
|
88 |
+
int& x_low,
|
89 |
+
int& x_high,
|
90 |
+
int& y_low,
|
91 |
+
int& y_high) {
|
92 |
+
// deal with cases that inverse elements are out of feature map boundary
|
93 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
94 |
+
// empty
|
95 |
+
w1 = w2 = w3 = w4 = 0.;
|
96 |
+
x_low = x_high = y_low = y_high = -1;
|
97 |
+
return;
|
98 |
+
}
|
99 |
+
|
100 |
+
if (y < 0) {
|
101 |
+
y = 0;
|
102 |
+
}
|
103 |
+
|
104 |
+
if (x < 0) {
|
105 |
+
x = 0;
|
106 |
+
}
|
107 |
+
|
108 |
+
y_low = (int)y;
|
109 |
+
x_low = (int)x;
|
110 |
+
|
111 |
+
if (y_low >= height - 1) {
|
112 |
+
y_high = y_low = height - 1;
|
113 |
+
y = (T)y_low;
|
114 |
+
} else {
|
115 |
+
y_high = y_low + 1;
|
116 |
+
}
|
117 |
+
|
118 |
+
if (x_low >= width - 1) {
|
119 |
+
x_high = x_low = width - 1;
|
120 |
+
x = (T)x_low;
|
121 |
+
} else {
|
122 |
+
x_high = x_low + 1;
|
123 |
+
}
|
124 |
+
|
125 |
+
T ly = y - y_low;
|
126 |
+
T lx = x - x_low;
|
127 |
+
T hy = 1. - ly, hx = 1. - lx;
|
128 |
+
|
129 |
+
// reference in forward
|
130 |
+
// T v1 = input[y_low * width + x_low];
|
131 |
+
// T v2 = input[y_low * width + x_high];
|
132 |
+
// T v3 = input[y_high * width + x_low];
|
133 |
+
// T v4 = input[y_high * width + x_high];
|
134 |
+
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
135 |
+
|
136 |
+
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
137 |
+
|
138 |
+
return;
|
139 |
+
}
|
140 |
+
|
141 |
+
} // namespace
|
142 |
+
|
143 |
+
template <typename T>
|
144 |
+
__global__ void RoIAlignRotatedForward(
|
145 |
+
const int nthreads,
|
146 |
+
const T* input,
|
147 |
+
const T spatial_scale,
|
148 |
+
const int channels,
|
149 |
+
const int height,
|
150 |
+
const int width,
|
151 |
+
const int pooled_height,
|
152 |
+
const int pooled_width,
|
153 |
+
const int sampling_ratio,
|
154 |
+
const T* rois,
|
155 |
+
T* top_data) {
|
156 |
+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
157 |
+
// (n, c, ph, pw) is an element in the pooled output
|
158 |
+
int pw = index % pooled_width;
|
159 |
+
int ph = (index / pooled_width) % pooled_height;
|
160 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
161 |
+
int n = index / pooled_width / pooled_height / channels;
|
162 |
+
|
163 |
+
const T* current_roi = rois + n * 6;
|
164 |
+
int roi_batch_ind = current_roi[0];
|
165 |
+
|
166 |
+
// Do not use rounding; this implementation detail is critical
|
167 |
+
// ROIAlignRotated supports align == true, i.e., continuous coordinate
|
168 |
+
// by default, thus the 0.5 offset
|
169 |
+
T offset = (T)0.5;
|
170 |
+
T roi_center_w = current_roi[1] * spatial_scale - offset;
|
171 |
+
T roi_center_h = current_roi[2] * spatial_scale - offset;
|
172 |
+
T roi_width = current_roi[3] * spatial_scale;
|
173 |
+
T roi_height = current_roi[4] * spatial_scale;
|
174 |
+
T theta = current_roi[5] * M_PI / 180.0;
|
175 |
+
T cos_theta = cos(theta);
|
176 |
+
T sin_theta = sin(theta);
|
177 |
+
|
178 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
179 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
180 |
+
|
181 |
+
const T* offset_input =
|
182 |
+
input + (roi_batch_ind * channels + c) * height * width;
|
183 |
+
|
184 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
185 |
+
int roi_bin_grid_h = (sampling_ratio > 0)
|
186 |
+
? sampling_ratio
|
187 |
+
: ceil(roi_height / pooled_height); // e.g., = 2
|
188 |
+
int roi_bin_grid_w =
|
189 |
+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
190 |
+
|
191 |
+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
192 |
+
// Appropriate translation needs to be applied after.
|
193 |
+
T roi_start_h = -roi_height / 2.0;
|
194 |
+
T roi_start_w = -roi_width / 2.0;
|
195 |
+
|
196 |
+
// We do average (inte gral) pooling inside a bin
|
197 |
+
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
198 |
+
|
199 |
+
T output_val = 0.;
|
200 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
|
201 |
+
{
|
202 |
+
const T yy = roi_start_h + ph * bin_size_h +
|
203 |
+
static_cast<T>(iy + .5f) * bin_size_h /
|
204 |
+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
205 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
206 |
+
const T xx = roi_start_w + pw * bin_size_w +
|
207 |
+
static_cast<T>(ix + .5f) * bin_size_w /
|
208 |
+
static_cast<T>(roi_bin_grid_w);
|
209 |
+
|
210 |
+
// Rotate by theta around the center and translate
|
211 |
+
T y = yy * cos_theta - xx * sin_theta + roi_center_h;
|
212 |
+
T x = yy * sin_theta + xx * cos_theta + roi_center_w;
|
213 |
+
|
214 |
+
T val = bilinear_interpolate(offset_input, height, width, y, x);
|
215 |
+
output_val += val;
|
216 |
+
}
|
217 |
+
}
|
218 |
+
output_val /= count;
|
219 |
+
|
220 |
+
top_data[index] = output_val;
|
221 |
+
}
|
222 |
+
}
|
223 |
+
|
224 |
+
template <typename T>
|
225 |
+
__global__ void RoIAlignRotatedBackwardFeature(
|
226 |
+
const int nthreads,
|
227 |
+
const T* top_diff,
|
228 |
+
const int num_rois,
|
229 |
+
const T spatial_scale,
|
230 |
+
const int channels,
|
231 |
+
const int height,
|
232 |
+
const int width,
|
233 |
+
const int pooled_height,
|
234 |
+
const int pooled_width,
|
235 |
+
const int sampling_ratio,
|
236 |
+
T* bottom_diff,
|
237 |
+
const T* rois) {
|
238 |
+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
239 |
+
// (n, c, ph, pw) is an element in the pooled output
|
240 |
+
int pw = index % pooled_width;
|
241 |
+
int ph = (index / pooled_width) % pooled_height;
|
242 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
243 |
+
int n = index / pooled_width / pooled_height / channels;
|
244 |
+
|
245 |
+
const T* current_roi = rois + n * 6;
|
246 |
+
int roi_batch_ind = current_roi[0];
|
247 |
+
|
248 |
+
// Do not use rounding; this implementation detail is critical
|
249 |
+
// ROIAlignRotated supports align == true, i.e., continuous coordinate
|
250 |
+
// by default, thus the 0.5 offset
|
251 |
+
T offset = (T)0.5;
|
252 |
+
T roi_center_w = current_roi[1] * spatial_scale - offset;
|
253 |
+
T roi_center_h = current_roi[2] * spatial_scale - offset;
|
254 |
+
T roi_width = current_roi[3] * spatial_scale;
|
255 |
+
T roi_height = current_roi[4] * spatial_scale;
|
256 |
+
T theta = current_roi[5] * M_PI / 180.0;
|
257 |
+
T cos_theta = cos(theta);
|
258 |
+
T sin_theta = sin(theta);
|
259 |
+
|
260 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
261 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
262 |
+
|
263 |
+
T* offset_bottom_diff =
|
264 |
+
bottom_diff + (roi_batch_ind * channels + c) * height * width;
|
265 |
+
|
266 |
+
int top_offset = (n * channels + c) * pooled_height * pooled_width;
|
267 |
+
const T* offset_top_diff = top_diff + top_offset;
|
268 |
+
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
269 |
+
|
270 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
271 |
+
int roi_bin_grid_h = (sampling_ratio > 0)
|
272 |
+
? sampling_ratio
|
273 |
+
: ceil(roi_height / pooled_height); // e.g., = 2
|
274 |
+
int roi_bin_grid_w =
|
275 |
+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
276 |
+
|
277 |
+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
278 |
+
// Appropriate translation needs to be applied after.
|
279 |
+
T roi_start_h = -roi_height / 2.0;
|
280 |
+
T roi_start_w = -roi_width / 2.0;
|
281 |
+
|
282 |
+
// We do average (integral) pooling inside a bin
|
283 |
+
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
284 |
+
|
285 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
|
286 |
+
{
|
287 |
+
const T yy = roi_start_h + ph * bin_size_h +
|
288 |
+
static_cast<T>(iy + .5f) * bin_size_h /
|
289 |
+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
290 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
291 |
+
const T xx = roi_start_w + pw * bin_size_w +
|
292 |
+
static_cast<T>(ix + .5f) * bin_size_w /
|
293 |
+
static_cast<T>(roi_bin_grid_w);
|
294 |
+
|
295 |
+
// Rotate by theta around the center and translate
|
296 |
+
T y = yy * cos_theta - xx * sin_theta + roi_center_h;
|
297 |
+
T x = yy * sin_theta + xx * cos_theta + roi_center_w;
|
298 |
+
|
299 |
+
T w1, w2, w3, w4;
|
300 |
+
int x_low, x_high, y_low, y_high;
|
301 |
+
|
302 |
+
bilinear_interpolate_gradient(
|
303 |
+
height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high);
|
304 |
+
|
305 |
+
T g1 = top_diff_this_bin * w1 / count;
|
306 |
+
T g2 = top_diff_this_bin * w2 / count;
|
307 |
+
T g3 = top_diff_this_bin * w3 / count;
|
308 |
+
T g4 = top_diff_this_bin * w4 / count;
|
309 |
+
|
310 |
+
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
311 |
+
atomicAdd(
|
312 |
+
offset_bottom_diff + y_low * width + x_low, static_cast<T>(g1));
|
313 |
+
atomicAdd(
|
314 |
+
offset_bottom_diff + y_low * width + x_high, static_cast<T>(g2));
|
315 |
+
atomicAdd(
|
316 |
+
offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
|
317 |
+
atomicAdd(
|
318 |
+
offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
|
319 |
+
} // if
|
320 |
+
} // ix
|
321 |
+
} // iy
|
322 |
+
} // CUDA_1D_KERNEL_LOOP
|
323 |
+
} // RoIAlignRotatedBackward
|
324 |
+
|
325 |
+
at::Tensor ROIAlignRotated_forward_cuda(
|
326 |
+
const at::Tensor& input,
|
327 |
+
const at::Tensor& rois,
|
328 |
+
const float spatial_scale,
|
329 |
+
const int pooled_height,
|
330 |
+
const int pooled_width,
|
331 |
+
const int sampling_ratio) {
|
332 |
+
AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
|
333 |
+
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
|
334 |
+
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
|
335 |
+
|
336 |
+
at::CheckedFrom c = "ROIAlignRotated_forward_cuda";
|
337 |
+
at::checkAllSameGPU(c, {input_t, rois_t});
|
338 |
+
at::checkAllSameType(c, {input_t, rois_t});
|
339 |
+
at::cuda::CUDAGuard device_guard(input.device());
|
340 |
+
|
341 |
+
auto num_rois = rois.size(0);
|
342 |
+
auto channels = input.size(1);
|
343 |
+
auto height = input.size(2);
|
344 |
+
auto width = input.size(3);
|
345 |
+
|
346 |
+
auto output = at::empty(
|
347 |
+
{num_rois, channels, pooled_height, pooled_width}, input.options());
|
348 |
+
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
349 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
350 |
+
|
351 |
+
dim3 grid(std::min(
|
352 |
+
at::cuda::ATenCeilDiv(
|
353 |
+
static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
|
354 |
+
static_cast<int64_t>(4096)));
|
355 |
+
dim3 block(512);
|
356 |
+
|
357 |
+
if (output.numel() == 0) {
|
358 |
+
AT_CUDA_CHECK(cudaGetLastError());
|
359 |
+
return output;
|
360 |
+
}
|
361 |
+
|
362 |
+
auto input_ = input.contiguous(), rois_ = rois.contiguous();
|
363 |
+
AT_DISPATCH_FLOATING_TYPES(
|
364 |
+
input.scalar_type(), "ROIAlignRotated_forward", [&] {
|
365 |
+
RoIAlignRotatedForward<scalar_t><<<grid, block, 0, stream>>>(
|
366 |
+
output_size,
|
367 |
+
input_.data_ptr<scalar_t>(),
|
368 |
+
spatial_scale,
|
369 |
+
channels,
|
370 |
+
height,
|
371 |
+
width,
|
372 |
+
pooled_height,
|
373 |
+
pooled_width,
|
374 |
+
sampling_ratio,
|
375 |
+
rois_.data_ptr<scalar_t>(),
|
376 |
+
output.data_ptr<scalar_t>());
|
377 |
+
});
|
378 |
+
cudaDeviceSynchronize();
|
379 |
+
AT_CUDA_CHECK(cudaGetLastError());
|
380 |
+
return output;
|
381 |
+
}
|
382 |
+
|
383 |
+
// TODO remove the dependency on input and use instead its sizes -> save memory
|
384 |
+
at::Tensor ROIAlignRotated_backward_cuda(
|
385 |
+
const at::Tensor& grad,
|
386 |
+
const at::Tensor& rois,
|
387 |
+
const float spatial_scale,
|
388 |
+
const int pooled_height,
|
389 |
+
const int pooled_width,
|
390 |
+
const int batch_size,
|
391 |
+
const int channels,
|
392 |
+
const int height,
|
393 |
+
const int width,
|
394 |
+
const int sampling_ratio) {
|
395 |
+
AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
|
396 |
+
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
|
397 |
+
|
398 |
+
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
|
399 |
+
at::CheckedFrom c = "ROIAlign_backward_cuda";
|
400 |
+
at::checkAllSameGPU(c, {grad_t, rois_t});
|
401 |
+
at::checkAllSameType(c, {grad_t, rois_t});
|
402 |
+
at::cuda::CUDAGuard device_guard(grad.device());
|
403 |
+
|
404 |
+
auto num_rois = rois.size(0);
|
405 |
+
auto grad_input =
|
406 |
+
at::zeros({batch_size, channels, height, width}, grad.options());
|
407 |
+
|
408 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
409 |
+
|
410 |
+
dim3 grid(std::min(
|
411 |
+
at::cuda::ATenCeilDiv(
|
412 |
+
static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
|
413 |
+
static_cast<int64_t>(4096)));
|
414 |
+
dim3 block(512);
|
415 |
+
|
416 |
+
// handle possibly empty gradients
|
417 |
+
if (grad.numel() == 0) {
|
418 |
+
AT_CUDA_CHECK(cudaGetLastError());
|
419 |
+
return grad_input;
|
420 |
+
}
|
421 |
+
|
422 |
+
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
|
423 |
+
AT_DISPATCH_FLOATING_TYPES(
|
424 |
+
grad.scalar_type(), "ROIAlignRotated_backward", [&] {
|
425 |
+
RoIAlignRotatedBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
|
426 |
+
grad.numel(),
|
427 |
+
grad_.data_ptr<scalar_t>(),
|
428 |
+
num_rois,
|
429 |
+
spatial_scale,
|
430 |
+
channels,
|
431 |
+
height,
|
432 |
+
width,
|
433 |
+
pooled_height,
|
434 |
+
pooled_width,
|
435 |
+
sampling_ratio,
|
436 |
+
grad_input.data_ptr<scalar_t>(),
|
437 |
+
rois_.data_ptr<scalar_t>());
|
438 |
+
});
|
439 |
+
AT_CUDA_CHECK(cudaGetLastError());
|
440 |
+
return grad_input;
|
441 |
+
}
|
442 |
+
|
443 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#pragma once
|
3 |
+
#include <torch/types.h>
|
4 |
+
|
5 |
+
namespace detectron2 {
|
6 |
+
|
7 |
+
at::Tensor box_iou_rotated_cpu(
|
8 |
+
const at::Tensor& boxes1,
|
9 |
+
const at::Tensor& boxes2);
|
10 |
+
|
11 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
12 |
+
at::Tensor box_iou_rotated_cuda(
|
13 |
+
const at::Tensor& boxes1,
|
14 |
+
const at::Tensor& boxes2);
|
15 |
+
#endif
|
16 |
+
|
17 |
+
// Interface for Python
|
18 |
+
// inline is needed to prevent multiple function definitions when this header is
|
19 |
+
// included by different cpps
|
20 |
+
inline at::Tensor box_iou_rotated(
|
21 |
+
const at::Tensor& boxes1,
|
22 |
+
const at::Tensor& boxes2) {
|
23 |
+
assert(boxes1.device().is_cuda() == boxes2.device().is_cuda());
|
24 |
+
if (boxes1.device().is_cuda()) {
|
25 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
26 |
+
return box_iou_rotated_cuda(boxes1.contiguous(), boxes2.contiguous());
|
27 |
+
#else
|
28 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
29 |
+
#endif
|
30 |
+
}
|
31 |
+
|
32 |
+
return box_iou_rotated_cpu(boxes1.contiguous(), boxes2.contiguous());
|
33 |
+
}
|
34 |
+
|
35 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#include "box_iou_rotated.h"
|
3 |
+
#include "box_iou_rotated_utils.h"
|
4 |
+
|
5 |
+
namespace detectron2 {
|
6 |
+
|
7 |
+
template <typename T>
|
8 |
+
void box_iou_rotated_cpu_kernel(
|
9 |
+
const at::Tensor& boxes1,
|
10 |
+
const at::Tensor& boxes2,
|
11 |
+
at::Tensor& ious) {
|
12 |
+
auto num_boxes1 = boxes1.size(0);
|
13 |
+
auto num_boxes2 = boxes2.size(0);
|
14 |
+
|
15 |
+
for (int i = 0; i < num_boxes1; i++) {
|
16 |
+
for (int j = 0; j < num_boxes2; j++) {
|
17 |
+
ious[i * num_boxes2 + j] = single_box_iou_rotated<T>(
|
18 |
+
boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>());
|
19 |
+
}
|
20 |
+
}
|
21 |
+
}
|
22 |
+
|
23 |
+
at::Tensor box_iou_rotated_cpu(
|
24 |
+
// input must be contiguous:
|
25 |
+
const at::Tensor& boxes1,
|
26 |
+
const at::Tensor& boxes2) {
|
27 |
+
auto num_boxes1 = boxes1.size(0);
|
28 |
+
auto num_boxes2 = boxes2.size(0);
|
29 |
+
at::Tensor ious =
|
30 |
+
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
|
31 |
+
|
32 |
+
box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious);
|
33 |
+
|
34 |
+
// reshape from 1d array to 2d array
|
35 |
+
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
|
36 |
+
return ious.reshape(shape);
|
37 |
+
}
|
38 |
+
|
39 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <ATen/cuda/CUDAContext.h>
|
4 |
+
#include <c10/cuda/CUDAGuard.h>
|
5 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
6 |
+
#include "box_iou_rotated_utils.h"
|
7 |
+
|
8 |
+
namespace detectron2 {
|
9 |
+
|
10 |
+
// 2D block with 32 * 16 = 512 threads per block
|
11 |
+
const int BLOCK_DIM_X = 32;
|
12 |
+
const int BLOCK_DIM_Y = 16;
|
13 |
+
|
14 |
+
template <typename T>
|
15 |
+
__global__ void box_iou_rotated_cuda_kernel(
|
16 |
+
const int n_boxes1,
|
17 |
+
const int n_boxes2,
|
18 |
+
const T* dev_boxes1,
|
19 |
+
const T* dev_boxes2,
|
20 |
+
T* dev_ious) {
|
21 |
+
const int row_start = blockIdx.x * blockDim.x;
|
22 |
+
const int col_start = blockIdx.y * blockDim.y;
|
23 |
+
|
24 |
+
const int row_size = min(n_boxes1 - row_start, blockDim.x);
|
25 |
+
const int col_size = min(n_boxes2 - col_start, blockDim.y);
|
26 |
+
|
27 |
+
__shared__ float block_boxes1[BLOCK_DIM_X * 5];
|
28 |
+
__shared__ float block_boxes2[BLOCK_DIM_Y * 5];
|
29 |
+
|
30 |
+
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
|
31 |
+
if (threadIdx.x < row_size && threadIdx.y == 0) {
|
32 |
+
block_boxes1[threadIdx.x * 5 + 0] =
|
33 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 0];
|
34 |
+
block_boxes1[threadIdx.x * 5 + 1] =
|
35 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 1];
|
36 |
+
block_boxes1[threadIdx.x * 5 + 2] =
|
37 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 2];
|
38 |
+
block_boxes1[threadIdx.x * 5 + 3] =
|
39 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 3];
|
40 |
+
block_boxes1[threadIdx.x * 5 + 4] =
|
41 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 4];
|
42 |
+
}
|
43 |
+
|
44 |
+
if (threadIdx.x < col_size && threadIdx.y == 0) {
|
45 |
+
block_boxes2[threadIdx.x * 5 + 0] =
|
46 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 0];
|
47 |
+
block_boxes2[threadIdx.x * 5 + 1] =
|
48 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 1];
|
49 |
+
block_boxes2[threadIdx.x * 5 + 2] =
|
50 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 2];
|
51 |
+
block_boxes2[threadIdx.x * 5 + 3] =
|
52 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 3];
|
53 |
+
block_boxes2[threadIdx.x * 5 + 4] =
|
54 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 4];
|
55 |
+
}
|
56 |
+
__syncthreads();
|
57 |
+
|
58 |
+
if (threadIdx.x < row_size && threadIdx.y < col_size) {
|
59 |
+
int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y;
|
60 |
+
dev_ious[offset] = single_box_iou_rotated<T>(
|
61 |
+
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
at::Tensor box_iou_rotated_cuda(
|
66 |
+
// input must be contiguous
|
67 |
+
const at::Tensor& boxes1,
|
68 |
+
const at::Tensor& boxes2) {
|
69 |
+
using scalar_t = float;
|
70 |
+
AT_ASSERTM(
|
71 |
+
boxes1.scalar_type() == at::kFloat, "boxes1 must be a float tensor");
|
72 |
+
AT_ASSERTM(
|
73 |
+
boxes2.scalar_type() == at::kFloat, "boxes2 must be a float tensor");
|
74 |
+
AT_ASSERTM(boxes1.is_cuda(), "boxes1 must be a CUDA tensor");
|
75 |
+
AT_ASSERTM(boxes2.is_cuda(), "boxes2 must be a CUDA tensor");
|
76 |
+
at::cuda::CUDAGuard device_guard(boxes1.device());
|
77 |
+
|
78 |
+
auto num_boxes1 = boxes1.size(0);
|
79 |
+
auto num_boxes2 = boxes2.size(0);
|
80 |
+
|
81 |
+
at::Tensor ious =
|
82 |
+
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
|
83 |
+
|
84 |
+
bool transpose = false;
|
85 |
+
if (num_boxes1 > 0 && num_boxes2 > 0) {
|
86 |
+
scalar_t *data1 = boxes1.data_ptr<scalar_t>(),
|
87 |
+
*data2 = boxes2.data_ptr<scalar_t>();
|
88 |
+
|
89 |
+
if (num_boxes2 > 65535 * BLOCK_DIM_Y) {
|
90 |
+
AT_ASSERTM(
|
91 |
+
num_boxes1 <= 65535 * BLOCK_DIM_Y,
|
92 |
+
"Too many boxes for box_iou_rotated_cuda!");
|
93 |
+
// x dim is allowed to be large, but y dim cannot,
|
94 |
+
// so we transpose the two to avoid "invalid configuration argument"
|
95 |
+
// error. We assume one of them is small. Otherwise the result is hard to
|
96 |
+
// fit in memory anyway.
|
97 |
+
std::swap(num_boxes1, num_boxes2);
|
98 |
+
std::swap(data1, data2);
|
99 |
+
transpose = true;
|
100 |
+
}
|
101 |
+
|
102 |
+
const int blocks_x =
|
103 |
+
at::cuda::ATenCeilDiv(static_cast<int>(num_boxes1), BLOCK_DIM_X);
|
104 |
+
const int blocks_y =
|
105 |
+
at::cuda::ATenCeilDiv(static_cast<int>(num_boxes2), BLOCK_DIM_Y);
|
106 |
+
|
107 |
+
dim3 blocks(blocks_x, blocks_y);
|
108 |
+
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
|
109 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
110 |
+
|
111 |
+
box_iou_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
112 |
+
num_boxes1,
|
113 |
+
num_boxes2,
|
114 |
+
data1,
|
115 |
+
data2,
|
116 |
+
(scalar_t*)ious.data_ptr<scalar_t>());
|
117 |
+
|
118 |
+
AT_CUDA_CHECK(cudaGetLastError());
|
119 |
+
}
|
120 |
+
|
121 |
+
// reshape from 1d array to 2d array
|
122 |
+
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
|
123 |
+
if (transpose) {
|
124 |
+
return ious.view(shape).t();
|
125 |
+
} else {
|
126 |
+
return ious.view(shape);
|
127 |
+
}
|
128 |
+
}
|
129 |
+
|
130 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#pragma once
|
3 |
+
|
4 |
+
#include <cassert>
|
5 |
+
#include <cmath>
|
6 |
+
|
7 |
+
#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
|
8 |
+
// Designates functions callable from the host (CPU) and the device (GPU)
|
9 |
+
#define HOST_DEVICE __host__ __device__
|
10 |
+
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
|
11 |
+
#else
|
12 |
+
#include <algorithm>
|
13 |
+
#define HOST_DEVICE
|
14 |
+
#define HOST_DEVICE_INLINE HOST_DEVICE inline
|
15 |
+
#endif
|
16 |
+
|
17 |
+
namespace detectron2 {
|
18 |
+
|
19 |
+
namespace {
|
20 |
+
|
21 |
+
template <typename T>
|
22 |
+
struct RotatedBox {
|
23 |
+
T x_ctr, y_ctr, w, h, a;
|
24 |
+
};
|
25 |
+
|
26 |
+
template <typename T>
|
27 |
+
struct Point {
|
28 |
+
T x, y;
|
29 |
+
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
|
30 |
+
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
|
31 |
+
return Point(x + p.x, y + p.y);
|
32 |
+
}
|
33 |
+
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
|
34 |
+
x += p.x;
|
35 |
+
y += p.y;
|
36 |
+
return *this;
|
37 |
+
}
|
38 |
+
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
|
39 |
+
return Point(x - p.x, y - p.y);
|
40 |
+
}
|
41 |
+
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
|
42 |
+
return Point(x * coeff, y * coeff);
|
43 |
+
}
|
44 |
+
};
|
45 |
+
|
46 |
+
template <typename T>
|
47 |
+
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
|
48 |
+
return A.x * B.x + A.y * B.y;
|
49 |
+
}
|
50 |
+
|
51 |
+
// R: result type. can be different from input type
|
52 |
+
template <typename T, typename R = T>
|
53 |
+
HOST_DEVICE_INLINE R cross_2d(const Point<T>& A, const Point<T>& B) {
|
54 |
+
return static_cast<R>(A.x) * static_cast<R>(B.y) -
|
55 |
+
static_cast<R>(B.x) * static_cast<R>(A.y);
|
56 |
+
}
|
57 |
+
|
58 |
+
template <typename T>
|
59 |
+
HOST_DEVICE_INLINE void get_rotated_vertices(
|
60 |
+
const RotatedBox<T>& box,
|
61 |
+
Point<T> (&pts)[4]) {
|
62 |
+
// M_PI / 180. == 0.01745329251
|
63 |
+
double theta = box.a * 0.01745329251;
|
64 |
+
T cosTheta2 = (T)cos(theta) * 0.5f;
|
65 |
+
T sinTheta2 = (T)sin(theta) * 0.5f;
|
66 |
+
|
67 |
+
// y: top --> down; x: left --> right
|
68 |
+
pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w;
|
69 |
+
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
|
70 |
+
pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w;
|
71 |
+
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
|
72 |
+
pts[2].x = 2 * box.x_ctr - pts[0].x;
|
73 |
+
pts[2].y = 2 * box.y_ctr - pts[0].y;
|
74 |
+
pts[3].x = 2 * box.x_ctr - pts[1].x;
|
75 |
+
pts[3].y = 2 * box.y_ctr - pts[1].y;
|
76 |
+
}
|
77 |
+
|
78 |
+
template <typename T>
|
79 |
+
HOST_DEVICE_INLINE int get_intersection_points(
|
80 |
+
const Point<T> (&pts1)[4],
|
81 |
+
const Point<T> (&pts2)[4],
|
82 |
+
Point<T> (&intersections)[24]) {
|
83 |
+
// Line vector
|
84 |
+
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
|
85 |
+
Point<T> vec1[4], vec2[4];
|
86 |
+
for (int i = 0; i < 4; i++) {
|
87 |
+
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
|
88 |
+
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
|
89 |
+
}
|
90 |
+
|
91 |
+
// When computing the intersection area, it doesn't hurt if we have
|
92 |
+
// more (duplicated/approximate) intersections/vertices than needed,
|
93 |
+
// while it can cause drastic difference if we miss an intersection/vertex.
|
94 |
+
// Therefore, we add an epsilon to relax the comparisons between
|
95 |
+
// the float point numbers that decide the intersection points.
|
96 |
+
double EPS = 1e-5;
|
97 |
+
|
98 |
+
// Line test - test all line combos for intersection
|
99 |
+
int num = 0; // number of intersections
|
100 |
+
for (int i = 0; i < 4; i++) {
|
101 |
+
for (int j = 0; j < 4; j++) {
|
102 |
+
// Solve for 2x2 Ax=b
|
103 |
+
T det = cross_2d<T>(vec2[j], vec1[i]);
|
104 |
+
|
105 |
+
// This takes care of parallel lines
|
106 |
+
if (fabs(det) <= 1e-14) {
|
107 |
+
continue;
|
108 |
+
}
|
109 |
+
|
110 |
+
auto vec12 = pts2[j] - pts1[i];
|
111 |
+
|
112 |
+
T t1 = cross_2d<T>(vec2[j], vec12) / det;
|
113 |
+
T t2 = cross_2d<T>(vec1[i], vec12) / det;
|
114 |
+
|
115 |
+
if (t1 > -EPS && t1 < 1.0f + EPS && t2 > -EPS && t2 < 1.0f + EPS) {
|
116 |
+
intersections[num++] = pts1[i] + vec1[i] * t1;
|
117 |
+
}
|
118 |
+
}
|
119 |
+
}
|
120 |
+
|
121 |
+
// Check for vertices of rect1 inside rect2
|
122 |
+
{
|
123 |
+
const auto& AB = vec2[0];
|
124 |
+
const auto& DA = vec2[3];
|
125 |
+
auto ABdotAB = dot_2d<T>(AB, AB);
|
126 |
+
auto ADdotAD = dot_2d<T>(DA, DA);
|
127 |
+
for (int i = 0; i < 4; i++) {
|
128 |
+
// assume ABCD is the rectangle, and P is the point to be judged
|
129 |
+
// P is inside ABCD iff. P's projection on AB lies within AB
|
130 |
+
// and P's projection on AD lies within AD
|
131 |
+
|
132 |
+
auto AP = pts1[i] - pts2[0];
|
133 |
+
|
134 |
+
auto APdotAB = dot_2d<T>(AP, AB);
|
135 |
+
auto APdotAD = -dot_2d<T>(AP, DA);
|
136 |
+
|
137 |
+
if ((APdotAB > -EPS) && (APdotAD > -EPS) && (APdotAB < ABdotAB + EPS) &&
|
138 |
+
(APdotAD < ADdotAD + EPS)) {
|
139 |
+
intersections[num++] = pts1[i];
|
140 |
+
}
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
// Reverse the check - check for vertices of rect2 inside rect1
|
145 |
+
{
|
146 |
+
const auto& AB = vec1[0];
|
147 |
+
const auto& DA = vec1[3];
|
148 |
+
auto ABdotAB = dot_2d<T>(AB, AB);
|
149 |
+
auto ADdotAD = dot_2d<T>(DA, DA);
|
150 |
+
for (int i = 0; i < 4; i++) {
|
151 |
+
auto AP = pts2[i] - pts1[0];
|
152 |
+
|
153 |
+
auto APdotAB = dot_2d<T>(AP, AB);
|
154 |
+
auto APdotAD = -dot_2d<T>(AP, DA);
|
155 |
+
|
156 |
+
if ((APdotAB > -EPS) && (APdotAD > -EPS) && (APdotAB < ABdotAB + EPS) &&
|
157 |
+
(APdotAD < ADdotAD + EPS)) {
|
158 |
+
intersections[num++] = pts2[i];
|
159 |
+
}
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
return num;
|
164 |
+
}
|
165 |
+
|
166 |
+
template <typename T>
|
167 |
+
HOST_DEVICE_INLINE int convex_hull_graham(
|
168 |
+
const Point<T> (&p)[24],
|
169 |
+
const int& num_in,
|
170 |
+
Point<T> (&q)[24],
|
171 |
+
bool shift_to_zero = false) {
|
172 |
+
assert(num_in >= 2);
|
173 |
+
|
174 |
+
// Step 1:
|
175 |
+
// Find point with minimum y
|
176 |
+
// if more than 1 points have the same minimum y,
|
177 |
+
// pick the one with the minimum x.
|
178 |
+
int t = 0;
|
179 |
+
for (int i = 1; i < num_in; i++) {
|
180 |
+
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
|
181 |
+
t = i;
|
182 |
+
}
|
183 |
+
}
|
184 |
+
auto& start = p[t]; // starting point
|
185 |
+
|
186 |
+
// Step 2:
|
187 |
+
// Subtract starting point from every points (for sorting in the next step)
|
188 |
+
for (int i = 0; i < num_in; i++) {
|
189 |
+
q[i] = p[i] - start;
|
190 |
+
}
|
191 |
+
|
192 |
+
// Swap the starting point to position 0
|
193 |
+
auto tmp = q[0];
|
194 |
+
q[0] = q[t];
|
195 |
+
q[t] = tmp;
|
196 |
+
|
197 |
+
// Step 3:
|
198 |
+
// Sort point 1 ~ num_in according to their relative cross-product values
|
199 |
+
// (essentially sorting according to angles)
|
200 |
+
// If the angles are the same, sort according to their distance to origin
|
201 |
+
T dist[24];
|
202 |
+
#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
|
203 |
+
// compute distance to origin before sort, and sort them together with the
|
204 |
+
// points
|
205 |
+
for (int i = 0; i < num_in; i++) {
|
206 |
+
dist[i] = dot_2d<T>(q[i], q[i]);
|
207 |
+
}
|
208 |
+
|
209 |
+
// CUDA version
|
210 |
+
// In the future, we can potentially use thrust
|
211 |
+
// for sorting here to improve speed (though not guaranteed)
|
212 |
+
for (int i = 1; i < num_in - 1; i++) {
|
213 |
+
for (int j = i + 1; j < num_in; j++) {
|
214 |
+
T crossProduct = cross_2d<T>(q[i], q[j]);
|
215 |
+
if ((crossProduct < -1e-6) ||
|
216 |
+
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
|
217 |
+
auto q_tmp = q[i];
|
218 |
+
q[i] = q[j];
|
219 |
+
q[j] = q_tmp;
|
220 |
+
auto dist_tmp = dist[i];
|
221 |
+
dist[i] = dist[j];
|
222 |
+
dist[j] = dist_tmp;
|
223 |
+
}
|
224 |
+
}
|
225 |
+
}
|
226 |
+
#else
|
227 |
+
// CPU version
|
228 |
+
std::sort(
|
229 |
+
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
|
230 |
+
T temp = cross_2d<T>(A, B);
|
231 |
+
if (fabs(temp) < 1e-6) {
|
232 |
+
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
|
233 |
+
} else {
|
234 |
+
return temp > 0;
|
235 |
+
}
|
236 |
+
});
|
237 |
+
// compute distance to origin after sort, since the points are now different.
|
238 |
+
for (int i = 0; i < num_in; i++) {
|
239 |
+
dist[i] = dot_2d<T>(q[i], q[i]);
|
240 |
+
}
|
241 |
+
#endif
|
242 |
+
|
243 |
+
// Step 4:
|
244 |
+
// Make sure there are at least 2 points (that don't overlap with each other)
|
245 |
+
// in the stack
|
246 |
+
int k; // index of the non-overlapped second point
|
247 |
+
for (k = 1; k < num_in; k++) {
|
248 |
+
if (dist[k] > 1e-8) {
|
249 |
+
break;
|
250 |
+
}
|
251 |
+
}
|
252 |
+
if (k == num_in) {
|
253 |
+
// We reach the end, which means the convex hull is just one point
|
254 |
+
q[0] = p[t];
|
255 |
+
return 1;
|
256 |
+
}
|
257 |
+
q[1] = q[k];
|
258 |
+
int m = 2; // 2 points in the stack
|
259 |
+
// Step 5:
|
260 |
+
// Finally we can start the scanning process.
|
261 |
+
// When a non-convex relationship between the 3 points is found
|
262 |
+
// (either concave shape or duplicated points),
|
263 |
+
// we pop the previous point from the stack
|
264 |
+
// until the 3-point relationship is convex again, or
|
265 |
+
// until the stack only contains two points
|
266 |
+
for (int i = k + 1; i < num_in; i++) {
|
267 |
+
while (m > 1) {
|
268 |
+
auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2];
|
269 |
+
// cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) -
|
270 |
+
// q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we
|
271 |
+
// compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means
|
272 |
+
// round to nearest floating point).
|
273 |
+
if (q1.x * q2.y >= q2.x * q1.y)
|
274 |
+
m--;
|
275 |
+
else
|
276 |
+
break;
|
277 |
+
}
|
278 |
+
// Using double also helps, but float can solve the issue for now.
|
279 |
+
// while (m > 1 && cross_2d<T, double>(q[i] - q[m - 2], q[m - 1] - q[m - 2])
|
280 |
+
// >= 0) {
|
281 |
+
// m--;
|
282 |
+
// }
|
283 |
+
q[m++] = q[i];
|
284 |
+
}
|
285 |
+
|
286 |
+
// Step 6 (Optional):
|
287 |
+
// In general sense we need the original coordinates, so we
|
288 |
+
// need to shift the points back (reverting Step 2)
|
289 |
+
// But if we're only interested in getting the area/perimeter of the shape
|
290 |
+
// We can simply return.
|
291 |
+
if (!shift_to_zero) {
|
292 |
+
for (int i = 0; i < m; i++) {
|
293 |
+
q[i] += start;
|
294 |
+
}
|
295 |
+
}
|
296 |
+
|
297 |
+
return m;
|
298 |
+
}
|
299 |
+
|
300 |
+
template <typename T>
|
301 |
+
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
|
302 |
+
if (m <= 2) {
|
303 |
+
return 0;
|
304 |
+
}
|
305 |
+
|
306 |
+
T area = 0;
|
307 |
+
for (int i = 1; i < m - 1; i++) {
|
308 |
+
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
|
309 |
+
}
|
310 |
+
|
311 |
+
return area / 2.0;
|
312 |
+
}
|
313 |
+
|
314 |
+
template <typename T>
|
315 |
+
HOST_DEVICE_INLINE T rotated_boxes_intersection(
|
316 |
+
const RotatedBox<T>& box1,
|
317 |
+
const RotatedBox<T>& box2) {
|
318 |
+
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
|
319 |
+
// from rotated_rect_intersection_pts
|
320 |
+
Point<T> intersectPts[24], orderedPts[24];
|
321 |
+
|
322 |
+
Point<T> pts1[4];
|
323 |
+
Point<T> pts2[4];
|
324 |
+
get_rotated_vertices<T>(box1, pts1);
|
325 |
+
get_rotated_vertices<T>(box2, pts2);
|
326 |
+
|
327 |
+
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
|
328 |
+
|
329 |
+
if (num <= 2) {
|
330 |
+
return 0.0;
|
331 |
+
}
|
332 |
+
|
333 |
+
// Convex Hull to order the intersection points in clockwise order and find
|
334 |
+
// the contour area.
|
335 |
+
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
|
336 |
+
return polygon_area<T>(orderedPts, num_convex);
|
337 |
+
}
|
338 |
+
|
339 |
+
} // namespace
|
340 |
+
|
341 |
+
template <typename T>
|
342 |
+
HOST_DEVICE_INLINE T
|
343 |
+
single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) {
|
344 |
+
// shift center to the middle point to achieve higher precision in result
|
345 |
+
RotatedBox<T> box1, box2;
|
346 |
+
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
|
347 |
+
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
|
348 |
+
box1.x_ctr = box1_raw[0] - center_shift_x;
|
349 |
+
box1.y_ctr = box1_raw[1] - center_shift_y;
|
350 |
+
box1.w = box1_raw[2];
|
351 |
+
box1.h = box1_raw[3];
|
352 |
+
box1.a = box1_raw[4];
|
353 |
+
box2.x_ctr = box2_raw[0] - center_shift_x;
|
354 |
+
box2.y_ctr = box2_raw[1] - center_shift_y;
|
355 |
+
box2.w = box2_raw[2];
|
356 |
+
box2.h = box2_raw[3];
|
357 |
+
box2.a = box2_raw[4];
|
358 |
+
|
359 |
+
T area1 = box1.w * box1.h;
|
360 |
+
T area2 = box2.w * box2.h;
|
361 |
+
if (area1 < 1e-14 || area2 < 1e-14) {
|
362 |
+
return 0.f;
|
363 |
+
}
|
364 |
+
|
365 |
+
T intersection = rotated_boxes_intersection<T>(box1, box2);
|
366 |
+
T iou = intersection / (area1 + area2 - intersection);
|
367 |
+
return iou;
|
368 |
+
}
|
369 |
+
|
370 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/cocoeval/cocoeval.cpp
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#include "cocoeval.h"
|
3 |
+
#include <time.h>
|
4 |
+
#include <algorithm>
|
5 |
+
#include <cstdint>
|
6 |
+
#include <numeric>
|
7 |
+
|
8 |
+
using namespace pybind11::literals;
|
9 |
+
|
10 |
+
namespace detectron2 {
|
11 |
+
|
12 |
+
namespace COCOeval {
|
13 |
+
|
14 |
+
// Sort detections from highest score to lowest, such that
|
15 |
+
// detection_instances[detection_sorted_indices[t]] >=
|
16 |
+
// detection_instances[detection_sorted_indices[t+1]]. Use stable_sort to match
|
17 |
+
// original COCO API
|
18 |
+
void SortInstancesByDetectionScore(
|
19 |
+
const std::vector<InstanceAnnotation>& detection_instances,
|
20 |
+
std::vector<uint64_t>* detection_sorted_indices) {
|
21 |
+
detection_sorted_indices->resize(detection_instances.size());
|
22 |
+
std::iota(
|
23 |
+
detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
|
24 |
+
std::stable_sort(
|
25 |
+
detection_sorted_indices->begin(),
|
26 |
+
detection_sorted_indices->end(),
|
27 |
+
[&detection_instances](size_t j1, size_t j2) {
|
28 |
+
return detection_instances[j1].score > detection_instances[j2].score;
|
29 |
+
});
|
30 |
+
}
|
31 |
+
|
32 |
+
// Partition the ground truth objects based on whether or not to ignore them
|
33 |
+
// based on area
|
34 |
+
void SortInstancesByIgnore(
|
35 |
+
const std::array<double, 2>& area_range,
|
36 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances,
|
37 |
+
std::vector<uint64_t>* ground_truth_sorted_indices,
|
38 |
+
std::vector<bool>* ignores) {
|
39 |
+
ignores->clear();
|
40 |
+
ignores->reserve(ground_truth_instances.size());
|
41 |
+
for (auto o : ground_truth_instances) {
|
42 |
+
ignores->push_back(
|
43 |
+
o.ignore || o.area < area_range[0] || o.area > area_range[1]);
|
44 |
+
}
|
45 |
+
|
46 |
+
ground_truth_sorted_indices->resize(ground_truth_instances.size());
|
47 |
+
std::iota(
|
48 |
+
ground_truth_sorted_indices->begin(),
|
49 |
+
ground_truth_sorted_indices->end(),
|
50 |
+
0);
|
51 |
+
std::stable_sort(
|
52 |
+
ground_truth_sorted_indices->begin(),
|
53 |
+
ground_truth_sorted_indices->end(),
|
54 |
+
[&ignores](size_t j1, size_t j2) {
|
55 |
+
return (int)(*ignores)[j1] < (int)(*ignores)[j2];
|
56 |
+
});
|
57 |
+
}
|
58 |
+
|
59 |
+
// For each IOU threshold, greedily match each detected instance to a ground
|
60 |
+
// truth instance (if possible) and store the results
|
61 |
+
void MatchDetectionsToGroundTruth(
|
62 |
+
const std::vector<InstanceAnnotation>& detection_instances,
|
63 |
+
const std::vector<uint64_t>& detection_sorted_indices,
|
64 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances,
|
65 |
+
const std::vector<uint64_t>& ground_truth_sorted_indices,
|
66 |
+
const std::vector<bool>& ignores,
|
67 |
+
const std::vector<std::vector<double>>& ious,
|
68 |
+
const std::vector<double>& iou_thresholds,
|
69 |
+
const std::array<double, 2>& area_range,
|
70 |
+
ImageEvaluation* results) {
|
71 |
+
// Initialize memory to store return data matches and ignore
|
72 |
+
const int num_iou_thresholds = iou_thresholds.size();
|
73 |
+
const int num_ground_truth = ground_truth_sorted_indices.size();
|
74 |
+
const int num_detections = detection_sorted_indices.size();
|
75 |
+
std::vector<uint64_t> ground_truth_matches(
|
76 |
+
num_iou_thresholds * num_ground_truth, 0);
|
77 |
+
std::vector<uint64_t>& detection_matches = results->detection_matches;
|
78 |
+
std::vector<bool>& detection_ignores = results->detection_ignores;
|
79 |
+
std::vector<bool>& ground_truth_ignores = results->ground_truth_ignores;
|
80 |
+
detection_matches.resize(num_iou_thresholds * num_detections, 0);
|
81 |
+
detection_ignores.resize(num_iou_thresholds * num_detections, false);
|
82 |
+
ground_truth_ignores.resize(num_ground_truth);
|
83 |
+
for (auto g = 0; g < num_ground_truth; ++g) {
|
84 |
+
ground_truth_ignores[g] = ignores[ground_truth_sorted_indices[g]];
|
85 |
+
}
|
86 |
+
|
87 |
+
for (auto t = 0; t < num_iou_thresholds; ++t) {
|
88 |
+
for (auto d = 0; d < num_detections; ++d) {
|
89 |
+
// information about best match so far (match=-1 -> unmatched)
|
90 |
+
double best_iou = std::min(iou_thresholds[t], 1 - 1e-10);
|
91 |
+
int match = -1;
|
92 |
+
for (auto g = 0; g < num_ground_truth; ++g) {
|
93 |
+
// if this ground truth instance is already matched and not a
|
94 |
+
// crowd, it cannot be matched to another detection
|
95 |
+
if (ground_truth_matches[t * num_ground_truth + g] > 0 &&
|
96 |
+
!ground_truth_instances[ground_truth_sorted_indices[g]].is_crowd) {
|
97 |
+
continue;
|
98 |
+
}
|
99 |
+
|
100 |
+
// if detected instance matched to a regular ground truth
|
101 |
+
// instance, we can break on the first ground truth instance
|
102 |
+
// tagged as ignore (because they are sorted by the ignore tag)
|
103 |
+
if (match >= 0 && !ground_truth_ignores[match] &&
|
104 |
+
ground_truth_ignores[g]) {
|
105 |
+
break;
|
106 |
+
}
|
107 |
+
|
108 |
+
// if IOU overlap is the best so far, store the match appropriately
|
109 |
+
if (ious[d][ground_truth_sorted_indices[g]] >= best_iou) {
|
110 |
+
best_iou = ious[d][ground_truth_sorted_indices[g]];
|
111 |
+
match = g;
|
112 |
+
}
|
113 |
+
}
|
114 |
+
// if match was made, store id of match for both detection and
|
115 |
+
// ground truth
|
116 |
+
if (match >= 0) {
|
117 |
+
detection_ignores[t * num_detections + d] = ground_truth_ignores[match];
|
118 |
+
detection_matches[t * num_detections + d] =
|
119 |
+
ground_truth_instances[ground_truth_sorted_indices[match]].id;
|
120 |
+
ground_truth_matches[t * num_ground_truth + match] =
|
121 |
+
detection_instances[detection_sorted_indices[d]].id;
|
122 |
+
}
|
123 |
+
|
124 |
+
// set unmatched detections outside of area range to ignore
|
125 |
+
const InstanceAnnotation& detection =
|
126 |
+
detection_instances[detection_sorted_indices[d]];
|
127 |
+
detection_ignores[t * num_detections + d] =
|
128 |
+
detection_ignores[t * num_detections + d] ||
|
129 |
+
(detection_matches[t * num_detections + d] == 0 &&
|
130 |
+
(detection.area < area_range[0] || detection.area > area_range[1]));
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
// store detection score results
|
135 |
+
results->detection_scores.resize(detection_sorted_indices.size());
|
136 |
+
for (size_t d = 0; d < detection_sorted_indices.size(); ++d) {
|
137 |
+
results->detection_scores[d] =
|
138 |
+
detection_instances[detection_sorted_indices[d]].score;
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
std::vector<ImageEvaluation> EvaluateImages(
|
143 |
+
const std::vector<std::array<double, 2>>& area_ranges,
|
144 |
+
int max_detections,
|
145 |
+
const std::vector<double>& iou_thresholds,
|
146 |
+
const ImageCategoryInstances<std::vector<double>>& image_category_ious,
|
147 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
148 |
+
image_category_ground_truth_instances,
|
149 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
150 |
+
image_category_detection_instances) {
|
151 |
+
const int num_area_ranges = area_ranges.size();
|
152 |
+
const int num_images = image_category_ground_truth_instances.size();
|
153 |
+
const int num_categories =
|
154 |
+
image_category_ious.size() > 0 ? image_category_ious[0].size() : 0;
|
155 |
+
std::vector<uint64_t> detection_sorted_indices;
|
156 |
+
std::vector<uint64_t> ground_truth_sorted_indices;
|
157 |
+
std::vector<bool> ignores;
|
158 |
+
std::vector<ImageEvaluation> results_all(
|
159 |
+
num_images * num_area_ranges * num_categories);
|
160 |
+
|
161 |
+
// Store results for each image, category, and area range combination. Results
|
162 |
+
// for each IOU threshold are packed into the same ImageEvaluation object
|
163 |
+
for (auto i = 0; i < num_images; ++i) {
|
164 |
+
for (auto c = 0; c < num_categories; ++c) {
|
165 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances =
|
166 |
+
image_category_ground_truth_instances[i][c];
|
167 |
+
const std::vector<InstanceAnnotation>& detection_instances =
|
168 |
+
image_category_detection_instances[i][c];
|
169 |
+
|
170 |
+
SortInstancesByDetectionScore(
|
171 |
+
detection_instances, &detection_sorted_indices);
|
172 |
+
if ((int)detection_sorted_indices.size() > max_detections) {
|
173 |
+
detection_sorted_indices.resize(max_detections);
|
174 |
+
}
|
175 |
+
|
176 |
+
for (size_t a = 0; a < area_ranges.size(); ++a) {
|
177 |
+
SortInstancesByIgnore(
|
178 |
+
area_ranges[a],
|
179 |
+
ground_truth_instances,
|
180 |
+
&ground_truth_sorted_indices,
|
181 |
+
&ignores);
|
182 |
+
|
183 |
+
MatchDetectionsToGroundTruth(
|
184 |
+
detection_instances,
|
185 |
+
detection_sorted_indices,
|
186 |
+
ground_truth_instances,
|
187 |
+
ground_truth_sorted_indices,
|
188 |
+
ignores,
|
189 |
+
image_category_ious[i][c],
|
190 |
+
iou_thresholds,
|
191 |
+
area_ranges[a],
|
192 |
+
&results_all
|
193 |
+
[c * num_area_ranges * num_images + a * num_images + i]);
|
194 |
+
}
|
195 |
+
}
|
196 |
+
}
|
197 |
+
|
198 |
+
return results_all;
|
199 |
+
}
|
200 |
+
|
201 |
+
// Convert a python list to a vector
|
202 |
+
template <typename T>
|
203 |
+
std::vector<T> list_to_vec(const py::list& l) {
|
204 |
+
std::vector<T> v(py::len(l));
|
205 |
+
for (int i = 0; i < (int)py::len(l); ++i) {
|
206 |
+
v[i] = l[i].cast<T>();
|
207 |
+
}
|
208 |
+
return v;
|
209 |
+
}
|
210 |
+
|
211 |
+
// Helper function to Accumulate()
|
212 |
+
// Considers the evaluation results applicable to a particular category, area
|
213 |
+
// range, and max_detections parameter setting, which begin at
|
214 |
+
// evaluations[evaluation_index]. Extracts a sorted list of length n of all
|
215 |
+
// applicable detection instances concatenated across all images in the dataset,
|
216 |
+
// which are represented by the outputs evaluation_indices, detection_scores,
|
217 |
+
// image_detection_indices, and detection_sorted_indices--all of which are
|
218 |
+
// length n. evaluation_indices[i] stores the applicable index into
|
219 |
+
// evaluations[] for instance i, which has detection score detection_score[i],
|
220 |
+
// and is the image_detection_indices[i]'th of the list of detections
|
221 |
+
// for the image containing i. detection_sorted_indices[] defines a sorted
|
222 |
+
// permutation of the 3 other outputs
|
223 |
+
int BuildSortedDetectionList(
|
224 |
+
const std::vector<ImageEvaluation>& evaluations,
|
225 |
+
const int64_t evaluation_index,
|
226 |
+
const int64_t num_images,
|
227 |
+
const int max_detections,
|
228 |
+
std::vector<uint64_t>* evaluation_indices,
|
229 |
+
std::vector<double>* detection_scores,
|
230 |
+
std::vector<uint64_t>* detection_sorted_indices,
|
231 |
+
std::vector<uint64_t>* image_detection_indices) {
|
232 |
+
assert(evaluations.size() >= evaluation_index + num_images);
|
233 |
+
|
234 |
+
// Extract a list of object instances of the applicable category, area
|
235 |
+
// range, and max detections requirements such that they can be sorted
|
236 |
+
image_detection_indices->clear();
|
237 |
+
evaluation_indices->clear();
|
238 |
+
detection_scores->clear();
|
239 |
+
image_detection_indices->reserve(num_images * max_detections);
|
240 |
+
evaluation_indices->reserve(num_images * max_detections);
|
241 |
+
detection_scores->reserve(num_images * max_detections);
|
242 |
+
int num_valid_ground_truth = 0;
|
243 |
+
for (auto i = 0; i < num_images; ++i) {
|
244 |
+
const ImageEvaluation& evaluation = evaluations[evaluation_index + i];
|
245 |
+
|
246 |
+
for (int d = 0;
|
247 |
+
d < (int)evaluation.detection_scores.size() && d < max_detections;
|
248 |
+
++d) { // detected instances
|
249 |
+
evaluation_indices->push_back(evaluation_index + i);
|
250 |
+
image_detection_indices->push_back(d);
|
251 |
+
detection_scores->push_back(evaluation.detection_scores[d]);
|
252 |
+
}
|
253 |
+
for (auto ground_truth_ignore : evaluation.ground_truth_ignores) {
|
254 |
+
if (!ground_truth_ignore) {
|
255 |
+
++num_valid_ground_truth;
|
256 |
+
}
|
257 |
+
}
|
258 |
+
}
|
259 |
+
|
260 |
+
// Sort detections by decreasing score, using stable sort to match
|
261 |
+
// python implementation
|
262 |
+
detection_sorted_indices->resize(detection_scores->size());
|
263 |
+
std::iota(
|
264 |
+
detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
|
265 |
+
std::stable_sort(
|
266 |
+
detection_sorted_indices->begin(),
|
267 |
+
detection_sorted_indices->end(),
|
268 |
+
[&detection_scores](size_t j1, size_t j2) {
|
269 |
+
return (*detection_scores)[j1] > (*detection_scores)[j2];
|
270 |
+
});
|
271 |
+
|
272 |
+
return num_valid_ground_truth;
|
273 |
+
}
|
274 |
+
|
275 |
+
// Helper function to Accumulate()
|
276 |
+
// Compute a precision recall curve given a sorted list of detected instances
|
277 |
+
// encoded in evaluations, evaluation_indices, detection_scores,
|
278 |
+
// detection_sorted_indices, image_detection_indices (see
|
279 |
+
// BuildSortedDetectionList()). Using vectors precisions and recalls
|
280 |
+
// and temporary storage, output the results into precisions_out, recalls_out,
|
281 |
+
// and scores_out, which are large buffers containing many precion/recall curves
|
282 |
+
// for all possible parameter settings, with precisions_out_index and
|
283 |
+
// recalls_out_index defining the applicable indices to store results.
|
284 |
+
void ComputePrecisionRecallCurve(
|
285 |
+
const int64_t precisions_out_index,
|
286 |
+
const int64_t precisions_out_stride,
|
287 |
+
const int64_t recalls_out_index,
|
288 |
+
const std::vector<double>& recall_thresholds,
|
289 |
+
const int iou_threshold_index,
|
290 |
+
const int num_iou_thresholds,
|
291 |
+
const int num_valid_ground_truth,
|
292 |
+
const std::vector<ImageEvaluation>& evaluations,
|
293 |
+
const std::vector<uint64_t>& evaluation_indices,
|
294 |
+
const std::vector<double>& detection_scores,
|
295 |
+
const std::vector<uint64_t>& detection_sorted_indices,
|
296 |
+
const std::vector<uint64_t>& image_detection_indices,
|
297 |
+
std::vector<double>* precisions,
|
298 |
+
std::vector<double>* recalls,
|
299 |
+
std::vector<double>* precisions_out,
|
300 |
+
std::vector<double>* scores_out,
|
301 |
+
std::vector<double>* recalls_out) {
|
302 |
+
assert(recalls_out->size() > recalls_out_index);
|
303 |
+
|
304 |
+
// Compute precision/recall for each instance in the sorted list of detections
|
305 |
+
int64_t true_positives_sum = 0, false_positives_sum = 0;
|
306 |
+
precisions->clear();
|
307 |
+
recalls->clear();
|
308 |
+
precisions->reserve(detection_sorted_indices.size());
|
309 |
+
recalls->reserve(detection_sorted_indices.size());
|
310 |
+
assert(!evaluations.empty() || detection_sorted_indices.empty());
|
311 |
+
for (auto detection_sorted_index : detection_sorted_indices) {
|
312 |
+
const ImageEvaluation& evaluation =
|
313 |
+
evaluations[evaluation_indices[detection_sorted_index]];
|
314 |
+
const auto num_detections =
|
315 |
+
evaluation.detection_matches.size() / num_iou_thresholds;
|
316 |
+
const auto detection_index = iou_threshold_index * num_detections +
|
317 |
+
image_detection_indices[detection_sorted_index];
|
318 |
+
assert(evaluation.detection_matches.size() > detection_index);
|
319 |
+
assert(evaluation.detection_ignores.size() > detection_index);
|
320 |
+
const int64_t detection_match =
|
321 |
+
evaluation.detection_matches[detection_index];
|
322 |
+
const bool detection_ignores =
|
323 |
+
evaluation.detection_ignores[detection_index];
|
324 |
+
const auto true_positive = detection_match > 0 && !detection_ignores;
|
325 |
+
const auto false_positive = detection_match == 0 && !detection_ignores;
|
326 |
+
if (true_positive) {
|
327 |
+
++true_positives_sum;
|
328 |
+
}
|
329 |
+
if (false_positive) {
|
330 |
+
++false_positives_sum;
|
331 |
+
}
|
332 |
+
|
333 |
+
const double recall =
|
334 |
+
static_cast<double>(true_positives_sum) / num_valid_ground_truth;
|
335 |
+
recalls->push_back(recall);
|
336 |
+
const int64_t num_valid_detections =
|
337 |
+
true_positives_sum + false_positives_sum;
|
338 |
+
const double precision = num_valid_detections > 0
|
339 |
+
? static_cast<double>(true_positives_sum) / num_valid_detections
|
340 |
+
: 0.0;
|
341 |
+
precisions->push_back(precision);
|
342 |
+
}
|
343 |
+
|
344 |
+
(*recalls_out)[recalls_out_index] = !recalls->empty() ? recalls->back() : 0;
|
345 |
+
|
346 |
+
for (int64_t i = static_cast<int64_t>(precisions->size()) - 1; i > 0; --i) {
|
347 |
+
if ((*precisions)[i] > (*precisions)[i - 1]) {
|
348 |
+
(*precisions)[i - 1] = (*precisions)[i];
|
349 |
+
}
|
350 |
+
}
|
351 |
+
|
352 |
+
// Sample the per instance precision/recall list at each recall threshold
|
353 |
+
for (size_t r = 0; r < recall_thresholds.size(); ++r) {
|
354 |
+
// first index in recalls >= recall_thresholds[r]
|
355 |
+
std::vector<double>::iterator low = std::lower_bound(
|
356 |
+
recalls->begin(), recalls->end(), recall_thresholds[r]);
|
357 |
+
size_t precisions_index = low - recalls->begin();
|
358 |
+
|
359 |
+
const auto results_ind = precisions_out_index + r * precisions_out_stride;
|
360 |
+
assert(results_ind < precisions_out->size());
|
361 |
+
assert(results_ind < scores_out->size());
|
362 |
+
if (precisions_index < precisions->size()) {
|
363 |
+
(*precisions_out)[results_ind] = (*precisions)[precisions_index];
|
364 |
+
(*scores_out)[results_ind] =
|
365 |
+
detection_scores[detection_sorted_indices[precisions_index]];
|
366 |
+
} else {
|
367 |
+
(*precisions_out)[results_ind] = 0;
|
368 |
+
(*scores_out)[results_ind] = 0;
|
369 |
+
}
|
370 |
+
}
|
371 |
+
}
|
372 |
+
py::dict Accumulate(
|
373 |
+
const py::object& params,
|
374 |
+
const std::vector<ImageEvaluation>& evaluations) {
|
375 |
+
const std::vector<double> recall_thresholds =
|
376 |
+
list_to_vec<double>(params.attr("recThrs"));
|
377 |
+
const std::vector<int> max_detections =
|
378 |
+
list_to_vec<int>(params.attr("maxDets"));
|
379 |
+
const int num_iou_thresholds = py::len(params.attr("iouThrs"));
|
380 |
+
const int num_recall_thresholds = py::len(params.attr("recThrs"));
|
381 |
+
const int num_categories = params.attr("useCats").cast<int>() == 1
|
382 |
+
? py::len(params.attr("catIds"))
|
383 |
+
: 1;
|
384 |
+
const int num_area_ranges = py::len(params.attr("areaRng"));
|
385 |
+
const int num_max_detections = py::len(params.attr("maxDets"));
|
386 |
+
const int num_images = py::len(params.attr("imgIds"));
|
387 |
+
|
388 |
+
std::vector<double> precisions_out(
|
389 |
+
num_iou_thresholds * num_recall_thresholds * num_categories *
|
390 |
+
num_area_ranges * num_max_detections,
|
391 |
+
-1);
|
392 |
+
std::vector<double> recalls_out(
|
393 |
+
num_iou_thresholds * num_categories * num_area_ranges *
|
394 |
+
num_max_detections,
|
395 |
+
-1);
|
396 |
+
std::vector<double> scores_out(
|
397 |
+
num_iou_thresholds * num_recall_thresholds * num_categories *
|
398 |
+
num_area_ranges * num_max_detections,
|
399 |
+
-1);
|
400 |
+
|
401 |
+
// Consider the list of all detected instances in the entire dataset in one
|
402 |
+
// large list. evaluation_indices, detection_scores,
|
403 |
+
// image_detection_indices, and detection_sorted_indices all have the same
|
404 |
+
// length as this list, such that each entry corresponds to one detected
|
405 |
+
// instance
|
406 |
+
std::vector<uint64_t> evaluation_indices; // indices into evaluations[]
|
407 |
+
std::vector<double> detection_scores; // detection scores of each instance
|
408 |
+
std::vector<uint64_t> detection_sorted_indices; // sorted indices of all
|
409 |
+
// instances in the dataset
|
410 |
+
std::vector<uint64_t>
|
411 |
+
image_detection_indices; // indices into the list of detected instances in
|
412 |
+
// the same image as each instance
|
413 |
+
std::vector<double> precisions, recalls;
|
414 |
+
|
415 |
+
for (auto c = 0; c < num_categories; ++c) {
|
416 |
+
for (auto a = 0; a < num_area_ranges; ++a) {
|
417 |
+
for (auto m = 0; m < num_max_detections; ++m) {
|
418 |
+
// The COCO PythonAPI assumes evaluations[] (the return value of
|
419 |
+
// COCOeval::EvaluateImages() is one long list storing results for each
|
420 |
+
// combination of category, area range, and image id, with categories in
|
421 |
+
// the outermost loop and images in the innermost loop.
|
422 |
+
const int64_t evaluations_index =
|
423 |
+
c * num_area_ranges * num_images + a * num_images;
|
424 |
+
int num_valid_ground_truth = BuildSortedDetectionList(
|
425 |
+
evaluations,
|
426 |
+
evaluations_index,
|
427 |
+
num_images,
|
428 |
+
max_detections[m],
|
429 |
+
&evaluation_indices,
|
430 |
+
&detection_scores,
|
431 |
+
&detection_sorted_indices,
|
432 |
+
&image_detection_indices);
|
433 |
+
|
434 |
+
if (num_valid_ground_truth == 0) {
|
435 |
+
continue;
|
436 |
+
}
|
437 |
+
|
438 |
+
for (auto t = 0; t < num_iou_thresholds; ++t) {
|
439 |
+
// recalls_out is a flattened vectors representing a
|
440 |
+
// num_iou_thresholds X num_categories X num_area_ranges X
|
441 |
+
// num_max_detections matrix
|
442 |
+
const int64_t recalls_out_index =
|
443 |
+
t * num_categories * num_area_ranges * num_max_detections +
|
444 |
+
c * num_area_ranges * num_max_detections +
|
445 |
+
a * num_max_detections + m;
|
446 |
+
|
447 |
+
// precisions_out and scores_out are flattened vectors
|
448 |
+
// representing a num_iou_thresholds X num_recall_thresholds X
|
449 |
+
// num_categories X num_area_ranges X num_max_detections matrix
|
450 |
+
const int64_t precisions_out_stride =
|
451 |
+
num_categories * num_area_ranges * num_max_detections;
|
452 |
+
const int64_t precisions_out_index = t * num_recall_thresholds *
|
453 |
+
num_categories * num_area_ranges * num_max_detections +
|
454 |
+
c * num_area_ranges * num_max_detections +
|
455 |
+
a * num_max_detections + m;
|
456 |
+
|
457 |
+
ComputePrecisionRecallCurve(
|
458 |
+
precisions_out_index,
|
459 |
+
precisions_out_stride,
|
460 |
+
recalls_out_index,
|
461 |
+
recall_thresholds,
|
462 |
+
t,
|
463 |
+
num_iou_thresholds,
|
464 |
+
num_valid_ground_truth,
|
465 |
+
evaluations,
|
466 |
+
evaluation_indices,
|
467 |
+
detection_scores,
|
468 |
+
detection_sorted_indices,
|
469 |
+
image_detection_indices,
|
470 |
+
&precisions,
|
471 |
+
&recalls,
|
472 |
+
&precisions_out,
|
473 |
+
&scores_out,
|
474 |
+
&recalls_out);
|
475 |
+
}
|
476 |
+
}
|
477 |
+
}
|
478 |
+
}
|
479 |
+
|
480 |
+
time_t rawtime;
|
481 |
+
struct tm local_time;
|
482 |
+
std::array<char, 200> buffer;
|
483 |
+
time(&rawtime);
|
484 |
+
#ifdef _WIN32
|
485 |
+
localtime_s(&local_time, &rawtime);
|
486 |
+
#else
|
487 |
+
localtime_r(&rawtime, &local_time);
|
488 |
+
#endif
|
489 |
+
strftime(
|
490 |
+
buffer.data(), 200, "%Y-%m-%d %H:%num_max_detections:%S", &local_time);
|
491 |
+
return py::dict(
|
492 |
+
"params"_a = params,
|
493 |
+
"counts"_a = std::vector<int64_t>(
|
494 |
+
{num_iou_thresholds,
|
495 |
+
num_recall_thresholds,
|
496 |
+
num_categories,
|
497 |
+
num_area_ranges,
|
498 |
+
num_max_detections}),
|
499 |
+
"date"_a = buffer,
|
500 |
+
"precision"_a = precisions_out,
|
501 |
+
"recall"_a = recalls_out,
|
502 |
+
"scores"_a = scores_out);
|
503 |
+
}
|
504 |
+
|
505 |
+
} // namespace COCOeval
|
506 |
+
|
507 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/cocoeval/cocoeval.h
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#pragma once
|
3 |
+
|
4 |
+
#include <pybind11/numpy.h>
|
5 |
+
#include <pybind11/pybind11.h>
|
6 |
+
#include <pybind11/stl.h>
|
7 |
+
#include <pybind11/stl_bind.h>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
namespace py = pybind11;
|
11 |
+
|
12 |
+
namespace detectron2 {
|
13 |
+
|
14 |
+
namespace COCOeval {
|
15 |
+
|
16 |
+
// Annotation data for a single object instance in an image
|
17 |
+
struct InstanceAnnotation {
|
18 |
+
InstanceAnnotation(
|
19 |
+
uint64_t id,
|
20 |
+
double score,
|
21 |
+
double area,
|
22 |
+
bool is_crowd,
|
23 |
+
bool ignore)
|
24 |
+
: id{id}, score{score}, area{area}, is_crowd{is_crowd}, ignore{ignore} {}
|
25 |
+
uint64_t id;
|
26 |
+
double score = 0.;
|
27 |
+
double area = 0.;
|
28 |
+
bool is_crowd = false;
|
29 |
+
bool ignore = false;
|
30 |
+
};
|
31 |
+
|
32 |
+
// Stores intermediate results for evaluating detection results for a single
|
33 |
+
// image that has D detected instances and G ground truth instances. This stores
|
34 |
+
// matches between detected and ground truth instances
|
35 |
+
struct ImageEvaluation {
|
36 |
+
// For each of the D detected instances, the id of the matched ground truth
|
37 |
+
// instance, or 0 if unmatched
|
38 |
+
std::vector<uint64_t> detection_matches;
|
39 |
+
|
40 |
+
// The detection score of each of the D detected instances
|
41 |
+
std::vector<double> detection_scores;
|
42 |
+
|
43 |
+
// Marks whether or not each of G instances was ignored from evaluation (e.g.,
|
44 |
+
// because it's outside area_range)
|
45 |
+
std::vector<bool> ground_truth_ignores;
|
46 |
+
|
47 |
+
// Marks whether or not each of D instances was ignored from evaluation (e.g.,
|
48 |
+
// because it's outside aRng)
|
49 |
+
std::vector<bool> detection_ignores;
|
50 |
+
};
|
51 |
+
|
52 |
+
template <class T>
|
53 |
+
using ImageCategoryInstances = std::vector<std::vector<std::vector<T>>>;
|
54 |
+
|
55 |
+
// C++ implementation of COCO API cocoeval.py::COCOeval.evaluateImg(). For each
|
56 |
+
// combination of image, category, area range settings, and IOU thresholds to
|
57 |
+
// evaluate, it matches detected instances to ground truth instances and stores
|
58 |
+
// the results into a vector of ImageEvaluation results, which will be
|
59 |
+
// interpreted by the COCOeval::Accumulate() function to produce precion-recall
|
60 |
+
// curves. The parameters of nested vectors have the following semantics:
|
61 |
+
// image_category_ious[i][c][d][g] is the intersection over union of the d'th
|
62 |
+
// detected instance and g'th ground truth instance of
|
63 |
+
// category category_ids[c] in image image_ids[i]
|
64 |
+
// image_category_ground_truth_instances[i][c] is a vector of ground truth
|
65 |
+
// instances in image image_ids[i] of category category_ids[c]
|
66 |
+
// image_category_detection_instances[i][c] is a vector of detected
|
67 |
+
// instances in image image_ids[i] of category category_ids[c]
|
68 |
+
std::vector<ImageEvaluation> EvaluateImages(
|
69 |
+
const std::vector<std::array<double, 2>>& area_ranges, // vector of 2-tuples
|
70 |
+
int max_detections,
|
71 |
+
const std::vector<double>& iou_thresholds,
|
72 |
+
const ImageCategoryInstances<std::vector<double>>& image_category_ious,
|
73 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
74 |
+
image_category_ground_truth_instances,
|
75 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
76 |
+
image_category_detection_instances);
|
77 |
+
|
78 |
+
// C++ implementation of COCOeval.accumulate(), which generates precision
|
79 |
+
// recall curves for each set of category, IOU threshold, detection area range,
|
80 |
+
// and max number of detections parameters. It is assumed that the parameter
|
81 |
+
// evaluations is the return value of the functon COCOeval::EvaluateImages(),
|
82 |
+
// which was called with the same parameter settings params
|
83 |
+
py::dict Accumulate(
|
84 |
+
const py::object& params,
|
85 |
+
const std::vector<ImageEvaluation>& evalutations);
|
86 |
+
|
87 |
+
} // namespace COCOeval
|
88 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/cuda_version.cu
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
#include <cuda_runtime_api.h>
|
4 |
+
|
5 |
+
namespace detectron2 {
|
6 |
+
int get_cudart_version() {
|
7 |
+
// Not a ROCM platform: Either HIP is not used, or
|
8 |
+
// it is used, but platform is not ROCM (i.e. it is CUDA)
|
9 |
+
#if !defined(__HIP_PLATFORM_HCC__)
|
10 |
+
return CUDART_VERSION;
|
11 |
+
#else
|
12 |
+
int version = 0;
|
13 |
+
|
14 |
+
#if HIP_VERSION_MAJOR != 0
|
15 |
+
// Create a convention similar to that of CUDA, as assumed by other
|
16 |
+
// parts of the code.
|
17 |
+
|
18 |
+
version = HIP_VERSION_MINOR;
|
19 |
+
version += (HIP_VERSION_MAJOR * 100);
|
20 |
+
#else
|
21 |
+
hipRuntimeGetVersion(&version);
|
22 |
+
#endif
|
23 |
+
return version;
|
24 |
+
#endif
|
25 |
+
}
|
26 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/deformable/deform_conv.h
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#pragma once
|
3 |
+
#include <torch/types.h>
|
4 |
+
|
5 |
+
namespace detectron2 {
|
6 |
+
|
7 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
8 |
+
int deform_conv_forward_cuda(
|
9 |
+
at::Tensor input,
|
10 |
+
at::Tensor weight,
|
11 |
+
at::Tensor offset,
|
12 |
+
at::Tensor output,
|
13 |
+
at::Tensor columns,
|
14 |
+
at::Tensor ones,
|
15 |
+
int kW,
|
16 |
+
int kH,
|
17 |
+
int dW,
|
18 |
+
int dH,
|
19 |
+
int padW,
|
20 |
+
int padH,
|
21 |
+
int dilationW,
|
22 |
+
int dilationH,
|
23 |
+
int group,
|
24 |
+
int deformable_group,
|
25 |
+
int im2col_step);
|
26 |
+
|
27 |
+
int deform_conv_backward_input_cuda(
|
28 |
+
at::Tensor input,
|
29 |
+
at::Tensor offset,
|
30 |
+
at::Tensor gradOutput,
|
31 |
+
at::Tensor gradInput,
|
32 |
+
at::Tensor gradOffset,
|
33 |
+
at::Tensor weight,
|
34 |
+
at::Tensor columns,
|
35 |
+
int kW,
|
36 |
+
int kH,
|
37 |
+
int dW,
|
38 |
+
int dH,
|
39 |
+
int padW,
|
40 |
+
int padH,
|
41 |
+
int dilationW,
|
42 |
+
int dilationH,
|
43 |
+
int group,
|
44 |
+
int deformable_group,
|
45 |
+
int im2col_step);
|
46 |
+
|
47 |
+
int deform_conv_backward_parameters_cuda(
|
48 |
+
at::Tensor input,
|
49 |
+
at::Tensor offset,
|
50 |
+
at::Tensor gradOutput,
|
51 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
52 |
+
at::Tensor columns,
|
53 |
+
at::Tensor ones,
|
54 |
+
int kW,
|
55 |
+
int kH,
|
56 |
+
int dW,
|
57 |
+
int dH,
|
58 |
+
int padW,
|
59 |
+
int padH,
|
60 |
+
int dilationW,
|
61 |
+
int dilationH,
|
62 |
+
int group,
|
63 |
+
int deformable_group,
|
64 |
+
float scale,
|
65 |
+
int im2col_step);
|
66 |
+
|
67 |
+
void modulated_deform_conv_cuda_forward(
|
68 |
+
at::Tensor input,
|
69 |
+
at::Tensor weight,
|
70 |
+
at::Tensor bias,
|
71 |
+
at::Tensor ones,
|
72 |
+
at::Tensor offset,
|
73 |
+
at::Tensor mask,
|
74 |
+
at::Tensor output,
|
75 |
+
at::Tensor columns,
|
76 |
+
int kernel_h,
|
77 |
+
int kernel_w,
|
78 |
+
const int stride_h,
|
79 |
+
const int stride_w,
|
80 |
+
const int pad_h,
|
81 |
+
const int pad_w,
|
82 |
+
const int dilation_h,
|
83 |
+
const int dilation_w,
|
84 |
+
const int group,
|
85 |
+
const int deformable_group,
|
86 |
+
const bool with_bias);
|
87 |
+
|
88 |
+
void modulated_deform_conv_cuda_backward(
|
89 |
+
at::Tensor input,
|
90 |
+
at::Tensor weight,
|
91 |
+
at::Tensor bias,
|
92 |
+
at::Tensor ones,
|
93 |
+
at::Tensor offset,
|
94 |
+
at::Tensor mask,
|
95 |
+
at::Tensor columns,
|
96 |
+
at::Tensor grad_input,
|
97 |
+
at::Tensor grad_weight,
|
98 |
+
at::Tensor grad_bias,
|
99 |
+
at::Tensor grad_offset,
|
100 |
+
at::Tensor grad_mask,
|
101 |
+
at::Tensor grad_output,
|
102 |
+
int kernel_h,
|
103 |
+
int kernel_w,
|
104 |
+
int stride_h,
|
105 |
+
int stride_w,
|
106 |
+
int pad_h,
|
107 |
+
int pad_w,
|
108 |
+
int dilation_h,
|
109 |
+
int dilation_w,
|
110 |
+
int group,
|
111 |
+
int deformable_group,
|
112 |
+
const bool with_bias);
|
113 |
+
|
114 |
+
#endif
|
115 |
+
|
116 |
+
inline int deform_conv_forward(
|
117 |
+
at::Tensor input,
|
118 |
+
at::Tensor weight,
|
119 |
+
at::Tensor offset,
|
120 |
+
at::Tensor output,
|
121 |
+
at::Tensor columns,
|
122 |
+
at::Tensor ones,
|
123 |
+
int kW,
|
124 |
+
int kH,
|
125 |
+
int dW,
|
126 |
+
int dH,
|
127 |
+
int padW,
|
128 |
+
int padH,
|
129 |
+
int dilationW,
|
130 |
+
int dilationH,
|
131 |
+
int group,
|
132 |
+
int deformable_group,
|
133 |
+
int im2col_step) {
|
134 |
+
if (input.is_cuda()) {
|
135 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
136 |
+
TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
|
137 |
+
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
|
138 |
+
return deform_conv_forward_cuda(
|
139 |
+
input,
|
140 |
+
weight,
|
141 |
+
offset,
|
142 |
+
output,
|
143 |
+
columns,
|
144 |
+
ones,
|
145 |
+
kW,
|
146 |
+
kH,
|
147 |
+
dW,
|
148 |
+
dH,
|
149 |
+
padW,
|
150 |
+
padH,
|
151 |
+
dilationW,
|
152 |
+
dilationH,
|
153 |
+
group,
|
154 |
+
deformable_group,
|
155 |
+
im2col_step);
|
156 |
+
#else
|
157 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
158 |
+
#endif
|
159 |
+
}
|
160 |
+
AT_ERROR("This operator is not implemented on CPU");
|
161 |
+
}
|
162 |
+
|
163 |
+
inline int deform_conv_backward_input(
|
164 |
+
at::Tensor input,
|
165 |
+
at::Tensor offset,
|
166 |
+
at::Tensor gradOutput,
|
167 |
+
at::Tensor gradInput,
|
168 |
+
at::Tensor gradOffset,
|
169 |
+
at::Tensor weight,
|
170 |
+
at::Tensor columns,
|
171 |
+
int kW,
|
172 |
+
int kH,
|
173 |
+
int dW,
|
174 |
+
int dH,
|
175 |
+
int padW,
|
176 |
+
int padH,
|
177 |
+
int dilationW,
|
178 |
+
int dilationH,
|
179 |
+
int group,
|
180 |
+
int deformable_group,
|
181 |
+
int im2col_step) {
|
182 |
+
if (gradOutput.is_cuda()) {
|
183 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
184 |
+
TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
|
185 |
+
TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
|
186 |
+
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
|
187 |
+
return deform_conv_backward_input_cuda(
|
188 |
+
input,
|
189 |
+
offset,
|
190 |
+
gradOutput,
|
191 |
+
gradInput,
|
192 |
+
gradOffset,
|
193 |
+
weight,
|
194 |
+
columns,
|
195 |
+
kW,
|
196 |
+
kH,
|
197 |
+
dW,
|
198 |
+
dH,
|
199 |
+
padW,
|
200 |
+
padH,
|
201 |
+
dilationW,
|
202 |
+
dilationH,
|
203 |
+
group,
|
204 |
+
deformable_group,
|
205 |
+
im2col_step);
|
206 |
+
#else
|
207 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
208 |
+
#endif
|
209 |
+
}
|
210 |
+
AT_ERROR("This operator is not implemented on CPU");
|
211 |
+
}
|
212 |
+
|
213 |
+
inline int deform_conv_backward_filter(
|
214 |
+
at::Tensor input,
|
215 |
+
at::Tensor offset,
|
216 |
+
at::Tensor gradOutput,
|
217 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
218 |
+
at::Tensor columns,
|
219 |
+
at::Tensor ones,
|
220 |
+
int kW,
|
221 |
+
int kH,
|
222 |
+
int dW,
|
223 |
+
int dH,
|
224 |
+
int padW,
|
225 |
+
int padH,
|
226 |
+
int dilationW,
|
227 |
+
int dilationH,
|
228 |
+
int group,
|
229 |
+
int deformable_group,
|
230 |
+
float scale,
|
231 |
+
int im2col_step) {
|
232 |
+
if (gradOutput.is_cuda()) {
|
233 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
234 |
+
TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
|
235 |
+
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
|
236 |
+
return deform_conv_backward_parameters_cuda(
|
237 |
+
input,
|
238 |
+
offset,
|
239 |
+
gradOutput,
|
240 |
+
gradWeight,
|
241 |
+
columns,
|
242 |
+
ones,
|
243 |
+
kW,
|
244 |
+
kH,
|
245 |
+
dW,
|
246 |
+
dH,
|
247 |
+
padW,
|
248 |
+
padH,
|
249 |
+
dilationW,
|
250 |
+
dilationH,
|
251 |
+
group,
|
252 |
+
deformable_group,
|
253 |
+
scale,
|
254 |
+
im2col_step);
|
255 |
+
#else
|
256 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
257 |
+
#endif
|
258 |
+
}
|
259 |
+
AT_ERROR("This operator is not implemented on CPU");
|
260 |
+
}
|
261 |
+
|
262 |
+
inline void modulated_deform_conv_forward(
|
263 |
+
at::Tensor input,
|
264 |
+
at::Tensor weight,
|
265 |
+
at::Tensor bias,
|
266 |
+
at::Tensor ones,
|
267 |
+
at::Tensor offset,
|
268 |
+
at::Tensor mask,
|
269 |
+
at::Tensor output,
|
270 |
+
at::Tensor columns,
|
271 |
+
int kernel_h,
|
272 |
+
int kernel_w,
|
273 |
+
const int stride_h,
|
274 |
+
const int stride_w,
|
275 |
+
const int pad_h,
|
276 |
+
const int pad_w,
|
277 |
+
const int dilation_h,
|
278 |
+
const int dilation_w,
|
279 |
+
const int group,
|
280 |
+
const int deformable_group,
|
281 |
+
const bool with_bias) {
|
282 |
+
if (input.is_cuda()) {
|
283 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
284 |
+
TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
|
285 |
+
TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!");
|
286 |
+
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
|
287 |
+
return modulated_deform_conv_cuda_forward(
|
288 |
+
input,
|
289 |
+
weight,
|
290 |
+
bias,
|
291 |
+
ones,
|
292 |
+
offset,
|
293 |
+
mask,
|
294 |
+
output,
|
295 |
+
columns,
|
296 |
+
kernel_h,
|
297 |
+
kernel_w,
|
298 |
+
stride_h,
|
299 |
+
stride_w,
|
300 |
+
pad_h,
|
301 |
+
pad_w,
|
302 |
+
dilation_h,
|
303 |
+
dilation_w,
|
304 |
+
group,
|
305 |
+
deformable_group,
|
306 |
+
with_bias);
|
307 |
+
#else
|
308 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
309 |
+
#endif
|
310 |
+
}
|
311 |
+
AT_ERROR("This operator is not implemented on CPU");
|
312 |
+
}
|
313 |
+
|
314 |
+
inline void modulated_deform_conv_backward(
|
315 |
+
at::Tensor input,
|
316 |
+
at::Tensor weight,
|
317 |
+
at::Tensor bias,
|
318 |
+
at::Tensor ones,
|
319 |
+
at::Tensor offset,
|
320 |
+
at::Tensor mask,
|
321 |
+
at::Tensor columns,
|
322 |
+
at::Tensor grad_input,
|
323 |
+
at::Tensor grad_weight,
|
324 |
+
at::Tensor grad_bias,
|
325 |
+
at::Tensor grad_offset,
|
326 |
+
at::Tensor grad_mask,
|
327 |
+
at::Tensor grad_output,
|
328 |
+
int kernel_h,
|
329 |
+
int kernel_w,
|
330 |
+
int stride_h,
|
331 |
+
int stride_w,
|
332 |
+
int pad_h,
|
333 |
+
int pad_w,
|
334 |
+
int dilation_h,
|
335 |
+
int dilation_w,
|
336 |
+
int group,
|
337 |
+
int deformable_group,
|
338 |
+
const bool with_bias) {
|
339 |
+
if (grad_output.is_cuda()) {
|
340 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
341 |
+
TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
|
342 |
+
TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
|
343 |
+
TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!");
|
344 |
+
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
|
345 |
+
return modulated_deform_conv_cuda_backward(
|
346 |
+
input,
|
347 |
+
weight,
|
348 |
+
bias,
|
349 |
+
ones,
|
350 |
+
offset,
|
351 |
+
mask,
|
352 |
+
columns,
|
353 |
+
grad_input,
|
354 |
+
grad_weight,
|
355 |
+
grad_bias,
|
356 |
+
grad_offset,
|
357 |
+
grad_mask,
|
358 |
+
grad_output,
|
359 |
+
kernel_h,
|
360 |
+
kernel_w,
|
361 |
+
stride_h,
|
362 |
+
stride_w,
|
363 |
+
pad_h,
|
364 |
+
pad_w,
|
365 |
+
dilation_h,
|
366 |
+
dilation_w,
|
367 |
+
group,
|
368 |
+
deformable_group,
|
369 |
+
with_bias);
|
370 |
+
#else
|
371 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
372 |
+
#endif
|
373 |
+
}
|
374 |
+
AT_ERROR("This operator is not implemented on CPU");
|
375 |
+
}
|
376 |
+
|
377 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/deformable/deform_conv_cuda.cu
ADDED
@@ -0,0 +1,1223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
// modified from
|
4 |
+
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
|
5 |
+
// Original license: Apache 2.0
|
6 |
+
|
7 |
+
// modify from
|
8 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
9 |
+
// Original license: Apache 2.0
|
10 |
+
|
11 |
+
#include <torch/types.h>
|
12 |
+
|
13 |
+
#include "deform_conv.h"
|
14 |
+
|
15 |
+
#include <cmath>
|
16 |
+
#include <vector>
|
17 |
+
|
18 |
+
namespace detectron2 {
|
19 |
+
|
20 |
+
void deformable_im2col(
|
21 |
+
const at::Tensor data_im,
|
22 |
+
const at::Tensor data_offset,
|
23 |
+
const int channels,
|
24 |
+
const int height,
|
25 |
+
const int width,
|
26 |
+
const int ksize_h,
|
27 |
+
const int ksize_w,
|
28 |
+
const int pad_h,
|
29 |
+
const int pad_w,
|
30 |
+
const int stride_h,
|
31 |
+
const int stride_w,
|
32 |
+
const int dilation_h,
|
33 |
+
const int dilation_w,
|
34 |
+
const int parallel_imgs,
|
35 |
+
const int deformable_group,
|
36 |
+
at::Tensor data_col);
|
37 |
+
|
38 |
+
void deformable_col2im(
|
39 |
+
const at::Tensor data_col,
|
40 |
+
const at::Tensor data_offset,
|
41 |
+
const int channels,
|
42 |
+
const int height,
|
43 |
+
const int width,
|
44 |
+
const int ksize_h,
|
45 |
+
const int ksize_w,
|
46 |
+
const int pad_h,
|
47 |
+
const int pad_w,
|
48 |
+
const int stride_h,
|
49 |
+
const int stride_w,
|
50 |
+
const int dilation_h,
|
51 |
+
const int dilation_w,
|
52 |
+
const int parallel_imgs,
|
53 |
+
const int deformable_group,
|
54 |
+
at::Tensor grad_im);
|
55 |
+
|
56 |
+
void deformable_col2im_coord(
|
57 |
+
const at::Tensor data_col,
|
58 |
+
const at::Tensor data_im,
|
59 |
+
const at::Tensor data_offset,
|
60 |
+
const int channels,
|
61 |
+
const int height,
|
62 |
+
const int width,
|
63 |
+
const int ksize_h,
|
64 |
+
const int ksize_w,
|
65 |
+
const int pad_h,
|
66 |
+
const int pad_w,
|
67 |
+
const int stride_h,
|
68 |
+
const int stride_w,
|
69 |
+
const int dilation_h,
|
70 |
+
const int dilation_w,
|
71 |
+
const int parallel_imgs,
|
72 |
+
const int deformable_group,
|
73 |
+
at::Tensor grad_offset);
|
74 |
+
|
75 |
+
void modulated_deformable_im2col_cuda(
|
76 |
+
const at::Tensor data_im,
|
77 |
+
const at::Tensor data_offset,
|
78 |
+
const at::Tensor data_mask,
|
79 |
+
const int batch_size,
|
80 |
+
const int channels,
|
81 |
+
const int height_im,
|
82 |
+
const int width_im,
|
83 |
+
const int height_col,
|
84 |
+
const int width_col,
|
85 |
+
const int kernel_h,
|
86 |
+
const int kenerl_w,
|
87 |
+
const int pad_h,
|
88 |
+
const int pad_w,
|
89 |
+
const int stride_h,
|
90 |
+
const int stride_w,
|
91 |
+
const int dilation_h,
|
92 |
+
const int dilation_w,
|
93 |
+
const int deformable_group,
|
94 |
+
at::Tensor data_col);
|
95 |
+
|
96 |
+
void modulated_deformable_col2im_cuda(
|
97 |
+
const at::Tensor data_col,
|
98 |
+
const at::Tensor data_offset,
|
99 |
+
const at::Tensor data_mask,
|
100 |
+
const int batch_size,
|
101 |
+
const int channels,
|
102 |
+
const int height_im,
|
103 |
+
const int width_im,
|
104 |
+
const int height_col,
|
105 |
+
const int width_col,
|
106 |
+
const int kernel_h,
|
107 |
+
const int kenerl_w,
|
108 |
+
const int pad_h,
|
109 |
+
const int pad_w,
|
110 |
+
const int stride_h,
|
111 |
+
const int stride_w,
|
112 |
+
const int dilation_h,
|
113 |
+
const int dilation_w,
|
114 |
+
const int deformable_group,
|
115 |
+
at::Tensor grad_im);
|
116 |
+
|
117 |
+
void modulated_deformable_col2im_coord_cuda(
|
118 |
+
const at::Tensor data_col,
|
119 |
+
const at::Tensor data_im,
|
120 |
+
const at::Tensor data_offset,
|
121 |
+
const at::Tensor data_mask,
|
122 |
+
const int batch_size,
|
123 |
+
const int channels,
|
124 |
+
const int height_im,
|
125 |
+
const int width_im,
|
126 |
+
const int height_col,
|
127 |
+
const int width_col,
|
128 |
+
const int kernel_h,
|
129 |
+
const int kenerl_w,
|
130 |
+
const int pad_h,
|
131 |
+
const int pad_w,
|
132 |
+
const int stride_h,
|
133 |
+
const int stride_w,
|
134 |
+
const int dilation_h,
|
135 |
+
const int dilation_w,
|
136 |
+
const int deformable_group,
|
137 |
+
at::Tensor grad_offset,
|
138 |
+
at::Tensor grad_mask);
|
139 |
+
|
140 |
+
void shape_check(
|
141 |
+
at::Tensor input,
|
142 |
+
at::Tensor offset,
|
143 |
+
at::Tensor* gradOutput,
|
144 |
+
at::Tensor weight,
|
145 |
+
int kH,
|
146 |
+
int kW,
|
147 |
+
int dH,
|
148 |
+
int dW,
|
149 |
+
int padH,
|
150 |
+
int padW,
|
151 |
+
int dilationH,
|
152 |
+
int dilationW,
|
153 |
+
int group,
|
154 |
+
int deformable_group) {
|
155 |
+
TORCH_CHECK(
|
156 |
+
weight.ndimension() == 4,
|
157 |
+
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
158 |
+
"but got: %s",
|
159 |
+
weight.ndimension());
|
160 |
+
|
161 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
162 |
+
|
163 |
+
TORCH_CHECK(
|
164 |
+
kW > 0 && kH > 0,
|
165 |
+
"kernel size should be greater than zero, but got kH: %d kW: %d",
|
166 |
+
kH,
|
167 |
+
kW);
|
168 |
+
|
169 |
+
TORCH_CHECK(
|
170 |
+
(weight.size(2) == kH && weight.size(3) == kW),
|
171 |
+
"kernel size should be consistent with weight, ",
|
172 |
+
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d",
|
173 |
+
kH,
|
174 |
+
kW,
|
175 |
+
weight.size(2),
|
176 |
+
weight.size(3));
|
177 |
+
|
178 |
+
TORCH_CHECK(
|
179 |
+
dW > 0 && dH > 0,
|
180 |
+
"stride should be greater than zero, but got dH: %d dW: %d",
|
181 |
+
dH,
|
182 |
+
dW);
|
183 |
+
|
184 |
+
TORCH_CHECK(
|
185 |
+
dilationW > 0 && dilationH > 0,
|
186 |
+
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
187 |
+
dilationH,
|
188 |
+
dilationW);
|
189 |
+
|
190 |
+
int ndim = input.ndimension();
|
191 |
+
int dimf = 0;
|
192 |
+
int dimh = 1;
|
193 |
+
int dimw = 2;
|
194 |
+
|
195 |
+
if (ndim == 4) {
|
196 |
+
dimf++;
|
197 |
+
dimh++;
|
198 |
+
dimw++;
|
199 |
+
}
|
200 |
+
|
201 |
+
TORCH_CHECK(
|
202 |
+
ndim == 3 || ndim == 4,
|
203 |
+
"3D or 4D input tensor expected but got: %s",
|
204 |
+
ndim);
|
205 |
+
|
206 |
+
long nInputPlane = weight.size(1) * group;
|
207 |
+
long inputHeight = input.size(dimh);
|
208 |
+
long inputWidth = input.size(dimw);
|
209 |
+
long nOutputPlane = weight.size(0);
|
210 |
+
long outputHeight =
|
211 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
212 |
+
long outputWidth =
|
213 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
214 |
+
|
215 |
+
TORCH_CHECK(
|
216 |
+
nInputPlane % deformable_group == 0,
|
217 |
+
"input channels must divide deformable group size");
|
218 |
+
|
219 |
+
if (outputWidth < 1 || outputHeight < 1)
|
220 |
+
AT_ERROR(
|
221 |
+
"Given input size: (%ld x %ld x %ld). "
|
222 |
+
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
223 |
+
nInputPlane,
|
224 |
+
inputHeight,
|
225 |
+
inputWidth,
|
226 |
+
nOutputPlane,
|
227 |
+
outputHeight,
|
228 |
+
outputWidth);
|
229 |
+
|
230 |
+
TORCH_CHECK(
|
231 |
+
input.size(1) == nInputPlane,
|
232 |
+
"invalid number of input planes, expected: %d, but got: %d",
|
233 |
+
nInputPlane,
|
234 |
+
input.size(1));
|
235 |
+
|
236 |
+
TORCH_CHECK(
|
237 |
+
(inputHeight + 2 * padH >= kH && inputWidth + 2 * padW >= kW),
|
238 |
+
"input image is smaller than kernel");
|
239 |
+
|
240 |
+
TORCH_CHECK(
|
241 |
+
(offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
242 |
+
"invalid spatial size of offset, expected height: %d width: %d, but "
|
243 |
+
"got height: %d width: %d",
|
244 |
+
outputHeight,
|
245 |
+
outputWidth,
|
246 |
+
offset.size(2),
|
247 |
+
offset.size(3));
|
248 |
+
|
249 |
+
TORCH_CHECK(
|
250 |
+
(offset.size(1) == deformable_group * 2 * kH * kW),
|
251 |
+
"invalid number of channels of offset");
|
252 |
+
|
253 |
+
if (gradOutput != NULL) {
|
254 |
+
TORCH_CHECK(
|
255 |
+
gradOutput->size(dimf) == nOutputPlane,
|
256 |
+
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
257 |
+
nOutputPlane,
|
258 |
+
gradOutput->size(dimf));
|
259 |
+
|
260 |
+
TORCH_CHECK(
|
261 |
+
(gradOutput->size(dimh) == outputHeight &&
|
262 |
+
gradOutput->size(dimw) == outputWidth),
|
263 |
+
"invalid size of gradOutput, expected height: %d width: %d , but "
|
264 |
+
"got height: %d width: %d",
|
265 |
+
outputHeight,
|
266 |
+
outputWidth,
|
267 |
+
gradOutput->size(dimh),
|
268 |
+
gradOutput->size(dimw));
|
269 |
+
}
|
270 |
+
}
|
271 |
+
|
272 |
+
int deform_conv_forward_cuda(
|
273 |
+
at::Tensor input,
|
274 |
+
at::Tensor weight,
|
275 |
+
at::Tensor offset,
|
276 |
+
at::Tensor output,
|
277 |
+
at::Tensor columns,
|
278 |
+
at::Tensor ones,
|
279 |
+
int kW,
|
280 |
+
int kH,
|
281 |
+
int dW,
|
282 |
+
int dH,
|
283 |
+
int padW,
|
284 |
+
int padH,
|
285 |
+
int dilationW,
|
286 |
+
int dilationH,
|
287 |
+
int group,
|
288 |
+
int deformable_group,
|
289 |
+
int im2col_step) {
|
290 |
+
// todo: resize columns to include im2col: done
|
291 |
+
// todo: add im2col_step as input
|
292 |
+
// todo: add new output buffer and transpose it to output (or directly
|
293 |
+
// transpose output) todo: possibly change data indexing because of
|
294 |
+
// parallel_imgs
|
295 |
+
|
296 |
+
shape_check(
|
297 |
+
input,
|
298 |
+
offset,
|
299 |
+
NULL,
|
300 |
+
weight,
|
301 |
+
kH,
|
302 |
+
kW,
|
303 |
+
dH,
|
304 |
+
dW,
|
305 |
+
padH,
|
306 |
+
padW,
|
307 |
+
dilationH,
|
308 |
+
dilationW,
|
309 |
+
group,
|
310 |
+
deformable_group);
|
311 |
+
|
312 |
+
input = input.contiguous();
|
313 |
+
offset = offset.contiguous();
|
314 |
+
weight = weight.contiguous();
|
315 |
+
|
316 |
+
int batch = 1;
|
317 |
+
if (input.ndimension() == 3) {
|
318 |
+
// Force batch
|
319 |
+
batch = 0;
|
320 |
+
input.unsqueeze_(0);
|
321 |
+
offset.unsqueeze_(0);
|
322 |
+
}
|
323 |
+
|
324 |
+
// todo: assert batchsize dividable by im2col_step
|
325 |
+
|
326 |
+
long batchSize = input.size(0);
|
327 |
+
long nInputPlane = input.size(1);
|
328 |
+
long inputHeight = input.size(2);
|
329 |
+
long inputWidth = input.size(3);
|
330 |
+
|
331 |
+
long nOutputPlane = weight.size(0);
|
332 |
+
|
333 |
+
long outputWidth =
|
334 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
335 |
+
long outputHeight =
|
336 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
337 |
+
|
338 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
339 |
+
|
340 |
+
output = output.view(
|
341 |
+
{batchSize / im2col_step,
|
342 |
+
im2col_step,
|
343 |
+
nOutputPlane,
|
344 |
+
outputHeight,
|
345 |
+
outputWidth});
|
346 |
+
columns = at::zeros(
|
347 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
348 |
+
input.options());
|
349 |
+
|
350 |
+
if (ones.ndimension() != 2 ||
|
351 |
+
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
352 |
+
ones = at::ones({outputHeight, outputWidth}, input.options());
|
353 |
+
}
|
354 |
+
|
355 |
+
input = input.view(
|
356 |
+
{batchSize / im2col_step,
|
357 |
+
im2col_step,
|
358 |
+
nInputPlane,
|
359 |
+
inputHeight,
|
360 |
+
inputWidth});
|
361 |
+
offset = offset.view(
|
362 |
+
{batchSize / im2col_step,
|
363 |
+
im2col_step,
|
364 |
+
deformable_group * 2 * kH * kW,
|
365 |
+
outputHeight,
|
366 |
+
outputWidth});
|
367 |
+
|
368 |
+
at::Tensor output_buffer = at::zeros(
|
369 |
+
{batchSize / im2col_step,
|
370 |
+
nOutputPlane,
|
371 |
+
im2col_step * outputHeight,
|
372 |
+
outputWidth},
|
373 |
+
output.options());
|
374 |
+
|
375 |
+
output_buffer = output_buffer.view(
|
376 |
+
{output_buffer.size(0),
|
377 |
+
group,
|
378 |
+
output_buffer.size(1) / group,
|
379 |
+
output_buffer.size(2),
|
380 |
+
output_buffer.size(3)});
|
381 |
+
|
382 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
383 |
+
deformable_im2col(
|
384 |
+
input[elt],
|
385 |
+
offset[elt],
|
386 |
+
nInputPlane,
|
387 |
+
inputHeight,
|
388 |
+
inputWidth,
|
389 |
+
kH,
|
390 |
+
kW,
|
391 |
+
padH,
|
392 |
+
padW,
|
393 |
+
dH,
|
394 |
+
dW,
|
395 |
+
dilationH,
|
396 |
+
dilationW,
|
397 |
+
im2col_step,
|
398 |
+
deformable_group,
|
399 |
+
columns);
|
400 |
+
|
401 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
402 |
+
weight = weight.view(
|
403 |
+
{group,
|
404 |
+
weight.size(0) / group,
|
405 |
+
weight.size(1),
|
406 |
+
weight.size(2),
|
407 |
+
weight.size(3)});
|
408 |
+
|
409 |
+
for (int g = 0; g < group; g++) {
|
410 |
+
output_buffer[elt][g] = output_buffer[elt][g]
|
411 |
+
.flatten(1)
|
412 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
413 |
+
.view_as(output_buffer[elt][g]);
|
414 |
+
}
|
415 |
+
}
|
416 |
+
|
417 |
+
output_buffer = output_buffer.view(
|
418 |
+
{output_buffer.size(0),
|
419 |
+
output_buffer.size(1) * output_buffer.size(2),
|
420 |
+
output_buffer.size(3),
|
421 |
+
output_buffer.size(4)});
|
422 |
+
|
423 |
+
output_buffer = output_buffer.view(
|
424 |
+
{batchSize / im2col_step,
|
425 |
+
nOutputPlane,
|
426 |
+
im2col_step,
|
427 |
+
outputHeight,
|
428 |
+
outputWidth});
|
429 |
+
output_buffer.transpose_(1, 2);
|
430 |
+
output.copy_(output_buffer);
|
431 |
+
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
432 |
+
|
433 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
434 |
+
offset = offset.view(
|
435 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
436 |
+
|
437 |
+
if (batch == 0) {
|
438 |
+
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
439 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
440 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
441 |
+
}
|
442 |
+
|
443 |
+
return 1;
|
444 |
+
}
|
445 |
+
|
446 |
+
int deform_conv_backward_input_cuda(
|
447 |
+
at::Tensor input,
|
448 |
+
at::Tensor offset,
|
449 |
+
at::Tensor gradOutput,
|
450 |
+
at::Tensor gradInput,
|
451 |
+
at::Tensor gradOffset,
|
452 |
+
at::Tensor weight,
|
453 |
+
at::Tensor columns,
|
454 |
+
int kW,
|
455 |
+
int kH,
|
456 |
+
int dW,
|
457 |
+
int dH,
|
458 |
+
int padW,
|
459 |
+
int padH,
|
460 |
+
int dilationW,
|
461 |
+
int dilationH,
|
462 |
+
int group,
|
463 |
+
int deformable_group,
|
464 |
+
int im2col_step) {
|
465 |
+
shape_check(
|
466 |
+
input,
|
467 |
+
offset,
|
468 |
+
&gradOutput,
|
469 |
+
weight,
|
470 |
+
kH,
|
471 |
+
kW,
|
472 |
+
dH,
|
473 |
+
dW,
|
474 |
+
padH,
|
475 |
+
padW,
|
476 |
+
dilationH,
|
477 |
+
dilationW,
|
478 |
+
group,
|
479 |
+
deformable_group);
|
480 |
+
|
481 |
+
input = input.contiguous();
|
482 |
+
offset = offset.contiguous();
|
483 |
+
gradOutput = gradOutput.contiguous();
|
484 |
+
weight = weight.contiguous();
|
485 |
+
|
486 |
+
int batch = 1;
|
487 |
+
|
488 |
+
if (input.ndimension() == 3) {
|
489 |
+
// Force batch
|
490 |
+
batch = 0;
|
491 |
+
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
492 |
+
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
493 |
+
gradOutput = gradOutput.view(
|
494 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
495 |
+
}
|
496 |
+
|
497 |
+
long batchSize = input.size(0);
|
498 |
+
long nInputPlane = input.size(1);
|
499 |
+
long inputHeight = input.size(2);
|
500 |
+
long inputWidth = input.size(3);
|
501 |
+
|
502 |
+
long nOutputPlane = weight.size(0);
|
503 |
+
|
504 |
+
long outputWidth =
|
505 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
506 |
+
long outputHeight =
|
507 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
508 |
+
|
509 |
+
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
510 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
511 |
+
columns = at::zeros(
|
512 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
513 |
+
input.options());
|
514 |
+
|
515 |
+
// change order of grad output
|
516 |
+
gradOutput = gradOutput.view(
|
517 |
+
{batchSize / im2col_step,
|
518 |
+
im2col_step,
|
519 |
+
nOutputPlane,
|
520 |
+
outputHeight,
|
521 |
+
outputWidth});
|
522 |
+
gradOutput.transpose_(1, 2);
|
523 |
+
|
524 |
+
gradInput = gradInput.view(
|
525 |
+
{batchSize / im2col_step,
|
526 |
+
im2col_step,
|
527 |
+
nInputPlane,
|
528 |
+
inputHeight,
|
529 |
+
inputWidth});
|
530 |
+
input = input.view(
|
531 |
+
{batchSize / im2col_step,
|
532 |
+
im2col_step,
|
533 |
+
nInputPlane,
|
534 |
+
inputHeight,
|
535 |
+
inputWidth});
|
536 |
+
gradOffset = gradOffset.view(
|
537 |
+
{batchSize / im2col_step,
|
538 |
+
im2col_step,
|
539 |
+
deformable_group * 2 * kH * kW,
|
540 |
+
outputHeight,
|
541 |
+
outputWidth});
|
542 |
+
offset = offset.view(
|
543 |
+
{batchSize / im2col_step,
|
544 |
+
im2col_step,
|
545 |
+
deformable_group * 2 * kH * kW,
|
546 |
+
outputHeight,
|
547 |
+
outputWidth});
|
548 |
+
|
549 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
550 |
+
// divide into groups
|
551 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
552 |
+
weight = weight.view(
|
553 |
+
{group,
|
554 |
+
weight.size(0) / group,
|
555 |
+
weight.size(1),
|
556 |
+
weight.size(2),
|
557 |
+
weight.size(3)});
|
558 |
+
gradOutput = gradOutput.view(
|
559 |
+
{gradOutput.size(0),
|
560 |
+
group,
|
561 |
+
gradOutput.size(1) / group,
|
562 |
+
gradOutput.size(2),
|
563 |
+
gradOutput.size(3),
|
564 |
+
gradOutput.size(4)});
|
565 |
+
|
566 |
+
for (int g = 0; g < group; g++) {
|
567 |
+
columns[g] = columns[g].addmm_(
|
568 |
+
weight[g].flatten(1).transpose(0, 1),
|
569 |
+
gradOutput[elt][g].flatten(1),
|
570 |
+
0.0f,
|
571 |
+
1.0f);
|
572 |
+
}
|
573 |
+
|
574 |
+
columns =
|
575 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
576 |
+
gradOutput = gradOutput.view(
|
577 |
+
{gradOutput.size(0),
|
578 |
+
gradOutput.size(1) * gradOutput.size(2),
|
579 |
+
gradOutput.size(3),
|
580 |
+
gradOutput.size(4),
|
581 |
+
gradOutput.size(5)});
|
582 |
+
|
583 |
+
deformable_col2im_coord(
|
584 |
+
columns,
|
585 |
+
input[elt],
|
586 |
+
offset[elt],
|
587 |
+
nInputPlane,
|
588 |
+
inputHeight,
|
589 |
+
inputWidth,
|
590 |
+
kH,
|
591 |
+
kW,
|
592 |
+
padH,
|
593 |
+
padW,
|
594 |
+
dH,
|
595 |
+
dW,
|
596 |
+
dilationH,
|
597 |
+
dilationW,
|
598 |
+
im2col_step,
|
599 |
+
deformable_group,
|
600 |
+
gradOffset[elt]);
|
601 |
+
|
602 |
+
deformable_col2im(
|
603 |
+
columns,
|
604 |
+
offset[elt],
|
605 |
+
nInputPlane,
|
606 |
+
inputHeight,
|
607 |
+
inputWidth,
|
608 |
+
kH,
|
609 |
+
kW,
|
610 |
+
padH,
|
611 |
+
padW,
|
612 |
+
dH,
|
613 |
+
dW,
|
614 |
+
dilationH,
|
615 |
+
dilationW,
|
616 |
+
im2col_step,
|
617 |
+
deformable_group,
|
618 |
+
gradInput[elt]);
|
619 |
+
}
|
620 |
+
|
621 |
+
gradOutput.transpose_(1, 2);
|
622 |
+
gradOutput =
|
623 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
624 |
+
|
625 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
626 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
627 |
+
gradOffset = gradOffset.view(
|
628 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
629 |
+
offset = offset.view(
|
630 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
631 |
+
|
632 |
+
if (batch == 0) {
|
633 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
634 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
635 |
+
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
636 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
637 |
+
gradOffset =
|
638 |
+
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
639 |
+
}
|
640 |
+
|
641 |
+
return 1;
|
642 |
+
}
|
643 |
+
|
644 |
+
int deform_conv_backward_parameters_cuda(
|
645 |
+
at::Tensor input,
|
646 |
+
at::Tensor offset,
|
647 |
+
at::Tensor gradOutput,
|
648 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
649 |
+
at::Tensor columns,
|
650 |
+
at::Tensor ones,
|
651 |
+
int kW,
|
652 |
+
int kH,
|
653 |
+
int dW,
|
654 |
+
int dH,
|
655 |
+
int padW,
|
656 |
+
int padH,
|
657 |
+
int dilationW,
|
658 |
+
int dilationH,
|
659 |
+
int group,
|
660 |
+
int deformable_group,
|
661 |
+
float scale,
|
662 |
+
int im2col_step) {
|
663 |
+
// todo: transpose and reshape outGrad
|
664 |
+
// todo: reshape columns
|
665 |
+
// todo: add im2col_step as input
|
666 |
+
|
667 |
+
shape_check(
|
668 |
+
input,
|
669 |
+
offset,
|
670 |
+
&gradOutput,
|
671 |
+
gradWeight,
|
672 |
+
kH,
|
673 |
+
kW,
|
674 |
+
dH,
|
675 |
+
dW,
|
676 |
+
padH,
|
677 |
+
padW,
|
678 |
+
dilationH,
|
679 |
+
dilationW,
|
680 |
+
group,
|
681 |
+
deformable_group);
|
682 |
+
|
683 |
+
input = input.contiguous();
|
684 |
+
offset = offset.contiguous();
|
685 |
+
gradOutput = gradOutput.contiguous();
|
686 |
+
|
687 |
+
int batch = 1;
|
688 |
+
|
689 |
+
if (input.ndimension() == 3) {
|
690 |
+
// Force batch
|
691 |
+
batch = 0;
|
692 |
+
input = input.view(
|
693 |
+
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
694 |
+
gradOutput = gradOutput.view(
|
695 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
696 |
+
}
|
697 |
+
|
698 |
+
long batchSize = input.size(0);
|
699 |
+
long nInputPlane = input.size(1);
|
700 |
+
long inputHeight = input.size(2);
|
701 |
+
long inputWidth = input.size(3);
|
702 |
+
|
703 |
+
long nOutputPlane = gradWeight.size(0);
|
704 |
+
|
705 |
+
long outputWidth =
|
706 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
707 |
+
long outputHeight =
|
708 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
709 |
+
|
710 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
711 |
+
|
712 |
+
columns = at::zeros(
|
713 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
714 |
+
input.options());
|
715 |
+
|
716 |
+
gradOutput = gradOutput.view(
|
717 |
+
{batchSize / im2col_step,
|
718 |
+
im2col_step,
|
719 |
+
nOutputPlane,
|
720 |
+
outputHeight,
|
721 |
+
outputWidth});
|
722 |
+
gradOutput.transpose_(1, 2);
|
723 |
+
|
724 |
+
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
725 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
726 |
+
{batchSize / im2col_step,
|
727 |
+
nOutputPlane,
|
728 |
+
im2col_step,
|
729 |
+
outputHeight,
|
730 |
+
outputWidth});
|
731 |
+
gradOutputBuffer.copy_(gradOutput);
|
732 |
+
// gradOutput is not contiguous, so we do reshape (instead of view) next
|
733 |
+
gradOutputBuffer = gradOutputBuffer.reshape(
|
734 |
+
{batchSize / im2col_step,
|
735 |
+
nOutputPlane,
|
736 |
+
im2col_step * outputHeight,
|
737 |
+
outputWidth});
|
738 |
+
|
739 |
+
gradOutput.transpose_(1, 2);
|
740 |
+
gradOutput =
|
741 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
742 |
+
|
743 |
+
input = input.view(
|
744 |
+
{batchSize / im2col_step,
|
745 |
+
im2col_step,
|
746 |
+
nInputPlane,
|
747 |
+
inputHeight,
|
748 |
+
inputWidth});
|
749 |
+
offset = offset.view(
|
750 |
+
{batchSize / im2col_step,
|
751 |
+
im2col_step,
|
752 |
+
deformable_group * 2 * kH * kW,
|
753 |
+
outputHeight,
|
754 |
+
outputWidth});
|
755 |
+
|
756 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
757 |
+
deformable_im2col(
|
758 |
+
input[elt],
|
759 |
+
offset[elt],
|
760 |
+
nInputPlane,
|
761 |
+
inputHeight,
|
762 |
+
inputWidth,
|
763 |
+
kH,
|
764 |
+
kW,
|
765 |
+
padH,
|
766 |
+
padW,
|
767 |
+
dH,
|
768 |
+
dW,
|
769 |
+
dilationH,
|
770 |
+
dilationW,
|
771 |
+
im2col_step,
|
772 |
+
deformable_group,
|
773 |
+
columns);
|
774 |
+
|
775 |
+
// divide into group
|
776 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
777 |
+
{gradOutputBuffer.size(0),
|
778 |
+
group,
|
779 |
+
gradOutputBuffer.size(1) / group,
|
780 |
+
gradOutputBuffer.size(2),
|
781 |
+
gradOutputBuffer.size(3)});
|
782 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
783 |
+
gradWeight = gradWeight.view(
|
784 |
+
{group,
|
785 |
+
gradWeight.size(0) / group,
|
786 |
+
gradWeight.size(1),
|
787 |
+
gradWeight.size(2),
|
788 |
+
gradWeight.size(3)});
|
789 |
+
|
790 |
+
for (int g = 0; g < group; g++) {
|
791 |
+
gradWeight[g] = gradWeight[g]
|
792 |
+
.flatten(1)
|
793 |
+
.addmm_(
|
794 |
+
gradOutputBuffer[elt][g].flatten(1),
|
795 |
+
columns[g].transpose(1, 0),
|
796 |
+
1.0,
|
797 |
+
scale)
|
798 |
+
.view_as(gradWeight[g]);
|
799 |
+
}
|
800 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
801 |
+
{gradOutputBuffer.size(0),
|
802 |
+
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
803 |
+
gradOutputBuffer.size(3),
|
804 |
+
gradOutputBuffer.size(4)});
|
805 |
+
columns =
|
806 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
807 |
+
gradWeight = gradWeight.view(
|
808 |
+
{gradWeight.size(0) * gradWeight.size(1),
|
809 |
+
gradWeight.size(2),
|
810 |
+
gradWeight.size(3),
|
811 |
+
gradWeight.size(4)});
|
812 |
+
}
|
813 |
+
|
814 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
815 |
+
offset = offset.view(
|
816 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
817 |
+
|
818 |
+
if (batch == 0) {
|
819 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
820 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
821 |
+
}
|
822 |
+
|
823 |
+
return 1;
|
824 |
+
}
|
825 |
+
|
826 |
+
void modulated_deform_conv_cuda_forward(
|
827 |
+
at::Tensor input,
|
828 |
+
at::Tensor weight,
|
829 |
+
at::Tensor bias,
|
830 |
+
at::Tensor ones,
|
831 |
+
at::Tensor offset,
|
832 |
+
at::Tensor mask,
|
833 |
+
at::Tensor output,
|
834 |
+
at::Tensor columns,
|
835 |
+
int kernel_h,
|
836 |
+
int kernel_w,
|
837 |
+
const int stride_h,
|
838 |
+
const int stride_w,
|
839 |
+
const int pad_h,
|
840 |
+
const int pad_w,
|
841 |
+
const int dilation_h,
|
842 |
+
const int dilation_w,
|
843 |
+
const int group,
|
844 |
+
const int deformable_group,
|
845 |
+
const bool with_bias) {
|
846 |
+
shape_check(
|
847 |
+
input,
|
848 |
+
offset,
|
849 |
+
NULL,
|
850 |
+
weight,
|
851 |
+
kernel_h,
|
852 |
+
kernel_w,
|
853 |
+
stride_h,
|
854 |
+
stride_w,
|
855 |
+
pad_h,
|
856 |
+
pad_w,
|
857 |
+
dilation_h,
|
858 |
+
dilation_w,
|
859 |
+
group,
|
860 |
+
deformable_group);
|
861 |
+
|
862 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
863 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
864 |
+
|
865 |
+
const int batch = input.size(0);
|
866 |
+
const int channels = input.size(1);
|
867 |
+
const int height = input.size(2);
|
868 |
+
const int width = input.size(3);
|
869 |
+
|
870 |
+
const int channels_out = weight.size(0);
|
871 |
+
const int channels_kernel = weight.size(1);
|
872 |
+
const int kernel_h_ = weight.size(2);
|
873 |
+
const int kernel_w_ = weight.size(3);
|
874 |
+
|
875 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
876 |
+
AT_ERROR(
|
877 |
+
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
878 |
+
kernel_h_,
|
879 |
+
kernel_w,
|
880 |
+
kernel_h_,
|
881 |
+
kernel_w_);
|
882 |
+
if (channels != channels_kernel * group)
|
883 |
+
AT_ERROR(
|
884 |
+
"Input shape and kernel channels wont match: (%d vs %d).",
|
885 |
+
channels,
|
886 |
+
channels_kernel * group);
|
887 |
+
|
888 |
+
const int height_out =
|
889 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
890 |
+
const int width_out =
|
891 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
892 |
+
|
893 |
+
// mask shape check
|
894 |
+
TORCH_CHECK(
|
895 |
+
(mask.size(2) == height_out && mask.size(3) == width_out),
|
896 |
+
"invalid spatial size of mask, expected height: %d width: %d, but "
|
897 |
+
"got height: %d width: %d",
|
898 |
+
height_out,
|
899 |
+
width_out,
|
900 |
+
mask.size(2),
|
901 |
+
mask.size(3));
|
902 |
+
|
903 |
+
TORCH_CHECK(
|
904 |
+
(mask.size(1) == deformable_group * kernel_h * kernel_w),
|
905 |
+
"invalid number of channels of mask");
|
906 |
+
|
907 |
+
if (ones.ndimension() != 2 ||
|
908 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
909 |
+
// Resize plane and fill with ones...
|
910 |
+
ones = at::ones({height_out, width_out}, input.options());
|
911 |
+
}
|
912 |
+
|
913 |
+
// resize output
|
914 |
+
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
915 |
+
// resize temporary columns
|
916 |
+
columns = at::zeros(
|
917 |
+
{channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
918 |
+
input.options());
|
919 |
+
|
920 |
+
output = output.view(
|
921 |
+
{output.size(0),
|
922 |
+
group,
|
923 |
+
output.size(1) / group,
|
924 |
+
output.size(2),
|
925 |
+
output.size(3)});
|
926 |
+
|
927 |
+
for (int b = 0; b < batch; b++) {
|
928 |
+
modulated_deformable_im2col_cuda(
|
929 |
+
input[b],
|
930 |
+
offset[b],
|
931 |
+
mask[b],
|
932 |
+
1,
|
933 |
+
channels,
|
934 |
+
height,
|
935 |
+
width,
|
936 |
+
height_out,
|
937 |
+
width_out,
|
938 |
+
kernel_h,
|
939 |
+
kernel_w,
|
940 |
+
pad_h,
|
941 |
+
pad_w,
|
942 |
+
stride_h,
|
943 |
+
stride_w,
|
944 |
+
dilation_h,
|
945 |
+
dilation_w,
|
946 |
+
deformable_group,
|
947 |
+
columns);
|
948 |
+
|
949 |
+
// divide into group
|
950 |
+
weight = weight.view(
|
951 |
+
{group,
|
952 |
+
weight.size(0) / group,
|
953 |
+
weight.size(1),
|
954 |
+
weight.size(2),
|
955 |
+
weight.size(3)});
|
956 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
957 |
+
|
958 |
+
for (int g = 0; g < group; g++) {
|
959 |
+
output[b][g] = output[b][g]
|
960 |
+
.flatten(1)
|
961 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
962 |
+
.view_as(output[b][g]);
|
963 |
+
}
|
964 |
+
|
965 |
+
weight = weight.view(
|
966 |
+
{weight.size(0) * weight.size(1),
|
967 |
+
weight.size(2),
|
968 |
+
weight.size(3),
|
969 |
+
weight.size(4)});
|
970 |
+
columns =
|
971 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
972 |
+
}
|
973 |
+
|
974 |
+
output = output.view(
|
975 |
+
{output.size(0),
|
976 |
+
output.size(1) * output.size(2),
|
977 |
+
output.size(3),
|
978 |
+
output.size(4)});
|
979 |
+
|
980 |
+
if (with_bias) {
|
981 |
+
output += bias.view({1, bias.size(0), 1, 1});
|
982 |
+
}
|
983 |
+
}
|
984 |
+
|
985 |
+
void modulated_deform_conv_cuda_backward(
|
986 |
+
at::Tensor input,
|
987 |
+
at::Tensor weight,
|
988 |
+
at::Tensor bias,
|
989 |
+
at::Tensor ones,
|
990 |
+
at::Tensor offset,
|
991 |
+
at::Tensor mask,
|
992 |
+
at::Tensor columns,
|
993 |
+
at::Tensor grad_input,
|
994 |
+
at::Tensor grad_weight,
|
995 |
+
at::Tensor grad_bias,
|
996 |
+
at::Tensor grad_offset,
|
997 |
+
at::Tensor grad_mask,
|
998 |
+
at::Tensor grad_output,
|
999 |
+
int kernel_h,
|
1000 |
+
int kernel_w,
|
1001 |
+
int stride_h,
|
1002 |
+
int stride_w,
|
1003 |
+
int pad_h,
|
1004 |
+
int pad_w,
|
1005 |
+
int dilation_h,
|
1006 |
+
int dilation_w,
|
1007 |
+
int group,
|
1008 |
+
int deformable_group,
|
1009 |
+
const bool with_bias) {
|
1010 |
+
shape_check(
|
1011 |
+
input,
|
1012 |
+
offset,
|
1013 |
+
&grad_output,
|
1014 |
+
weight,
|
1015 |
+
kernel_h,
|
1016 |
+
kernel_w,
|
1017 |
+
stride_h,
|
1018 |
+
stride_w,
|
1019 |
+
pad_h,
|
1020 |
+
pad_w,
|
1021 |
+
dilation_h,
|
1022 |
+
dilation_w,
|
1023 |
+
group,
|
1024 |
+
deformable_group);
|
1025 |
+
|
1026 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
1027 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
1028 |
+
|
1029 |
+
const int batch = input.size(0);
|
1030 |
+
const int channels = input.size(1);
|
1031 |
+
const int height = input.size(2);
|
1032 |
+
const int width = input.size(3);
|
1033 |
+
|
1034 |
+
const int channels_kernel = weight.size(1);
|
1035 |
+
const int kernel_h_ = weight.size(2);
|
1036 |
+
const int kernel_w_ = weight.size(3);
|
1037 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
1038 |
+
AT_ERROR(
|
1039 |
+
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
1040 |
+
kernel_h_,
|
1041 |
+
kernel_w,
|
1042 |
+
kernel_h_,
|
1043 |
+
kernel_w_);
|
1044 |
+
if (channels != channels_kernel * group)
|
1045 |
+
AT_ERROR(
|
1046 |
+
"Input shape and kernel channels wont match: (%d vs %d).",
|
1047 |
+
channels,
|
1048 |
+
channels_kernel * group);
|
1049 |
+
|
1050 |
+
const int height_out =
|
1051 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
1052 |
+
const int width_out =
|
1053 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
1054 |
+
|
1055 |
+
// mask shape check
|
1056 |
+
TORCH_CHECK(
|
1057 |
+
(mask.size(2) == height_out && mask.size(3) == width_out),
|
1058 |
+
"invalid spatial size of mask, expected height: %d width: %d, but "
|
1059 |
+
"got height: %d width: %d",
|
1060 |
+
height_out,
|
1061 |
+
width_out,
|
1062 |
+
mask.size(2),
|
1063 |
+
mask.size(3));
|
1064 |
+
|
1065 |
+
TORCH_CHECK(
|
1066 |
+
(mask.size(1) == deformable_group * kernel_h * kernel_w),
|
1067 |
+
"invalid number of channels of mask");
|
1068 |
+
|
1069 |
+
if (ones.ndimension() != 2 ||
|
1070 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
1071 |
+
// Resize plane and fill with ones...
|
1072 |
+
ones = at::ones({height_out, width_out}, input.options());
|
1073 |
+
}
|
1074 |
+
|
1075 |
+
grad_input = grad_input.view({batch, channels, height, width});
|
1076 |
+
columns = at::zeros(
|
1077 |
+
{channels * kernel_h * kernel_w, height_out * width_out},
|
1078 |
+
input.options());
|
1079 |
+
|
1080 |
+
grad_output = grad_output.view(
|
1081 |
+
{grad_output.size(0),
|
1082 |
+
group,
|
1083 |
+
grad_output.size(1) / group,
|
1084 |
+
grad_output.size(2),
|
1085 |
+
grad_output.size(3)});
|
1086 |
+
|
1087 |
+
for (int b = 0; b < batch; b++) {
|
1088 |
+
// divide int group
|
1089 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
1090 |
+
weight = weight.view(
|
1091 |
+
{group,
|
1092 |
+
weight.size(0) / group,
|
1093 |
+
weight.size(1),
|
1094 |
+
weight.size(2),
|
1095 |
+
weight.size(3)});
|
1096 |
+
|
1097 |
+
for (int g = 0; g < group; g++) {
|
1098 |
+
columns[g].addmm_(
|
1099 |
+
weight[g].flatten(1).transpose(0, 1),
|
1100 |
+
grad_output[b][g].flatten(1),
|
1101 |
+
0.0f,
|
1102 |
+
1.0f);
|
1103 |
+
}
|
1104 |
+
|
1105 |
+
columns =
|
1106 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
1107 |
+
weight = weight.view(
|
1108 |
+
{weight.size(0) * weight.size(1),
|
1109 |
+
weight.size(2),
|
1110 |
+
weight.size(3),
|
1111 |
+
weight.size(4)});
|
1112 |
+
|
1113 |
+
// gradient w.r.t. input coordinate data
|
1114 |
+
modulated_deformable_col2im_coord_cuda(
|
1115 |
+
columns,
|
1116 |
+
input[b],
|
1117 |
+
offset[b],
|
1118 |
+
mask[b],
|
1119 |
+
1,
|
1120 |
+
channels,
|
1121 |
+
height,
|
1122 |
+
width,
|
1123 |
+
height_out,
|
1124 |
+
width_out,
|
1125 |
+
kernel_h,
|
1126 |
+
kernel_w,
|
1127 |
+
pad_h,
|
1128 |
+
pad_w,
|
1129 |
+
stride_h,
|
1130 |
+
stride_w,
|
1131 |
+
dilation_h,
|
1132 |
+
dilation_w,
|
1133 |
+
deformable_group,
|
1134 |
+
grad_offset[b],
|
1135 |
+
grad_mask[b]);
|
1136 |
+
// gradient w.r.t. input data
|
1137 |
+
modulated_deformable_col2im_cuda(
|
1138 |
+
columns,
|
1139 |
+
offset[b],
|
1140 |
+
mask[b],
|
1141 |
+
1,
|
1142 |
+
channels,
|
1143 |
+
height,
|
1144 |
+
width,
|
1145 |
+
height_out,
|
1146 |
+
width_out,
|
1147 |
+
kernel_h,
|
1148 |
+
kernel_w,
|
1149 |
+
pad_h,
|
1150 |
+
pad_w,
|
1151 |
+
stride_h,
|
1152 |
+
stride_w,
|
1153 |
+
dilation_h,
|
1154 |
+
dilation_w,
|
1155 |
+
deformable_group,
|
1156 |
+
grad_input[b]);
|
1157 |
+
|
1158 |
+
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
1159 |
+
// group
|
1160 |
+
modulated_deformable_im2col_cuda(
|
1161 |
+
input[b],
|
1162 |
+
offset[b],
|
1163 |
+
mask[b],
|
1164 |
+
1,
|
1165 |
+
channels,
|
1166 |
+
height,
|
1167 |
+
width,
|
1168 |
+
height_out,
|
1169 |
+
width_out,
|
1170 |
+
kernel_h,
|
1171 |
+
kernel_w,
|
1172 |
+
pad_h,
|
1173 |
+
pad_w,
|
1174 |
+
stride_h,
|
1175 |
+
stride_w,
|
1176 |
+
dilation_h,
|
1177 |
+
dilation_w,
|
1178 |
+
deformable_group,
|
1179 |
+
columns);
|
1180 |
+
|
1181 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
1182 |
+
grad_weight = grad_weight.view(
|
1183 |
+
{group,
|
1184 |
+
grad_weight.size(0) / group,
|
1185 |
+
grad_weight.size(1),
|
1186 |
+
grad_weight.size(2),
|
1187 |
+
grad_weight.size(3)});
|
1188 |
+
if (with_bias)
|
1189 |
+
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
1190 |
+
|
1191 |
+
for (int g = 0; g < group; g++) {
|
1192 |
+
grad_weight[g] =
|
1193 |
+
grad_weight[g]
|
1194 |
+
.flatten(1)
|
1195 |
+
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
1196 |
+
.view_as(grad_weight[g]);
|
1197 |
+
if (with_bias) {
|
1198 |
+
grad_bias[g] =
|
1199 |
+
grad_bias[g]
|
1200 |
+
.view({-1, 1})
|
1201 |
+
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
1202 |
+
.view(-1);
|
1203 |
+
}
|
1204 |
+
}
|
1205 |
+
|
1206 |
+
columns =
|
1207 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
1208 |
+
grad_weight = grad_weight.view(
|
1209 |
+
{grad_weight.size(0) * grad_weight.size(1),
|
1210 |
+
grad_weight.size(2),
|
1211 |
+
grad_weight.size(3),
|
1212 |
+
grad_weight.size(4)});
|
1213 |
+
if (with_bias)
|
1214 |
+
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
1215 |
+
}
|
1216 |
+
grad_output = grad_output.view(
|
1217 |
+
{grad_output.size(0) * grad_output.size(1),
|
1218 |
+
grad_output.size(2),
|
1219 |
+
grad_output.size(3),
|
1220 |
+
grad_output.size(4)});
|
1221 |
+
}
|
1222 |
+
|
1223 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu
ADDED
@@ -0,0 +1,1288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
// modified from
|
4 |
+
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
|
5 |
+
// Original license: Apache 2.0
|
6 |
+
// clang-format off
|
7 |
+
|
8 |
+
// modify from
|
9 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
|
10 |
+
|
11 |
+
/*!
|
12 |
+
******************* BEGIN Caffe Copyright Notice and Disclaimer *****************
|
13 |
+
*
|
14 |
+
* COPYRIGHT
|
15 |
+
*
|
16 |
+
* All contributions by the University of California:
|
17 |
+
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
18 |
+
* All rights reserved.
|
19 |
+
*
|
20 |
+
* All other contributions:
|
21 |
+
* Copyright (c) 2014-2017, the respective contributors
|
22 |
+
* All rights reserved.
|
23 |
+
*
|
24 |
+
* Caffe uses a shared copyright model: each contributor holds copyright over
|
25 |
+
* their contributions to Caffe. The project versioning records all such
|
26 |
+
* contribution and copyright details. If a contributor wants to further mark
|
27 |
+
* their specific copyright on a particular contribution, they should indicate
|
28 |
+
* their copyright solely in the commit message of the change when it is
|
29 |
+
* committed.
|
30 |
+
*
|
31 |
+
* LICENSE
|
32 |
+
*
|
33 |
+
* Redistribution and use in source and binary forms, with or without
|
34 |
+
* modification, are permitted provided that the following conditions are met:
|
35 |
+
*
|
36 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
37 |
+
* list of conditions and the following disclaimer.
|
38 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
39 |
+
* this list of conditions and the following disclaimer in the documentation
|
40 |
+
* and/or other materials provided with the distribution.
|
41 |
+
*
|
42 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
43 |
+
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
44 |
+
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
45 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
46 |
+
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
47 |
+
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
48 |
+
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
49 |
+
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
50 |
+
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
51 |
+
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
52 |
+
*
|
53 |
+
* CONTRIBUTION AGREEMENT
|
54 |
+
*
|
55 |
+
* By contributing to the BVLC/caffe repository through pull-request, comment,
|
56 |
+
* or otherwise, the contributor releases their content to the
|
57 |
+
* license and copyright terms herein.
|
58 |
+
*
|
59 |
+
***************** END Caffe Copyright Notice and Disclaimer *********************
|
60 |
+
*
|
61 |
+
* Copyright (c) 2018 Microsoft
|
62 |
+
* Licensed under The MIT License [see LICENSE for details]
|
63 |
+
* \file modulated_deformable_im2col.cuh
|
64 |
+
* \brief Function definitions of converting an image to
|
65 |
+
* column matrix based on kernel, padding, dilation, and offset.
|
66 |
+
* These functions are mainly used in deformable convolution operators.
|
67 |
+
* \ref: https://arxiv.org/abs/1703.06211
|
68 |
+
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
|
69 |
+
*/
|
70 |
+
|
71 |
+
#include <ATen/ATen.h>
|
72 |
+
#include <c10/cuda/CUDAGuard.h>
|
73 |
+
#include <float.h>
|
74 |
+
#include <math.h>
|
75 |
+
#include <stdio.h>
|
76 |
+
#include <THC/THCAtomics.cuh>
|
77 |
+
|
78 |
+
using namespace at;
|
79 |
+
|
80 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
81 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
82 |
+
i += blockDim.x * gridDim.x)
|
83 |
+
|
84 |
+
|
85 |
+
namespace {
|
86 |
+
|
87 |
+
const int CUDA_NUM_THREADS = 1024;
|
88 |
+
const int kMaxGridNum = 65535;
|
89 |
+
|
90 |
+
inline int GET_BLOCKS(const int N) {
|
91 |
+
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
|
92 |
+
}
|
93 |
+
|
94 |
+
}
|
95 |
+
|
96 |
+
template <typename scalar_t>
|
97 |
+
__device__ scalar_t deformable_im2col_bilinear(
|
98 |
+
const scalar_t* bottom_data,
|
99 |
+
const int data_width,
|
100 |
+
const int height,
|
101 |
+
const int width,
|
102 |
+
scalar_t h,
|
103 |
+
scalar_t w) {
|
104 |
+
int h_low = floor(h);
|
105 |
+
int w_low = floor(w);
|
106 |
+
int h_high = h_low + 1;
|
107 |
+
int w_high = w_low + 1;
|
108 |
+
|
109 |
+
scalar_t lh = h - h_low;
|
110 |
+
scalar_t lw = w - w_low;
|
111 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
112 |
+
|
113 |
+
scalar_t v1 = 0;
|
114 |
+
if (h_low >= 0 && w_low >= 0)
|
115 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
116 |
+
scalar_t v2 = 0;
|
117 |
+
if (h_low >= 0 && w_high <= width - 1)
|
118 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
119 |
+
scalar_t v3 = 0;
|
120 |
+
if (h_high <= height - 1 && w_low >= 0)
|
121 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
122 |
+
scalar_t v4 = 0;
|
123 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
124 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
125 |
+
|
126 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
127 |
+
|
128 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
129 |
+
return val;
|
130 |
+
}
|
131 |
+
|
132 |
+
template <typename scalar_t>
|
133 |
+
__device__ scalar_t get_gradient_weight(
|
134 |
+
scalar_t argmax_h,
|
135 |
+
scalar_t argmax_w,
|
136 |
+
const int h,
|
137 |
+
const int w,
|
138 |
+
const int height,
|
139 |
+
const int width) {
|
140 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
141 |
+
argmax_w >= width) {
|
142 |
+
// empty
|
143 |
+
return 0;
|
144 |
+
}
|
145 |
+
|
146 |
+
int argmax_h_low = floor(argmax_h);
|
147 |
+
int argmax_w_low = floor(argmax_w);
|
148 |
+
int argmax_h_high = argmax_h_low + 1;
|
149 |
+
int argmax_w_high = argmax_w_low + 1;
|
150 |
+
|
151 |
+
scalar_t weight = 0;
|
152 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
153 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
154 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
155 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
156 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
157 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
158 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
159 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
160 |
+
return weight;
|
161 |
+
}
|
162 |
+
|
163 |
+
template <typename scalar_t>
|
164 |
+
__device__ scalar_t get_coordinate_weight(
|
165 |
+
scalar_t argmax_h,
|
166 |
+
scalar_t argmax_w,
|
167 |
+
const int height,
|
168 |
+
const int width,
|
169 |
+
const scalar_t* im_data,
|
170 |
+
const int data_width,
|
171 |
+
const int bp_dir) {
|
172 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
173 |
+
argmax_w >= width) {
|
174 |
+
// empty
|
175 |
+
return 0;
|
176 |
+
}
|
177 |
+
|
178 |
+
int argmax_h_low = floor(argmax_h);
|
179 |
+
int argmax_w_low = floor(argmax_w);
|
180 |
+
int argmax_h_high = argmax_h_low + 1;
|
181 |
+
int argmax_w_high = argmax_w_low + 1;
|
182 |
+
|
183 |
+
scalar_t weight = 0;
|
184 |
+
|
185 |
+
if (bp_dir == 0) {
|
186 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
187 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) *
|
188 |
+
im_data[argmax_h_low * data_width + argmax_w_low];
|
189 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
190 |
+
weight += -1 * (argmax_w - argmax_w_low) *
|
191 |
+
im_data[argmax_h_low * data_width + argmax_w_high];
|
192 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
193 |
+
weight += (argmax_w_low + 1 - argmax_w) *
|
194 |
+
im_data[argmax_h_high * data_width + argmax_w_low];
|
195 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
196 |
+
weight += (argmax_w - argmax_w_low) *
|
197 |
+
im_data[argmax_h_high * data_width + argmax_w_high];
|
198 |
+
} else if (bp_dir == 1) {
|
199 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
200 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) *
|
201 |
+
im_data[argmax_h_low * data_width + argmax_w_low];
|
202 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
203 |
+
weight += (argmax_h_low + 1 - argmax_h) *
|
204 |
+
im_data[argmax_h_low * data_width + argmax_w_high];
|
205 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
206 |
+
weight += -1 * (argmax_h - argmax_h_low) *
|
207 |
+
im_data[argmax_h_high * data_width + argmax_w_low];
|
208 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
209 |
+
weight += (argmax_h - argmax_h_low) *
|
210 |
+
im_data[argmax_h_high * data_width + argmax_w_high];
|
211 |
+
}
|
212 |
+
|
213 |
+
return weight;
|
214 |
+
}
|
215 |
+
|
216 |
+
template <typename scalar_t>
|
217 |
+
__global__ void deformable_im2col_gpu_kernel(
|
218 |
+
const int n,
|
219 |
+
const scalar_t* data_im,
|
220 |
+
const scalar_t* data_offset,
|
221 |
+
const int height,
|
222 |
+
const int width,
|
223 |
+
const int kernel_h,
|
224 |
+
const int kernel_w,
|
225 |
+
const int pad_h,
|
226 |
+
const int pad_w,
|
227 |
+
const int stride_h,
|
228 |
+
const int stride_w,
|
229 |
+
const int dilation_h,
|
230 |
+
const int dilation_w,
|
231 |
+
const int channel_per_deformable_group,
|
232 |
+
const int batch_size,
|
233 |
+
const int num_channels,
|
234 |
+
const int deformable_group,
|
235 |
+
const int height_col,
|
236 |
+
const int width_col,
|
237 |
+
scalar_t* data_col) {
|
238 |
+
CUDA_KERNEL_LOOP(index, n) {
|
239 |
+
// index index of output matrix
|
240 |
+
const int w_col = index % width_col;
|
241 |
+
const int h_col = (index / width_col) % height_col;
|
242 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
243 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
244 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
245 |
+
|
246 |
+
// compute deformable group index
|
247 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
248 |
+
|
249 |
+
const int h_in = h_col * stride_h - pad_h;
|
250 |
+
const int w_in = w_col * stride_w - pad_w;
|
251 |
+
scalar_t* data_col_ptr = data_col +
|
252 |
+
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
253 |
+
// const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) *
|
254 |
+
// height + h_in) * width + w_in;
|
255 |
+
const scalar_t* data_im_ptr =
|
256 |
+
data_im + (b_col * num_channels + c_im) * height * width;
|
257 |
+
const scalar_t* data_offset_ptr = data_offset +
|
258 |
+
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
|
259 |
+
kernel_w * height_col * width_col;
|
260 |
+
|
261 |
+
for (int i = 0; i < kernel_h; ++i) {
|
262 |
+
for (int j = 0; j < kernel_w; ++j) {
|
263 |
+
const int data_offset_h_ptr =
|
264 |
+
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
265 |
+
const int data_offset_w_ptr =
|
266 |
+
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
|
267 |
+
w_col;
|
268 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
269 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
270 |
+
scalar_t val = static_cast<scalar_t>(0);
|
271 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
272 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
273 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
|
274 |
+
// const scalar_t map_h = i * dilation_h + offset_h;
|
275 |
+
// const scalar_t map_w = j * dilation_w + offset_w;
|
276 |
+
// const int cur_height = height - h_in;
|
277 |
+
// const int cur_width = width - w_in;
|
278 |
+
// val = deformable_im2col_bilinear(data_im_ptr, width, cur_height,
|
279 |
+
// cur_width, map_h, map_w);
|
280 |
+
val = deformable_im2col_bilinear(
|
281 |
+
data_im_ptr, width, height, width, h_im, w_im);
|
282 |
+
}
|
283 |
+
*data_col_ptr = val;
|
284 |
+
data_col_ptr += batch_size * height_col * width_col;
|
285 |
+
}
|
286 |
+
}
|
287 |
+
}
|
288 |
+
}
|
289 |
+
|
290 |
+
|
291 |
+
template <typename scalar_t>
|
292 |
+
__global__ void deformable_col2im_gpu_kernel(
|
293 |
+
const int n,
|
294 |
+
const scalar_t* data_col,
|
295 |
+
const scalar_t* data_offset,
|
296 |
+
const int channels,
|
297 |
+
const int height,
|
298 |
+
const int width,
|
299 |
+
const int kernel_h,
|
300 |
+
const int kernel_w,
|
301 |
+
const int pad_h,
|
302 |
+
const int pad_w,
|
303 |
+
const int stride_h,
|
304 |
+
const int stride_w,
|
305 |
+
const int dilation_h,
|
306 |
+
const int dilation_w,
|
307 |
+
const int channel_per_deformable_group,
|
308 |
+
const int batch_size,
|
309 |
+
const int deformable_group,
|
310 |
+
const int height_col,
|
311 |
+
const int width_col,
|
312 |
+
scalar_t* grad_im) {
|
313 |
+
CUDA_KERNEL_LOOP(index, n) {
|
314 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
315 |
+
const int i =
|
316 |
+
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
317 |
+
const int c =
|
318 |
+
index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
319 |
+
// compute the start and end of the output
|
320 |
+
|
321 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
322 |
+
|
323 |
+
int w_out = index % width_col;
|
324 |
+
int h_out = (index / width_col) % height_col;
|
325 |
+
int b = (index / width_col / height_col) % batch_size;
|
326 |
+
int w_in = w_out * stride_w - pad_w;
|
327 |
+
int h_in = h_out * stride_h - pad_h;
|
328 |
+
|
329 |
+
const scalar_t* data_offset_ptr = data_offset +
|
330 |
+
(b * deformable_group + deformable_group_index) * 2 * kernel_h *
|
331 |
+
kernel_w * height_col * width_col;
|
332 |
+
const int data_offset_h_ptr =
|
333 |
+
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
334 |
+
const int data_offset_w_ptr =
|
335 |
+
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
336 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
337 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
338 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
339 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
340 |
+
|
341 |
+
const scalar_t cur_top_grad = data_col[index];
|
342 |
+
const int cur_h = (int)cur_inv_h_data;
|
343 |
+
const int cur_w = (int)cur_inv_w_data;
|
344 |
+
for (int dy = -2; dy <= 2; dy++) {
|
345 |
+
for (int dx = -2; dx <= 2; dx++) {
|
346 |
+
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
|
347 |
+
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
348 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
|
349 |
+
int cur_bottom_grad_pos =
|
350 |
+
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
351 |
+
scalar_t weight = get_gradient_weight(
|
352 |
+
cur_inv_h_data,
|
353 |
+
cur_inv_w_data,
|
354 |
+
cur_h + dy,
|
355 |
+
cur_w + dx,
|
356 |
+
height,
|
357 |
+
width);
|
358 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
359 |
+
}
|
360 |
+
}
|
361 |
+
}
|
362 |
+
}
|
363 |
+
}
|
364 |
+
|
365 |
+
|
366 |
+
template <typename scalar_t>
|
367 |
+
__global__ void deformable_col2im_coord_gpu_kernel(
|
368 |
+
const int n,
|
369 |
+
const scalar_t* data_col,
|
370 |
+
const scalar_t* data_im,
|
371 |
+
const scalar_t* data_offset,
|
372 |
+
const int channels,
|
373 |
+
const int height,
|
374 |
+
const int width,
|
375 |
+
const int kernel_h,
|
376 |
+
const int kernel_w,
|
377 |
+
const int pad_h,
|
378 |
+
const int pad_w,
|
379 |
+
const int stride_h,
|
380 |
+
const int stride_w,
|
381 |
+
const int dilation_h,
|
382 |
+
const int dilation_w,
|
383 |
+
const int channel_per_deformable_group,
|
384 |
+
const int batch_size,
|
385 |
+
const int offset_channels,
|
386 |
+
const int deformable_group,
|
387 |
+
const int height_col,
|
388 |
+
const int width_col,
|
389 |
+
scalar_t* grad_offset) {
|
390 |
+
CUDA_KERNEL_LOOP(index, n) {
|
391 |
+
scalar_t val = 0;
|
392 |
+
int w = index % width_col;
|
393 |
+
int h = (index / width_col) % height_col;
|
394 |
+
int c = (index / width_col / height_col) % offset_channels;
|
395 |
+
int b = (index / width_col / height_col) / offset_channels;
|
396 |
+
// compute the start and end of the output
|
397 |
+
|
398 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
399 |
+
const int col_step = kernel_h * kernel_w;
|
400 |
+
int cnt = 0;
|
401 |
+
const scalar_t* data_col_ptr = data_col +
|
402 |
+
deformable_group_index * channel_per_deformable_group * batch_size *
|
403 |
+
width_col * height_col;
|
404 |
+
const scalar_t* data_im_ptr = data_im +
|
405 |
+
(b * deformable_group + deformable_group_index) *
|
406 |
+
channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
407 |
+
const scalar_t* data_offset_ptr = data_offset +
|
408 |
+
(b * deformable_group + deformable_group_index) * 2 * kernel_h *
|
409 |
+
kernel_w * height_col * width_col;
|
410 |
+
|
411 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
412 |
+
|
413 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
|
414 |
+
col_c += col_step) {
|
415 |
+
const int col_pos =
|
416 |
+
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
417 |
+
const int bp_dir = offset_c % 2;
|
418 |
+
|
419 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
420 |
+
int i =
|
421 |
+
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
422 |
+
int w_out = col_pos % width_col;
|
423 |
+
int h_out = (col_pos / width_col) % height_col;
|
424 |
+
int w_in = w_out * stride_w - pad_w;
|
425 |
+
int h_in = h_out * stride_h - pad_h;
|
426 |
+
const int data_offset_h_ptr =
|
427 |
+
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
428 |
+
const int data_offset_w_ptr =
|
429 |
+
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
|
430 |
+
w_out);
|
431 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
432 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
433 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
434 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
435 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
|
436 |
+
inv_h = inv_w = -2;
|
437 |
+
}
|
438 |
+
const scalar_t weight = get_coordinate_weight(
|
439 |
+
inv_h,
|
440 |
+
inv_w,
|
441 |
+
height,
|
442 |
+
width,
|
443 |
+
data_im_ptr + cnt * height * width,
|
444 |
+
width,
|
445 |
+
bp_dir);
|
446 |
+
val += weight * data_col_ptr[col_pos];
|
447 |
+
cnt += 1;
|
448 |
+
}
|
449 |
+
|
450 |
+
grad_offset[index] = val;
|
451 |
+
}
|
452 |
+
}
|
453 |
+
|
454 |
+
|
455 |
+
namespace detectron2 {
|
456 |
+
|
457 |
+
void deformable_im2col(
|
458 |
+
const at::Tensor data_im,
|
459 |
+
const at::Tensor data_offset,
|
460 |
+
const int channels,
|
461 |
+
const int height,
|
462 |
+
const int width,
|
463 |
+
const int ksize_h,
|
464 |
+
const int ksize_w,
|
465 |
+
const int pad_h,
|
466 |
+
const int pad_w,
|
467 |
+
const int stride_h,
|
468 |
+
const int stride_w,
|
469 |
+
const int dilation_h,
|
470 |
+
const int dilation_w,
|
471 |
+
const int parallel_imgs,
|
472 |
+
const int deformable_group,
|
473 |
+
at::Tensor data_col) {
|
474 |
+
// num_axes should be smaller than block size
|
475 |
+
// todo: check parallel_imgs is correctly passed in
|
476 |
+
int height_col =
|
477 |
+
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
478 |
+
int width_col =
|
479 |
+
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
480 |
+
int num_kernels = channels * height_col * width_col * parallel_imgs;
|
481 |
+
int channel_per_deformable_group = channels / deformable_group;
|
482 |
+
|
483 |
+
at::cuda::CUDAGuard device_guard(data_im.device());
|
484 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
485 |
+
|
486 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
487 |
+
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
|
488 |
+
const scalar_t* data_im_ = data_im.data_ptr<scalar_t>();
|
489 |
+
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
|
490 |
+
scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
|
491 |
+
|
492 |
+
deformable_im2col_gpu_kernel<<<
|
493 |
+
GET_BLOCKS(num_kernels),
|
494 |
+
CUDA_NUM_THREADS,
|
495 |
+
0,
|
496 |
+
stream>>>(
|
497 |
+
num_kernels,
|
498 |
+
data_im_,
|
499 |
+
data_offset_,
|
500 |
+
height,
|
501 |
+
width,
|
502 |
+
ksize_h,
|
503 |
+
ksize_w,
|
504 |
+
pad_h,
|
505 |
+
pad_w,
|
506 |
+
stride_h,
|
507 |
+
stride_w,
|
508 |
+
dilation_h,
|
509 |
+
dilation_w,
|
510 |
+
channel_per_deformable_group,
|
511 |
+
parallel_imgs,
|
512 |
+
channels,
|
513 |
+
deformable_group,
|
514 |
+
height_col,
|
515 |
+
width_col,
|
516 |
+
data_col_);
|
517 |
+
}));
|
518 |
+
|
519 |
+
cudaError_t err = cudaGetLastError();
|
520 |
+
if (err != cudaSuccess) {
|
521 |
+
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
|
522 |
+
}
|
523 |
+
}
|
524 |
+
|
525 |
+
|
526 |
+
void deformable_col2im(
|
527 |
+
const at::Tensor data_col,
|
528 |
+
const at::Tensor data_offset,
|
529 |
+
const int channels,
|
530 |
+
const int height,
|
531 |
+
const int width,
|
532 |
+
const int ksize_h,
|
533 |
+
const int ksize_w,
|
534 |
+
const int pad_h,
|
535 |
+
const int pad_w,
|
536 |
+
const int stride_h,
|
537 |
+
const int stride_w,
|
538 |
+
const int dilation_h,
|
539 |
+
const int dilation_w,
|
540 |
+
const int parallel_imgs,
|
541 |
+
const int deformable_group,
|
542 |
+
at::Tensor grad_im) {
|
543 |
+
// todo: make sure parallel_imgs is passed in correctly
|
544 |
+
int height_col =
|
545 |
+
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
546 |
+
int width_col =
|
547 |
+
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
548 |
+
int num_kernels =
|
549 |
+
channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
|
550 |
+
int channel_per_deformable_group = channels / deformable_group;
|
551 |
+
|
552 |
+
at::cuda::CUDAGuard device_guard(data_col.device());
|
553 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
554 |
+
|
555 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
556 |
+
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
|
557 |
+
const scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
|
558 |
+
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
|
559 |
+
scalar_t* grad_im_ = grad_im.data_ptr<scalar_t>();
|
560 |
+
|
561 |
+
deformable_col2im_gpu_kernel<<<
|
562 |
+
GET_BLOCKS(num_kernels),
|
563 |
+
CUDA_NUM_THREADS,
|
564 |
+
0,
|
565 |
+
stream>>>(
|
566 |
+
num_kernels,
|
567 |
+
data_col_,
|
568 |
+
data_offset_,
|
569 |
+
channels,
|
570 |
+
height,
|
571 |
+
width,
|
572 |
+
ksize_h,
|
573 |
+
ksize_w,
|
574 |
+
pad_h,
|
575 |
+
pad_w,
|
576 |
+
stride_h,
|
577 |
+
stride_w,
|
578 |
+
dilation_h,
|
579 |
+
dilation_w,
|
580 |
+
channel_per_deformable_group,
|
581 |
+
parallel_imgs,
|
582 |
+
deformable_group,
|
583 |
+
height_col,
|
584 |
+
width_col,
|
585 |
+
grad_im_);
|
586 |
+
}));
|
587 |
+
|
588 |
+
cudaError_t err = cudaGetLastError();
|
589 |
+
if (err != cudaSuccess) {
|
590 |
+
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
|
591 |
+
}
|
592 |
+
}
|
593 |
+
|
594 |
+
|
595 |
+
void deformable_col2im_coord(
|
596 |
+
const at::Tensor data_col,
|
597 |
+
const at::Tensor data_im,
|
598 |
+
const at::Tensor data_offset,
|
599 |
+
const int channels,
|
600 |
+
const int height,
|
601 |
+
const int width,
|
602 |
+
const int ksize_h,
|
603 |
+
const int ksize_w,
|
604 |
+
const int pad_h,
|
605 |
+
const int pad_w,
|
606 |
+
const int stride_h,
|
607 |
+
const int stride_w,
|
608 |
+
const int dilation_h,
|
609 |
+
const int dilation_w,
|
610 |
+
const int parallel_imgs,
|
611 |
+
const int deformable_group,
|
612 |
+
at::Tensor grad_offset) {
|
613 |
+
int height_col =
|
614 |
+
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
615 |
+
int width_col =
|
616 |
+
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
617 |
+
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w *
|
618 |
+
deformable_group * parallel_imgs;
|
619 |
+
int channel_per_deformable_group =
|
620 |
+
channels * ksize_h * ksize_w / deformable_group;
|
621 |
+
|
622 |
+
at::cuda::CUDAGuard device_guard(data_col.device());
|
623 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
624 |
+
|
625 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
626 |
+
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
|
627 |
+
const scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
|
628 |
+
const scalar_t* data_im_ = data_im.data_ptr<scalar_t>();
|
629 |
+
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
|
630 |
+
scalar_t* grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
631 |
+
|
632 |
+
deformable_col2im_coord_gpu_kernel<<<
|
633 |
+
GET_BLOCKS(num_kernels),
|
634 |
+
CUDA_NUM_THREADS,
|
635 |
+
0,
|
636 |
+
stream>>>(
|
637 |
+
num_kernels,
|
638 |
+
data_col_,
|
639 |
+
data_im_,
|
640 |
+
data_offset_,
|
641 |
+
channels,
|
642 |
+
height,
|
643 |
+
width,
|
644 |
+
ksize_h,
|
645 |
+
ksize_w,
|
646 |
+
pad_h,
|
647 |
+
pad_w,
|
648 |
+
stride_h,
|
649 |
+
stride_w,
|
650 |
+
dilation_h,
|
651 |
+
dilation_w,
|
652 |
+
channel_per_deformable_group,
|
653 |
+
parallel_imgs,
|
654 |
+
2 * ksize_h * ksize_w * deformable_group,
|
655 |
+
deformable_group,
|
656 |
+
height_col,
|
657 |
+
width_col,
|
658 |
+
grad_offset_);
|
659 |
+
}));
|
660 |
+
}
|
661 |
+
|
662 |
+
} // namespace detectron2
|
663 |
+
|
664 |
+
|
665 |
+
template <typename scalar_t>
|
666 |
+
__device__ scalar_t dmcn_im2col_bilinear(
|
667 |
+
const scalar_t* bottom_data,
|
668 |
+
const int data_width,
|
669 |
+
const int height,
|
670 |
+
const int width,
|
671 |
+
scalar_t h,
|
672 |
+
scalar_t w) {
|
673 |
+
int h_low = floor(h);
|
674 |
+
int w_low = floor(w);
|
675 |
+
int h_high = h_low + 1;
|
676 |
+
int w_high = w_low + 1;
|
677 |
+
|
678 |
+
scalar_t lh = h - h_low;
|
679 |
+
scalar_t lw = w - w_low;
|
680 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
681 |
+
|
682 |
+
scalar_t v1 = 0;
|
683 |
+
if (h_low >= 0 && w_low >= 0)
|
684 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
685 |
+
scalar_t v2 = 0;
|
686 |
+
if (h_low >= 0 && w_high <= width - 1)
|
687 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
688 |
+
scalar_t v3 = 0;
|
689 |
+
if (h_high <= height - 1 && w_low >= 0)
|
690 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
691 |
+
scalar_t v4 = 0;
|
692 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
693 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
694 |
+
|
695 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
696 |
+
|
697 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
698 |
+
return val;
|
699 |
+
}
|
700 |
+
|
701 |
+
template <typename scalar_t>
|
702 |
+
__device__ scalar_t dmcn_get_gradient_weight(
|
703 |
+
scalar_t argmax_h,
|
704 |
+
scalar_t argmax_w,
|
705 |
+
const int h,
|
706 |
+
const int w,
|
707 |
+
const int height,
|
708 |
+
const int width) {
|
709 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
710 |
+
argmax_w >= width) {
|
711 |
+
// empty
|
712 |
+
return 0;
|
713 |
+
}
|
714 |
+
|
715 |
+
int argmax_h_low = floor(argmax_h);
|
716 |
+
int argmax_w_low = floor(argmax_w);
|
717 |
+
int argmax_h_high = argmax_h_low + 1;
|
718 |
+
int argmax_w_high = argmax_w_low + 1;
|
719 |
+
|
720 |
+
scalar_t weight = 0;
|
721 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
722 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
723 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
724 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
725 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
726 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
727 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
728 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
729 |
+
return weight;
|
730 |
+
}
|
731 |
+
|
732 |
+
template <typename scalar_t>
|
733 |
+
__device__ scalar_t dmcn_get_coordinate_weight(
|
734 |
+
scalar_t argmax_h,
|
735 |
+
scalar_t argmax_w,
|
736 |
+
const int height,
|
737 |
+
const int width,
|
738 |
+
const scalar_t* im_data,
|
739 |
+
const int data_width,
|
740 |
+
const int bp_dir) {
|
741 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
742 |
+
argmax_w >= width) {
|
743 |
+
// empty
|
744 |
+
return 0;
|
745 |
+
}
|
746 |
+
|
747 |
+
int argmax_h_low = floor(argmax_h);
|
748 |
+
int argmax_w_low = floor(argmax_w);
|
749 |
+
int argmax_h_high = argmax_h_low + 1;
|
750 |
+
int argmax_w_high = argmax_w_low + 1;
|
751 |
+
|
752 |
+
scalar_t weight = 0;
|
753 |
+
|
754 |
+
if (bp_dir == 0) {
|
755 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
756 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) *
|
757 |
+
im_data[argmax_h_low * data_width + argmax_w_low];
|
758 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
759 |
+
weight += -1 * (argmax_w - argmax_w_low) *
|
760 |
+
im_data[argmax_h_low * data_width + argmax_w_high];
|
761 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
762 |
+
weight += (argmax_w_low + 1 - argmax_w) *
|
763 |
+
im_data[argmax_h_high * data_width + argmax_w_low];
|
764 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
765 |
+
weight += (argmax_w - argmax_w_low) *
|
766 |
+
im_data[argmax_h_high * data_width + argmax_w_high];
|
767 |
+
} else if (bp_dir == 1) {
|
768 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
769 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) *
|
770 |
+
im_data[argmax_h_low * data_width + argmax_w_low];
|
771 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
772 |
+
weight += (argmax_h_low + 1 - argmax_h) *
|
773 |
+
im_data[argmax_h_low * data_width + argmax_w_high];
|
774 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
775 |
+
weight += -1 * (argmax_h - argmax_h_low) *
|
776 |
+
im_data[argmax_h_high * data_width + argmax_w_low];
|
777 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
778 |
+
weight += (argmax_h - argmax_h_low) *
|
779 |
+
im_data[argmax_h_high * data_width + argmax_w_high];
|
780 |
+
}
|
781 |
+
|
782 |
+
return weight;
|
783 |
+
}
|
784 |
+
|
785 |
+
template <typename scalar_t>
|
786 |
+
__global__ void modulated_deformable_im2col_gpu_kernel(
|
787 |
+
const int n,
|
788 |
+
const scalar_t* data_im,
|
789 |
+
const scalar_t* data_offset,
|
790 |
+
const scalar_t* data_mask,
|
791 |
+
const int height,
|
792 |
+
const int width,
|
793 |
+
const int kernel_h,
|
794 |
+
const int kernel_w,
|
795 |
+
const int pad_h,
|
796 |
+
const int pad_w,
|
797 |
+
const int stride_h,
|
798 |
+
const int stride_w,
|
799 |
+
const int dilation_h,
|
800 |
+
const int dilation_w,
|
801 |
+
const int channel_per_deformable_group,
|
802 |
+
const int batch_size,
|
803 |
+
const int num_channels,
|
804 |
+
const int deformable_group,
|
805 |
+
const int height_col,
|
806 |
+
const int width_col,
|
807 |
+
scalar_t* data_col) {
|
808 |
+
CUDA_KERNEL_LOOP(index, n) {
|
809 |
+
// index index of output matrix
|
810 |
+
const int w_col = index % width_col;
|
811 |
+
const int h_col = (index / width_col) % height_col;
|
812 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
813 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
814 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
815 |
+
|
816 |
+
// compute deformable group index
|
817 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
818 |
+
|
819 |
+
const int h_in = h_col * stride_h - pad_h;
|
820 |
+
const int w_in = w_col * stride_w - pad_w;
|
821 |
+
|
822 |
+
scalar_t* data_col_ptr = data_col +
|
823 |
+
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
824 |
+
// const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) *
|
825 |
+
// height + h_in) * width + w_in;
|
826 |
+
const scalar_t* data_im_ptr =
|
827 |
+
data_im + (b_col * num_channels + c_im) * height * width;
|
828 |
+
const scalar_t* data_offset_ptr = data_offset +
|
829 |
+
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
|
830 |
+
kernel_w * height_col * width_col;
|
831 |
+
|
832 |
+
const scalar_t* data_mask_ptr = data_mask +
|
833 |
+
(b_col * deformable_group + deformable_group_index) * kernel_h *
|
834 |
+
kernel_w * height_col * width_col;
|
835 |
+
|
836 |
+
for (int i = 0; i < kernel_h; ++i) {
|
837 |
+
for (int j = 0; j < kernel_w; ++j) {
|
838 |
+
const int data_offset_h_ptr =
|
839 |
+
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
840 |
+
const int data_offset_w_ptr =
|
841 |
+
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
|
842 |
+
w_col;
|
843 |
+
const int data_mask_hw_ptr =
|
844 |
+
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
|
845 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
846 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
847 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
848 |
+
scalar_t val = static_cast<scalar_t>(0);
|
849 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
850 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
851 |
+
// if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
|
852 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
|
853 |
+
// const float map_h = i * dilation_h + offset_h;
|
854 |
+
// const float map_w = j * dilation_w + offset_w;
|
855 |
+
// const int cur_height = height - h_in;
|
856 |
+
// const int cur_width = width - w_in;
|
857 |
+
// val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height,
|
858 |
+
// cur_width, map_h, map_w);
|
859 |
+
val = dmcn_im2col_bilinear(
|
860 |
+
data_im_ptr, width, height, width, h_im, w_im);
|
861 |
+
}
|
862 |
+
*data_col_ptr = val * mask;
|
863 |
+
data_col_ptr += batch_size * height_col * width_col;
|
864 |
+
// data_col_ptr += height_col * width_col;
|
865 |
+
}
|
866 |
+
}
|
867 |
+
}
|
868 |
+
}
|
869 |
+
|
870 |
+
template <typename scalar_t>
|
871 |
+
__global__ void modulated_deformable_col2im_gpu_kernel(
|
872 |
+
const int n,
|
873 |
+
const scalar_t* data_col,
|
874 |
+
const scalar_t* data_offset,
|
875 |
+
const scalar_t* data_mask,
|
876 |
+
const int channels,
|
877 |
+
const int height,
|
878 |
+
const int width,
|
879 |
+
const int kernel_h,
|
880 |
+
const int kernel_w,
|
881 |
+
const int pad_h,
|
882 |
+
const int pad_w,
|
883 |
+
const int stride_h,
|
884 |
+
const int stride_w,
|
885 |
+
const int dilation_h,
|
886 |
+
const int dilation_w,
|
887 |
+
const int channel_per_deformable_group,
|
888 |
+
const int batch_size,
|
889 |
+
const int deformable_group,
|
890 |
+
const int height_col,
|
891 |
+
const int width_col,
|
892 |
+
scalar_t* grad_im) {
|
893 |
+
CUDA_KERNEL_LOOP(index, n) {
|
894 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
895 |
+
const int i =
|
896 |
+
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
897 |
+
const int c =
|
898 |
+
index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
899 |
+
// compute the start and end of the output
|
900 |
+
|
901 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
902 |
+
|
903 |
+
int w_out = index % width_col;
|
904 |
+
int h_out = (index / width_col) % height_col;
|
905 |
+
int b = (index / width_col / height_col) % batch_size;
|
906 |
+
int w_in = w_out * stride_w - pad_w;
|
907 |
+
int h_in = h_out * stride_h - pad_h;
|
908 |
+
|
909 |
+
const scalar_t* data_offset_ptr = data_offset +
|
910 |
+
(b * deformable_group + deformable_group_index) * 2 * kernel_h *
|
911 |
+
kernel_w * height_col * width_col;
|
912 |
+
const scalar_t* data_mask_ptr = data_mask +
|
913 |
+
(b * deformable_group + deformable_group_index) * kernel_h * kernel_w *
|
914 |
+
height_col * width_col;
|
915 |
+
const int data_offset_h_ptr =
|
916 |
+
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
917 |
+
const int data_offset_w_ptr =
|
918 |
+
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
919 |
+
const int data_mask_hw_ptr =
|
920 |
+
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
|
921 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
922 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
923 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
924 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
925 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
926 |
+
|
927 |
+
const scalar_t cur_top_grad = data_col[index] * mask;
|
928 |
+
const int cur_h = (int)cur_inv_h_data;
|
929 |
+
const int cur_w = (int)cur_inv_w_data;
|
930 |
+
for (int dy = -2; dy <= 2; dy++) {
|
931 |
+
for (int dx = -2; dx <= 2; dx++) {
|
932 |
+
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
|
933 |
+
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
934 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
|
935 |
+
int cur_bottom_grad_pos =
|
936 |
+
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
937 |
+
scalar_t weight = dmcn_get_gradient_weight(
|
938 |
+
cur_inv_h_data,
|
939 |
+
cur_inv_w_data,
|
940 |
+
cur_h + dy,
|
941 |
+
cur_w + dx,
|
942 |
+
height,
|
943 |
+
width);
|
944 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
945 |
+
}
|
946 |
+
}
|
947 |
+
}
|
948 |
+
}
|
949 |
+
}
|
950 |
+
|
951 |
+
template <typename scalar_t>
|
952 |
+
__global__ void modulated_deformable_col2im_coord_gpu_kernel(
|
953 |
+
const int n,
|
954 |
+
const scalar_t* data_col,
|
955 |
+
const scalar_t* data_im,
|
956 |
+
const scalar_t* data_offset,
|
957 |
+
const scalar_t* data_mask,
|
958 |
+
const int channels,
|
959 |
+
const int height,
|
960 |
+
const int width,
|
961 |
+
const int kernel_h,
|
962 |
+
const int kernel_w,
|
963 |
+
const int pad_h,
|
964 |
+
const int pad_w,
|
965 |
+
const int stride_h,
|
966 |
+
const int stride_w,
|
967 |
+
const int dilation_h,
|
968 |
+
const int dilation_w,
|
969 |
+
const int channel_per_deformable_group,
|
970 |
+
const int batch_size,
|
971 |
+
const int offset_channels,
|
972 |
+
const int deformable_group,
|
973 |
+
const int height_col,
|
974 |
+
const int width_col,
|
975 |
+
scalar_t* grad_offset,
|
976 |
+
scalar_t* grad_mask) {
|
977 |
+
CUDA_KERNEL_LOOP(index, n) {
|
978 |
+
scalar_t val = 0, mval = 0;
|
979 |
+
int w = index % width_col;
|
980 |
+
int h = (index / width_col) % height_col;
|
981 |
+
int c = (index / width_col / height_col) % offset_channels;
|
982 |
+
int b = (index / width_col / height_col) / offset_channels;
|
983 |
+
// compute the start and end of the output
|
984 |
+
|
985 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
986 |
+
const int col_step = kernel_h * kernel_w;
|
987 |
+
int cnt = 0;
|
988 |
+
const scalar_t* data_col_ptr = data_col +
|
989 |
+
deformable_group_index * channel_per_deformable_group * batch_size *
|
990 |
+
width_col * height_col;
|
991 |
+
const scalar_t* data_im_ptr = data_im +
|
992 |
+
(b * deformable_group + deformable_group_index) *
|
993 |
+
channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
994 |
+
const scalar_t* data_offset_ptr = data_offset +
|
995 |
+
(b * deformable_group + deformable_group_index) * 2 * kernel_h *
|
996 |
+
kernel_w * height_col * width_col;
|
997 |
+
const scalar_t* data_mask_ptr = data_mask +
|
998 |
+
(b * deformable_group + deformable_group_index) * kernel_h * kernel_w *
|
999 |
+
height_col * width_col;
|
1000 |
+
|
1001 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
1002 |
+
|
1003 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
|
1004 |
+
col_c += col_step) {
|
1005 |
+
const int col_pos =
|
1006 |
+
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
1007 |
+
const int bp_dir = offset_c % 2;
|
1008 |
+
|
1009 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
1010 |
+
int i =
|
1011 |
+
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
1012 |
+
int w_out = col_pos % width_col;
|
1013 |
+
int h_out = (col_pos / width_col) % height_col;
|
1014 |
+
int w_in = w_out * stride_w - pad_w;
|
1015 |
+
int h_in = h_out * stride_h - pad_h;
|
1016 |
+
const int data_offset_h_ptr =
|
1017 |
+
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
1018 |
+
const int data_offset_w_ptr =
|
1019 |
+
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
|
1020 |
+
w_out);
|
1021 |
+
const int data_mask_hw_ptr =
|
1022 |
+
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
|
1023 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
1024 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
1025 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
1026 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
1027 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
1028 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
|
1029 |
+
inv_h = inv_w = -2;
|
1030 |
+
} else {
|
1031 |
+
mval += data_col_ptr[col_pos] *
|
1032 |
+
dmcn_im2col_bilinear(
|
1033 |
+
data_im_ptr + cnt * height * width,
|
1034 |
+
width,
|
1035 |
+
height,
|
1036 |
+
width,
|
1037 |
+
inv_h,
|
1038 |
+
inv_w);
|
1039 |
+
}
|
1040 |
+
const scalar_t weight = dmcn_get_coordinate_weight(
|
1041 |
+
inv_h,
|
1042 |
+
inv_w,
|
1043 |
+
height,
|
1044 |
+
width,
|
1045 |
+
data_im_ptr + cnt * height * width,
|
1046 |
+
width,
|
1047 |
+
bp_dir);
|
1048 |
+
val += weight * data_col_ptr[col_pos] * mask;
|
1049 |
+
cnt += 1;
|
1050 |
+
}
|
1051 |
+
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
|
1052 |
+
grad_offset[index] = val;
|
1053 |
+
if (offset_c % 2 == 0)
|
1054 |
+
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group +
|
1055 |
+
// deformable_group_index) * kernel_h * kernel_w + offset_c / 2) *
|
1056 |
+
// height_col + h) * width_col + w], mask_req, mval);
|
1057 |
+
grad_mask
|
1058 |
+
[(((b * deformable_group + deformable_group_index) * kernel_h *
|
1059 |
+
kernel_w +
|
1060 |
+
offset_c / 2) *
|
1061 |
+
height_col +
|
1062 |
+
h) *
|
1063 |
+
width_col +
|
1064 |
+
w] = mval;
|
1065 |
+
}
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
|
1069 |
+
namespace detectron2 {
|
1070 |
+
|
1071 |
+
void modulated_deformable_im2col_cuda(
|
1072 |
+
const at::Tensor data_im,
|
1073 |
+
const at::Tensor data_offset,
|
1074 |
+
const at::Tensor data_mask,
|
1075 |
+
const int batch_size,
|
1076 |
+
const int channels,
|
1077 |
+
const int height_im,
|
1078 |
+
const int width_im,
|
1079 |
+
const int height_col,
|
1080 |
+
const int width_col,
|
1081 |
+
const int kernel_h,
|
1082 |
+
const int kenerl_w,
|
1083 |
+
const int pad_h,
|
1084 |
+
const int pad_w,
|
1085 |
+
const int stride_h,
|
1086 |
+
const int stride_w,
|
1087 |
+
const int dilation_h,
|
1088 |
+
const int dilation_w,
|
1089 |
+
const int deformable_group,
|
1090 |
+
at::Tensor data_col) {
|
1091 |
+
// num_axes should be smaller than block size
|
1092 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
1093 |
+
const int num_kernels = channels * batch_size * height_col * width_col;
|
1094 |
+
|
1095 |
+
at::cuda::CUDAGuard device_guard(data_im.device());
|
1096 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
1097 |
+
|
1098 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
1099 |
+
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
|
1100 |
+
const scalar_t* data_im_ = data_im.data_ptr<scalar_t>();
|
1101 |
+
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
|
1102 |
+
const scalar_t* data_mask_ = data_mask.data_ptr<scalar_t>();
|
1103 |
+
scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
|
1104 |
+
|
1105 |
+
modulated_deformable_im2col_gpu_kernel<<<
|
1106 |
+
GET_BLOCKS(num_kernels),
|
1107 |
+
CUDA_NUM_THREADS,
|
1108 |
+
0,
|
1109 |
+
stream>>>(
|
1110 |
+
num_kernels,
|
1111 |
+
data_im_,
|
1112 |
+
data_offset_,
|
1113 |
+
data_mask_,
|
1114 |
+
height_im,
|
1115 |
+
width_im,
|
1116 |
+
kernel_h,
|
1117 |
+
kenerl_w,
|
1118 |
+
pad_h,
|
1119 |
+
pad_w,
|
1120 |
+
stride_h,
|
1121 |
+
stride_w,
|
1122 |
+
dilation_h,
|
1123 |
+
dilation_w,
|
1124 |
+
channel_per_deformable_group,
|
1125 |
+
batch_size,
|
1126 |
+
channels,
|
1127 |
+
deformable_group,
|
1128 |
+
height_col,
|
1129 |
+
width_col,
|
1130 |
+
data_col_);
|
1131 |
+
}));
|
1132 |
+
|
1133 |
+
cudaError_t err = cudaGetLastError();
|
1134 |
+
if (err != cudaSuccess) {
|
1135 |
+
printf(
|
1136 |
+
"error in modulated_deformable_im2col_cuda: %s\n",
|
1137 |
+
cudaGetErrorString(err));
|
1138 |
+
}
|
1139 |
+
}
|
1140 |
+
|
1141 |
+
void modulated_deformable_col2im_cuda(
|
1142 |
+
const at::Tensor data_col,
|
1143 |
+
const at::Tensor data_offset,
|
1144 |
+
const at::Tensor data_mask,
|
1145 |
+
const int batch_size,
|
1146 |
+
const int channels,
|
1147 |
+
const int height_im,
|
1148 |
+
const int width_im,
|
1149 |
+
const int height_col,
|
1150 |
+
const int width_col,
|
1151 |
+
const int kernel_h,
|
1152 |
+
const int kernel_w,
|
1153 |
+
const int pad_h,
|
1154 |
+
const int pad_w,
|
1155 |
+
const int stride_h,
|
1156 |
+
const int stride_w,
|
1157 |
+
const int dilation_h,
|
1158 |
+
const int dilation_w,
|
1159 |
+
const int deformable_group,
|
1160 |
+
at::Tensor grad_im) {
|
1161 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
1162 |
+
const int num_kernels =
|
1163 |
+
channels * kernel_h * kernel_w * batch_size * height_col * width_col;
|
1164 |
+
|
1165 |
+
at::cuda::CUDAGuard device_guard(data_col.device());
|
1166 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
1167 |
+
|
1168 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
1169 |
+
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
|
1170 |
+
const scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
|
1171 |
+
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
|
1172 |
+
const scalar_t* data_mask_ = data_mask.data_ptr<scalar_t>();
|
1173 |
+
scalar_t* grad_im_ = grad_im.data_ptr<scalar_t>();
|
1174 |
+
|
1175 |
+
modulated_deformable_col2im_gpu_kernel<<<
|
1176 |
+
GET_BLOCKS(num_kernels),
|
1177 |
+
CUDA_NUM_THREADS,
|
1178 |
+
0,
|
1179 |
+
stream>>>(
|
1180 |
+
num_kernels,
|
1181 |
+
data_col_,
|
1182 |
+
data_offset_,
|
1183 |
+
data_mask_,
|
1184 |
+
channels,
|
1185 |
+
height_im,
|
1186 |
+
width_im,
|
1187 |
+
kernel_h,
|
1188 |
+
kernel_w,
|
1189 |
+
pad_h,
|
1190 |
+
pad_w,
|
1191 |
+
stride_h,
|
1192 |
+
stride_w,
|
1193 |
+
dilation_h,
|
1194 |
+
dilation_w,
|
1195 |
+
channel_per_deformable_group,
|
1196 |
+
batch_size,
|
1197 |
+
deformable_group,
|
1198 |
+
height_col,
|
1199 |
+
width_col,
|
1200 |
+
grad_im_);
|
1201 |
+
}));
|
1202 |
+
|
1203 |
+
cudaError_t err = cudaGetLastError();
|
1204 |
+
if (err != cudaSuccess) {
|
1205 |
+
printf(
|
1206 |
+
"error in modulated_deformable_col2im_cuda: %s\n",
|
1207 |
+
cudaGetErrorString(err));
|
1208 |
+
}
|
1209 |
+
}
|
1210 |
+
|
1211 |
+
void modulated_deformable_col2im_coord_cuda(
|
1212 |
+
const at::Tensor data_col,
|
1213 |
+
const at::Tensor data_im,
|
1214 |
+
const at::Tensor data_offset,
|
1215 |
+
const at::Tensor data_mask,
|
1216 |
+
const int batch_size,
|
1217 |
+
const int channels,
|
1218 |
+
const int height_im,
|
1219 |
+
const int width_im,
|
1220 |
+
const int height_col,
|
1221 |
+
const int width_col,
|
1222 |
+
const int kernel_h,
|
1223 |
+
const int kernel_w,
|
1224 |
+
const int pad_h,
|
1225 |
+
const int pad_w,
|
1226 |
+
const int stride_h,
|
1227 |
+
const int stride_w,
|
1228 |
+
const int dilation_h,
|
1229 |
+
const int dilation_w,
|
1230 |
+
const int deformable_group,
|
1231 |
+
at::Tensor grad_offset,
|
1232 |
+
at::Tensor grad_mask) {
|
1233 |
+
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h *
|
1234 |
+
kernel_w * deformable_group;
|
1235 |
+
const int channel_per_deformable_group =
|
1236 |
+
channels * kernel_h * kernel_w / deformable_group;
|
1237 |
+
|
1238 |
+
at::cuda::CUDAGuard device_guard(data_col.device());
|
1239 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
1240 |
+
|
1241 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
1242 |
+
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
|
1243 |
+
const scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
|
1244 |
+
const scalar_t* data_im_ = data_im.data_ptr<scalar_t>();
|
1245 |
+
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
|
1246 |
+
const scalar_t* data_mask_ = data_mask.data_ptr<scalar_t>();
|
1247 |
+
scalar_t* grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
1248 |
+
scalar_t* grad_mask_ = grad_mask.data_ptr<scalar_t>();
|
1249 |
+
|
1250 |
+
modulated_deformable_col2im_coord_gpu_kernel<<<
|
1251 |
+
GET_BLOCKS(num_kernels),
|
1252 |
+
CUDA_NUM_THREADS,
|
1253 |
+
0,
|
1254 |
+
stream>>>(
|
1255 |
+
num_kernels,
|
1256 |
+
data_col_,
|
1257 |
+
data_im_,
|
1258 |
+
data_offset_,
|
1259 |
+
data_mask_,
|
1260 |
+
channels,
|
1261 |
+
height_im,
|
1262 |
+
width_im,
|
1263 |
+
kernel_h,
|
1264 |
+
kernel_w,
|
1265 |
+
pad_h,
|
1266 |
+
pad_w,
|
1267 |
+
stride_h,
|
1268 |
+
stride_w,
|
1269 |
+
dilation_h,
|
1270 |
+
dilation_w,
|
1271 |
+
channel_per_deformable_group,
|
1272 |
+
batch_size,
|
1273 |
+
2 * kernel_h * kernel_w * deformable_group,
|
1274 |
+
deformable_group,
|
1275 |
+
height_col,
|
1276 |
+
width_col,
|
1277 |
+
grad_offset_,
|
1278 |
+
grad_mask_);
|
1279 |
+
}));
|
1280 |
+
cudaError_t err = cudaGetLastError();
|
1281 |
+
if (err != cudaSuccess) {
|
1282 |
+
printf(
|
1283 |
+
"error in modulated_deformable_col2im_coord_cuda: %s\n",
|
1284 |
+
cudaGetErrorString(err));
|
1285 |
+
}
|
1286 |
+
}
|
1287 |
+
|
1288 |
+
} // namespace detectron2
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/layers/csrc/nms_rotated/nms_rotated.h
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#pragma once
|
3 |
+
#include <torch/types.h>
|
4 |
+
|
5 |
+
namespace detectron2 {
|
6 |
+
|
7 |
+
at::Tensor nms_rotated_cpu(
|
8 |
+
const at::Tensor& dets,
|
9 |
+
const at::Tensor& scores,
|
10 |
+
const double iou_threshold);
|
11 |
+
|
12 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
13 |
+
at::Tensor nms_rotated_cuda(
|
14 |
+
const at::Tensor& dets,
|
15 |
+
const at::Tensor& scores,
|
16 |
+
const double iou_threshold);
|
17 |
+
#endif
|
18 |
+
|
19 |
+
// Interface for Python
|
20 |
+
// inline is needed to prevent multiple function definitions when this header is
|
21 |
+
// included by different cpps
|
22 |
+
inline at::Tensor nms_rotated(
|
23 |
+
const at::Tensor& dets,
|
24 |
+
const at::Tensor& scores,
|
25 |
+
const double iou_threshold) {
|
26 |
+
assert(dets.device().is_cuda() == scores.device().is_cuda());
|
27 |
+
if (dets.device().is_cuda()) {
|
28 |
+
#if defined(WITH_CUDA) || defined(WITH_HIP)
|
29 |
+
return nms_rotated_cuda(
|
30 |
+
dets.contiguous(), scores.contiguous(), iou_threshold);
|
31 |
+
#else
|
32 |
+
AT_ERROR("Detectron2 is not compiled with GPU support!");
|
33 |
+
#endif
|
34 |
+
}
|
35 |
+
|
36 |
+
return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold);
|
37 |
+
}
|
38 |
+
|
39 |
+
} // namespace detectron2
|