cc434d1df1b6943b39242aada22ff4706a40fad6c0adf32acd7723aef96ae719
Browse files- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/modeling/roi_heads/roi_heads.py +877 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/modeling/roi_heads/rotated_fast_rcnn.py +271 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/modeling/sampling.py +54 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/modeling/test_time_augmentation.py +307 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/README.md +2 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/__init__.py +34 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/__init__.py +5 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/build_solver.py +27 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/config.py +28 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/loss.py +40 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/lr_scheduler.py +62 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/resnet.py +158 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/semantic_seg.py +348 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/solver/__init__.py +11 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/solver/build.py +310 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/solver/lr_scheduler.py +246 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/__init__.py +17 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/boxes.py +425 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/image_list.py +129 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/instances.py +194 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/keypoints.py +235 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/masks.py +534 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/rotated_boxes.py +505 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/__init__.py +15 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/base_tracker.py +64 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/bbox_iou_tracker.py +276 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/hungarian_tracker.py +171 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py +102 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/utils.py +40 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py +129 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/README.md +5 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/__init__.py +1 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/analysis.py +188 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/collect_env.py +246 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/colormap.py +158 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/comm.py +238 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/develop.py +59 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/env.py +170 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/events.py +534 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/file_io.py +39 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/logger.py +237 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/memory.py +84 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/registry.py +60 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/serialize.py +32 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/testing.py +478 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/tracing.py +71 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/video_visualizer.py +287 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/visualizer.py +1267 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/__init__.py +9 -0
- extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/config.py +239 -0
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/modeling/roi_heads/roi_heads.py
ADDED
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import inspect
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
from typing import Dict, List, Optional, Tuple
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from annotator.oneformer.detectron2.config import configurable
|
10 |
+
from annotator.oneformer.detectron2.layers import ShapeSpec, nonzero_tuple
|
11 |
+
from annotator.oneformer.detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
|
12 |
+
from annotator.oneformer.detectron2.utils.events import get_event_storage
|
13 |
+
from annotator.oneformer.detectron2.utils.registry import Registry
|
14 |
+
|
15 |
+
from ..backbone.resnet import BottleneckBlock, ResNet
|
16 |
+
from ..matcher import Matcher
|
17 |
+
from ..poolers import ROIPooler
|
18 |
+
from ..proposal_generator.proposal_utils import add_ground_truth_to_proposals
|
19 |
+
from ..sampling import subsample_labels
|
20 |
+
from .box_head import build_box_head
|
21 |
+
from .fast_rcnn import FastRCNNOutputLayers
|
22 |
+
from .keypoint_head import build_keypoint_head
|
23 |
+
from .mask_head import build_mask_head
|
24 |
+
|
25 |
+
ROI_HEADS_REGISTRY = Registry("ROI_HEADS")
|
26 |
+
ROI_HEADS_REGISTRY.__doc__ = """
|
27 |
+
Registry for ROI heads in a generalized R-CNN model.
|
28 |
+
ROIHeads take feature maps and region proposals, and
|
29 |
+
perform per-region computation.
|
30 |
+
|
31 |
+
The registered object will be called with `obj(cfg, input_shape)`.
|
32 |
+
The call is expected to return an :class:`ROIHeads`.
|
33 |
+
"""
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
def build_roi_heads(cfg, input_shape):
|
39 |
+
"""
|
40 |
+
Build ROIHeads defined by `cfg.MODEL.ROI_HEADS.NAME`.
|
41 |
+
"""
|
42 |
+
name = cfg.MODEL.ROI_HEADS.NAME
|
43 |
+
return ROI_HEADS_REGISTRY.get(name)(cfg, input_shape)
|
44 |
+
|
45 |
+
|
46 |
+
def select_foreground_proposals(
|
47 |
+
proposals: List[Instances], bg_label: int
|
48 |
+
) -> Tuple[List[Instances], List[torch.Tensor]]:
|
49 |
+
"""
|
50 |
+
Given a list of N Instances (for N images), each containing a `gt_classes` field,
|
51 |
+
return a list of Instances that contain only instances with `gt_classes != -1 &&
|
52 |
+
gt_classes != bg_label`.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
proposals (list[Instances]): A list of N Instances, where N is the number of
|
56 |
+
images in the batch.
|
57 |
+
bg_label: label index of background class.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
list[Instances]: N Instances, each contains only the selected foreground instances.
|
61 |
+
list[Tensor]: N boolean vector, correspond to the selection mask of
|
62 |
+
each Instances object. True for selected instances.
|
63 |
+
"""
|
64 |
+
assert isinstance(proposals, (list, tuple))
|
65 |
+
assert isinstance(proposals[0], Instances)
|
66 |
+
assert proposals[0].has("gt_classes")
|
67 |
+
fg_proposals = []
|
68 |
+
fg_selection_masks = []
|
69 |
+
for proposals_per_image in proposals:
|
70 |
+
gt_classes = proposals_per_image.gt_classes
|
71 |
+
fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label)
|
72 |
+
fg_idxs = fg_selection_mask.nonzero().squeeze(1)
|
73 |
+
fg_proposals.append(proposals_per_image[fg_idxs])
|
74 |
+
fg_selection_masks.append(fg_selection_mask)
|
75 |
+
return fg_proposals, fg_selection_masks
|
76 |
+
|
77 |
+
|
78 |
+
def select_proposals_with_visible_keypoints(proposals: List[Instances]) -> List[Instances]:
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
proposals (list[Instances]): a list of N Instances, where N is the
|
82 |
+
number of images.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
proposals: only contains proposals with at least one visible keypoint.
|
86 |
+
|
87 |
+
Note that this is still slightly different from Detectron.
|
88 |
+
In Detectron, proposals for training keypoint head are re-sampled from
|
89 |
+
all the proposals with IOU>threshold & >=1 visible keypoint.
|
90 |
+
|
91 |
+
Here, the proposals are first sampled from all proposals with
|
92 |
+
IOU>threshold, then proposals with no visible keypoint are filtered out.
|
93 |
+
This strategy seems to make no difference on Detectron and is easier to implement.
|
94 |
+
"""
|
95 |
+
ret = []
|
96 |
+
all_num_fg = []
|
97 |
+
for proposals_per_image in proposals:
|
98 |
+
# If empty/unannotated image (hard negatives), skip filtering for train
|
99 |
+
if len(proposals_per_image) == 0:
|
100 |
+
ret.append(proposals_per_image)
|
101 |
+
continue
|
102 |
+
gt_keypoints = proposals_per_image.gt_keypoints.tensor
|
103 |
+
# #fg x K x 3
|
104 |
+
vis_mask = gt_keypoints[:, :, 2] >= 1
|
105 |
+
xs, ys = gt_keypoints[:, :, 0], gt_keypoints[:, :, 1]
|
106 |
+
proposal_boxes = proposals_per_image.proposal_boxes.tensor.unsqueeze(dim=1) # #fg x 1 x 4
|
107 |
+
kp_in_box = (
|
108 |
+
(xs >= proposal_boxes[:, :, 0])
|
109 |
+
& (xs <= proposal_boxes[:, :, 2])
|
110 |
+
& (ys >= proposal_boxes[:, :, 1])
|
111 |
+
& (ys <= proposal_boxes[:, :, 3])
|
112 |
+
)
|
113 |
+
selection = (kp_in_box & vis_mask).any(dim=1)
|
114 |
+
selection_idxs = nonzero_tuple(selection)[0]
|
115 |
+
all_num_fg.append(selection_idxs.numel())
|
116 |
+
ret.append(proposals_per_image[selection_idxs])
|
117 |
+
|
118 |
+
storage = get_event_storage()
|
119 |
+
storage.put_scalar("keypoint_head/num_fg_samples", np.mean(all_num_fg))
|
120 |
+
return ret
|
121 |
+
|
122 |
+
|
123 |
+
class ROIHeads(torch.nn.Module):
|
124 |
+
"""
|
125 |
+
ROIHeads perform all per-region computation in an R-CNN.
|
126 |
+
|
127 |
+
It typically contains logic to
|
128 |
+
|
129 |
+
1. (in training only) match proposals with ground truth and sample them
|
130 |
+
2. crop the regions and extract per-region features using proposals
|
131 |
+
3. make per-region predictions with different heads
|
132 |
+
|
133 |
+
It can have many variants, implemented as subclasses of this class.
|
134 |
+
This base class contains the logic to match/sample proposals.
|
135 |
+
But it is not necessary to inherit this class if the sampling logic is not needed.
|
136 |
+
"""
|
137 |
+
|
138 |
+
@configurable
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
*,
|
142 |
+
num_classes,
|
143 |
+
batch_size_per_image,
|
144 |
+
positive_fraction,
|
145 |
+
proposal_matcher,
|
146 |
+
proposal_append_gt=True,
|
147 |
+
):
|
148 |
+
"""
|
149 |
+
NOTE: this interface is experimental.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
num_classes (int): number of foreground classes (i.e. background is not included)
|
153 |
+
batch_size_per_image (int): number of proposals to sample for training
|
154 |
+
positive_fraction (float): fraction of positive (foreground) proposals
|
155 |
+
to sample for training.
|
156 |
+
proposal_matcher (Matcher): matcher that matches proposals and ground truth
|
157 |
+
proposal_append_gt (bool): whether to include ground truth as proposals as well
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
self.batch_size_per_image = batch_size_per_image
|
161 |
+
self.positive_fraction = positive_fraction
|
162 |
+
self.num_classes = num_classes
|
163 |
+
self.proposal_matcher = proposal_matcher
|
164 |
+
self.proposal_append_gt = proposal_append_gt
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def from_config(cls, cfg):
|
168 |
+
return {
|
169 |
+
"batch_size_per_image": cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE,
|
170 |
+
"positive_fraction": cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION,
|
171 |
+
"num_classes": cfg.MODEL.ROI_HEADS.NUM_CLASSES,
|
172 |
+
"proposal_append_gt": cfg.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT,
|
173 |
+
# Matcher to assign box proposals to gt boxes
|
174 |
+
"proposal_matcher": Matcher(
|
175 |
+
cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS,
|
176 |
+
cfg.MODEL.ROI_HEADS.IOU_LABELS,
|
177 |
+
allow_low_quality_matches=False,
|
178 |
+
),
|
179 |
+
}
|
180 |
+
|
181 |
+
def _sample_proposals(
|
182 |
+
self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor
|
183 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
184 |
+
"""
|
185 |
+
Based on the matching between N proposals and M groundtruth,
|
186 |
+
sample the proposals and set their classification labels.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
matched_idxs (Tensor): a vector of length N, each is the best-matched
|
190 |
+
gt index in [0, M) for each proposal.
|
191 |
+
matched_labels (Tensor): a vector of length N, the matcher's label
|
192 |
+
(one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
|
193 |
+
gt_classes (Tensor): a vector of length M.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Tensor: a vector of indices of sampled proposals. Each is in [0, N).
|
197 |
+
Tensor: a vector of the same length, the classification label for
|
198 |
+
each sampled proposal. Each sample is labeled as either a category in
|
199 |
+
[0, num_classes) or the background (num_classes).
|
200 |
+
"""
|
201 |
+
has_gt = gt_classes.numel() > 0
|
202 |
+
# Get the corresponding GT for each proposal
|
203 |
+
if has_gt:
|
204 |
+
gt_classes = gt_classes[matched_idxs]
|
205 |
+
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
|
206 |
+
gt_classes[matched_labels == 0] = self.num_classes
|
207 |
+
# Label ignore proposals (-1 label)
|
208 |
+
gt_classes[matched_labels == -1] = -1
|
209 |
+
else:
|
210 |
+
gt_classes = torch.zeros_like(matched_idxs) + self.num_classes
|
211 |
+
|
212 |
+
sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
|
213 |
+
gt_classes, self.batch_size_per_image, self.positive_fraction, self.num_classes
|
214 |
+
)
|
215 |
+
|
216 |
+
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
|
217 |
+
return sampled_idxs, gt_classes[sampled_idxs]
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def label_and_sample_proposals(
|
221 |
+
self, proposals: List[Instances], targets: List[Instances]
|
222 |
+
) -> List[Instances]:
|
223 |
+
"""
|
224 |
+
Prepare some proposals to be used to train the ROI heads.
|
225 |
+
It performs box matching between `proposals` and `targets`, and assigns
|
226 |
+
training labels to the proposals.
|
227 |
+
It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth
|
228 |
+
boxes, with a fraction of positives that is no larger than
|
229 |
+
``self.positive_fraction``.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
See :meth:`ROIHeads.forward`
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
list[Instances]:
|
236 |
+
length `N` list of `Instances`s containing the proposals
|
237 |
+
sampled for training. Each `Instances` has the following fields:
|
238 |
+
|
239 |
+
- proposal_boxes: the proposal boxes
|
240 |
+
- gt_boxes: the ground-truth box that the proposal is assigned to
|
241 |
+
(this is only meaningful if the proposal has a label > 0; if label = 0
|
242 |
+
then the ground-truth box is random)
|
243 |
+
|
244 |
+
Other fields such as "gt_classes", "gt_masks", that's included in `targets`.
|
245 |
+
"""
|
246 |
+
# Augment proposals with ground-truth boxes.
|
247 |
+
# In the case of learned proposals (e.g., RPN), when training starts
|
248 |
+
# the proposals will be low quality due to random initialization.
|
249 |
+
# It's possible that none of these initial
|
250 |
+
# proposals have high enough overlap with the gt objects to be used
|
251 |
+
# as positive examples for the second stage components (box head,
|
252 |
+
# cls head, mask head). Adding the gt boxes to the set of proposals
|
253 |
+
# ensures that the second stage components will have some positive
|
254 |
+
# examples from the start of training. For RPN, this augmentation improves
|
255 |
+
# convergence and empirically improves box AP on COCO by about 0.5
|
256 |
+
# points (under one tested configuration).
|
257 |
+
if self.proposal_append_gt:
|
258 |
+
proposals = add_ground_truth_to_proposals(targets, proposals)
|
259 |
+
|
260 |
+
proposals_with_gt = []
|
261 |
+
|
262 |
+
num_fg_samples = []
|
263 |
+
num_bg_samples = []
|
264 |
+
for proposals_per_image, targets_per_image in zip(proposals, targets):
|
265 |
+
has_gt = len(targets_per_image) > 0
|
266 |
+
match_quality_matrix = pairwise_iou(
|
267 |
+
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
|
268 |
+
)
|
269 |
+
matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix)
|
270 |
+
sampled_idxs, gt_classes = self._sample_proposals(
|
271 |
+
matched_idxs, matched_labels, targets_per_image.gt_classes
|
272 |
+
)
|
273 |
+
|
274 |
+
# Set target attributes of the sampled proposals:
|
275 |
+
proposals_per_image = proposals_per_image[sampled_idxs]
|
276 |
+
proposals_per_image.gt_classes = gt_classes
|
277 |
+
|
278 |
+
if has_gt:
|
279 |
+
sampled_targets = matched_idxs[sampled_idxs]
|
280 |
+
# We index all the attributes of targets that start with "gt_"
|
281 |
+
# and have not been added to proposals yet (="gt_classes").
|
282 |
+
# NOTE: here the indexing waste some compute, because heads
|
283 |
+
# like masks, keypoints, etc, will filter the proposals again,
|
284 |
+
# (by foreground/background, or number of keypoints in the image, etc)
|
285 |
+
# so we essentially index the data twice.
|
286 |
+
for (trg_name, trg_value) in targets_per_image.get_fields().items():
|
287 |
+
if trg_name.startswith("gt_") and not proposals_per_image.has(trg_name):
|
288 |
+
proposals_per_image.set(trg_name, trg_value[sampled_targets])
|
289 |
+
# If no GT is given in the image, we don't know what a dummy gt value can be.
|
290 |
+
# Therefore the returned proposals won't have any gt_* fields, except for a
|
291 |
+
# gt_classes full of background label.
|
292 |
+
|
293 |
+
num_bg_samples.append((gt_classes == self.num_classes).sum().item())
|
294 |
+
num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])
|
295 |
+
proposals_with_gt.append(proposals_per_image)
|
296 |
+
|
297 |
+
# Log the number of fg/bg samples that are selected for training ROI heads
|
298 |
+
storage = get_event_storage()
|
299 |
+
storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples))
|
300 |
+
storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples))
|
301 |
+
|
302 |
+
return proposals_with_gt
|
303 |
+
|
304 |
+
def forward(
|
305 |
+
self,
|
306 |
+
images: ImageList,
|
307 |
+
features: Dict[str, torch.Tensor],
|
308 |
+
proposals: List[Instances],
|
309 |
+
targets: Optional[List[Instances]] = None,
|
310 |
+
) -> Tuple[List[Instances], Dict[str, torch.Tensor]]:
|
311 |
+
"""
|
312 |
+
Args:
|
313 |
+
images (ImageList):
|
314 |
+
features (dict[str,Tensor]): input data as a mapping from feature
|
315 |
+
map name to tensor. Axis 0 represents the number of images `N` in
|
316 |
+
the input data; axes 1-3 are channels, height, and width, which may
|
317 |
+
vary between feature maps (e.g., if a feature pyramid is used).
|
318 |
+
proposals (list[Instances]): length `N` list of `Instances`. The i-th
|
319 |
+
`Instances` contains object proposals for the i-th input image,
|
320 |
+
with fields "proposal_boxes" and "objectness_logits".
|
321 |
+
targets (list[Instances], optional): length `N` list of `Instances`. The i-th
|
322 |
+
`Instances` contains the ground-truth per-instance annotations
|
323 |
+
for the i-th input image. Specify `targets` during training only.
|
324 |
+
It may have the following fields:
|
325 |
+
|
326 |
+
- gt_boxes: the bounding box of each instance.
|
327 |
+
- gt_classes: the label for each instance with a category ranging in [0, #class].
|
328 |
+
- gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.
|
329 |
+
- gt_keypoints: NxKx3, the groud-truth keypoints for each instance.
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
list[Instances]: length `N` list of `Instances` containing the
|
333 |
+
detected instances. Returned during inference only; may be [] during training.
|
334 |
+
|
335 |
+
dict[str->Tensor]:
|
336 |
+
mapping from a named loss to a tensor storing the loss. Used during training only.
|
337 |
+
"""
|
338 |
+
raise NotImplementedError()
|
339 |
+
|
340 |
+
|
341 |
+
@ROI_HEADS_REGISTRY.register()
|
342 |
+
class Res5ROIHeads(ROIHeads):
|
343 |
+
"""
|
344 |
+
The ROIHeads in a typical "C4" R-CNN model, where
|
345 |
+
the box and mask head share the cropping and
|
346 |
+
the per-region feature computation by a Res5 block.
|
347 |
+
See :paper:`ResNet` Appendix A.
|
348 |
+
"""
|
349 |
+
|
350 |
+
@configurable
|
351 |
+
def __init__(
|
352 |
+
self,
|
353 |
+
*,
|
354 |
+
in_features: List[str],
|
355 |
+
pooler: ROIPooler,
|
356 |
+
res5: nn.Module,
|
357 |
+
box_predictor: nn.Module,
|
358 |
+
mask_head: Optional[nn.Module] = None,
|
359 |
+
**kwargs,
|
360 |
+
):
|
361 |
+
"""
|
362 |
+
NOTE: this interface is experimental.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
in_features (list[str]): list of backbone feature map names to use for
|
366 |
+
feature extraction
|
367 |
+
pooler (ROIPooler): pooler to extra region features from backbone
|
368 |
+
res5 (nn.Sequential): a CNN to compute per-region features, to be used by
|
369 |
+
``box_predictor`` and ``mask_head``. Typically this is a "res5"
|
370 |
+
block from a ResNet.
|
371 |
+
box_predictor (nn.Module): make box predictions from the feature.
|
372 |
+
Should have the same interface as :class:`FastRCNNOutputLayers`.
|
373 |
+
mask_head (nn.Module): transform features to make mask predictions
|
374 |
+
"""
|
375 |
+
super().__init__(**kwargs)
|
376 |
+
self.in_features = in_features
|
377 |
+
self.pooler = pooler
|
378 |
+
if isinstance(res5, (list, tuple)):
|
379 |
+
res5 = nn.Sequential(*res5)
|
380 |
+
self.res5 = res5
|
381 |
+
self.box_predictor = box_predictor
|
382 |
+
self.mask_on = mask_head is not None
|
383 |
+
if self.mask_on:
|
384 |
+
self.mask_head = mask_head
|
385 |
+
|
386 |
+
@classmethod
|
387 |
+
def from_config(cls, cfg, input_shape):
|
388 |
+
# fmt: off
|
389 |
+
ret = super().from_config(cfg)
|
390 |
+
in_features = ret["in_features"] = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
391 |
+
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
|
392 |
+
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
|
393 |
+
pooler_scales = (1.0 / input_shape[in_features[0]].stride, )
|
394 |
+
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
395 |
+
mask_on = cfg.MODEL.MASK_ON
|
396 |
+
# fmt: on
|
397 |
+
assert not cfg.MODEL.KEYPOINT_ON
|
398 |
+
assert len(in_features) == 1
|
399 |
+
|
400 |
+
ret["pooler"] = ROIPooler(
|
401 |
+
output_size=pooler_resolution,
|
402 |
+
scales=pooler_scales,
|
403 |
+
sampling_ratio=sampling_ratio,
|
404 |
+
pooler_type=pooler_type,
|
405 |
+
)
|
406 |
+
|
407 |
+
# Compatbility with old moco code. Might be useful.
|
408 |
+
# See notes in StandardROIHeads.from_config
|
409 |
+
if not inspect.ismethod(cls._build_res5_block):
|
410 |
+
logger.warning(
|
411 |
+
"The behavior of _build_res5_block may change. "
|
412 |
+
"Please do not depend on private methods."
|
413 |
+
)
|
414 |
+
cls._build_res5_block = classmethod(cls._build_res5_block)
|
415 |
+
|
416 |
+
ret["res5"], out_channels = cls._build_res5_block(cfg)
|
417 |
+
ret["box_predictor"] = FastRCNNOutputLayers(
|
418 |
+
cfg, ShapeSpec(channels=out_channels, height=1, width=1)
|
419 |
+
)
|
420 |
+
|
421 |
+
if mask_on:
|
422 |
+
ret["mask_head"] = build_mask_head(
|
423 |
+
cfg,
|
424 |
+
ShapeSpec(channels=out_channels, width=pooler_resolution, height=pooler_resolution),
|
425 |
+
)
|
426 |
+
return ret
|
427 |
+
|
428 |
+
@classmethod
|
429 |
+
def _build_res5_block(cls, cfg):
|
430 |
+
# fmt: off
|
431 |
+
stage_channel_factor = 2 ** 3 # res5 is 8x res2
|
432 |
+
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
433 |
+
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
434 |
+
bottleneck_channels = num_groups * width_per_group * stage_channel_factor
|
435 |
+
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor
|
436 |
+
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
437 |
+
norm = cfg.MODEL.RESNETS.NORM
|
438 |
+
assert not cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE[-1], \
|
439 |
+
"Deformable conv is not yet supported in res5 head."
|
440 |
+
# fmt: on
|
441 |
+
|
442 |
+
blocks = ResNet.make_stage(
|
443 |
+
BottleneckBlock,
|
444 |
+
3,
|
445 |
+
stride_per_block=[2, 1, 1],
|
446 |
+
in_channels=out_channels // 2,
|
447 |
+
bottleneck_channels=bottleneck_channels,
|
448 |
+
out_channels=out_channels,
|
449 |
+
num_groups=num_groups,
|
450 |
+
norm=norm,
|
451 |
+
stride_in_1x1=stride_in_1x1,
|
452 |
+
)
|
453 |
+
return nn.Sequential(*blocks), out_channels
|
454 |
+
|
455 |
+
def _shared_roi_transform(self, features: List[torch.Tensor], boxes: List[Boxes]):
|
456 |
+
x = self.pooler(features, boxes)
|
457 |
+
return self.res5(x)
|
458 |
+
|
459 |
+
def forward(
|
460 |
+
self,
|
461 |
+
images: ImageList,
|
462 |
+
features: Dict[str, torch.Tensor],
|
463 |
+
proposals: List[Instances],
|
464 |
+
targets: Optional[List[Instances]] = None,
|
465 |
+
):
|
466 |
+
"""
|
467 |
+
See :meth:`ROIHeads.forward`.
|
468 |
+
"""
|
469 |
+
del images
|
470 |
+
|
471 |
+
if self.training:
|
472 |
+
assert targets
|
473 |
+
proposals = self.label_and_sample_proposals(proposals, targets)
|
474 |
+
del targets
|
475 |
+
|
476 |
+
proposal_boxes = [x.proposal_boxes for x in proposals]
|
477 |
+
box_features = self._shared_roi_transform(
|
478 |
+
[features[f] for f in self.in_features], proposal_boxes
|
479 |
+
)
|
480 |
+
predictions = self.box_predictor(box_features.mean(dim=[2, 3]))
|
481 |
+
|
482 |
+
if self.training:
|
483 |
+
del features
|
484 |
+
losses = self.box_predictor.losses(predictions, proposals)
|
485 |
+
if self.mask_on:
|
486 |
+
proposals, fg_selection_masks = select_foreground_proposals(
|
487 |
+
proposals, self.num_classes
|
488 |
+
)
|
489 |
+
# Since the ROI feature transform is shared between boxes and masks,
|
490 |
+
# we don't need to recompute features. The mask loss is only defined
|
491 |
+
# on foreground proposals, so we need to select out the foreground
|
492 |
+
# features.
|
493 |
+
mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
|
494 |
+
del box_features
|
495 |
+
losses.update(self.mask_head(mask_features, proposals))
|
496 |
+
return [], losses
|
497 |
+
else:
|
498 |
+
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
|
499 |
+
pred_instances = self.forward_with_given_boxes(features, pred_instances)
|
500 |
+
return pred_instances, {}
|
501 |
+
|
502 |
+
def forward_with_given_boxes(
|
503 |
+
self, features: Dict[str, torch.Tensor], instances: List[Instances]
|
504 |
+
) -> List[Instances]:
|
505 |
+
"""
|
506 |
+
Use the given boxes in `instances` to produce other (non-box) per-ROI outputs.
|
507 |
+
|
508 |
+
Args:
|
509 |
+
features: same as in `forward()`
|
510 |
+
instances (list[Instances]): instances to predict other outputs. Expect the keys
|
511 |
+
"pred_boxes" and "pred_classes" to exist.
|
512 |
+
|
513 |
+
Returns:
|
514 |
+
instances (Instances):
|
515 |
+
the same `Instances` object, with extra
|
516 |
+
fields such as `pred_masks` or `pred_keypoints`.
|
517 |
+
"""
|
518 |
+
assert not self.training
|
519 |
+
assert instances[0].has("pred_boxes") and instances[0].has("pred_classes")
|
520 |
+
|
521 |
+
if self.mask_on:
|
522 |
+
feature_list = [features[f] for f in self.in_features]
|
523 |
+
x = self._shared_roi_transform(feature_list, [x.pred_boxes for x in instances])
|
524 |
+
return self.mask_head(x, instances)
|
525 |
+
else:
|
526 |
+
return instances
|
527 |
+
|
528 |
+
|
529 |
+
@ROI_HEADS_REGISTRY.register()
|
530 |
+
class StandardROIHeads(ROIHeads):
|
531 |
+
"""
|
532 |
+
It's "standard" in a sense that there is no ROI transform sharing
|
533 |
+
or feature sharing between tasks.
|
534 |
+
Each head independently processes the input features by each head's
|
535 |
+
own pooler and head.
|
536 |
+
|
537 |
+
This class is used by most models, such as FPN and C5.
|
538 |
+
To implement more models, you can subclass it and implement a different
|
539 |
+
:meth:`forward()` or a head.
|
540 |
+
"""
|
541 |
+
|
542 |
+
@configurable
|
543 |
+
def __init__(
|
544 |
+
self,
|
545 |
+
*,
|
546 |
+
box_in_features: List[str],
|
547 |
+
box_pooler: ROIPooler,
|
548 |
+
box_head: nn.Module,
|
549 |
+
box_predictor: nn.Module,
|
550 |
+
mask_in_features: Optional[List[str]] = None,
|
551 |
+
mask_pooler: Optional[ROIPooler] = None,
|
552 |
+
mask_head: Optional[nn.Module] = None,
|
553 |
+
keypoint_in_features: Optional[List[str]] = None,
|
554 |
+
keypoint_pooler: Optional[ROIPooler] = None,
|
555 |
+
keypoint_head: Optional[nn.Module] = None,
|
556 |
+
train_on_pred_boxes: bool = False,
|
557 |
+
**kwargs,
|
558 |
+
):
|
559 |
+
"""
|
560 |
+
NOTE: this interface is experimental.
|
561 |
+
|
562 |
+
Args:
|
563 |
+
box_in_features (list[str]): list of feature names to use for the box head.
|
564 |
+
box_pooler (ROIPooler): pooler to extra region features for box head
|
565 |
+
box_head (nn.Module): transform features to make box predictions
|
566 |
+
box_predictor (nn.Module): make box predictions from the feature.
|
567 |
+
Should have the same interface as :class:`FastRCNNOutputLayers`.
|
568 |
+
mask_in_features (list[str]): list of feature names to use for the mask
|
569 |
+
pooler or mask head. None if not using mask head.
|
570 |
+
mask_pooler (ROIPooler): pooler to extract region features from image features.
|
571 |
+
The mask head will then take region features to make predictions.
|
572 |
+
If None, the mask head will directly take the dict of image features
|
573 |
+
defined by `mask_in_features`
|
574 |
+
mask_head (nn.Module): transform features to make mask predictions
|
575 |
+
keypoint_in_features, keypoint_pooler, keypoint_head: similar to ``mask_*``.
|
576 |
+
train_on_pred_boxes (bool): whether to use proposal boxes or
|
577 |
+
predicted boxes from the box head to train other heads.
|
578 |
+
"""
|
579 |
+
super().__init__(**kwargs)
|
580 |
+
# keep self.in_features for backward compatibility
|
581 |
+
self.in_features = self.box_in_features = box_in_features
|
582 |
+
self.box_pooler = box_pooler
|
583 |
+
self.box_head = box_head
|
584 |
+
self.box_predictor = box_predictor
|
585 |
+
|
586 |
+
self.mask_on = mask_in_features is not None
|
587 |
+
if self.mask_on:
|
588 |
+
self.mask_in_features = mask_in_features
|
589 |
+
self.mask_pooler = mask_pooler
|
590 |
+
self.mask_head = mask_head
|
591 |
+
|
592 |
+
self.keypoint_on = keypoint_in_features is not None
|
593 |
+
if self.keypoint_on:
|
594 |
+
self.keypoint_in_features = keypoint_in_features
|
595 |
+
self.keypoint_pooler = keypoint_pooler
|
596 |
+
self.keypoint_head = keypoint_head
|
597 |
+
|
598 |
+
self.train_on_pred_boxes = train_on_pred_boxes
|
599 |
+
|
600 |
+
@classmethod
|
601 |
+
def from_config(cls, cfg, input_shape):
|
602 |
+
ret = super().from_config(cfg)
|
603 |
+
ret["train_on_pred_boxes"] = cfg.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES
|
604 |
+
# Subclasses that have not been updated to use from_config style construction
|
605 |
+
# may have overridden _init_*_head methods. In this case, those overridden methods
|
606 |
+
# will not be classmethods and we need to avoid trying to call them here.
|
607 |
+
# We test for this with ismethod which only returns True for bound methods of cls.
|
608 |
+
# Such subclasses will need to handle calling their overridden _init_*_head methods.
|
609 |
+
if inspect.ismethod(cls._init_box_head):
|
610 |
+
ret.update(cls._init_box_head(cfg, input_shape))
|
611 |
+
if inspect.ismethod(cls._init_mask_head):
|
612 |
+
ret.update(cls._init_mask_head(cfg, input_shape))
|
613 |
+
if inspect.ismethod(cls._init_keypoint_head):
|
614 |
+
ret.update(cls._init_keypoint_head(cfg, input_shape))
|
615 |
+
return ret
|
616 |
+
|
617 |
+
@classmethod
|
618 |
+
def _init_box_head(cls, cfg, input_shape):
|
619 |
+
# fmt: off
|
620 |
+
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
621 |
+
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
|
622 |
+
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
|
623 |
+
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
624 |
+
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
|
625 |
+
# fmt: on
|
626 |
+
|
627 |
+
# If StandardROIHeads is applied on multiple feature maps (as in FPN),
|
628 |
+
# then we share the same predictors and therefore the channel counts must be the same
|
629 |
+
in_channels = [input_shape[f].channels for f in in_features]
|
630 |
+
# Check all channel counts are equal
|
631 |
+
assert len(set(in_channels)) == 1, in_channels
|
632 |
+
in_channels = in_channels[0]
|
633 |
+
|
634 |
+
box_pooler = ROIPooler(
|
635 |
+
output_size=pooler_resolution,
|
636 |
+
scales=pooler_scales,
|
637 |
+
sampling_ratio=sampling_ratio,
|
638 |
+
pooler_type=pooler_type,
|
639 |
+
)
|
640 |
+
# Here we split "box head" and "box predictor", which is mainly due to historical reasons.
|
641 |
+
# They are used together so the "box predictor" layers should be part of the "box head".
|
642 |
+
# New subclasses of ROIHeads do not need "box predictor"s.
|
643 |
+
box_head = build_box_head(
|
644 |
+
cfg, ShapeSpec(channels=in_channels, height=pooler_resolution, width=pooler_resolution)
|
645 |
+
)
|
646 |
+
box_predictor = FastRCNNOutputLayers(cfg, box_head.output_shape)
|
647 |
+
return {
|
648 |
+
"box_in_features": in_features,
|
649 |
+
"box_pooler": box_pooler,
|
650 |
+
"box_head": box_head,
|
651 |
+
"box_predictor": box_predictor,
|
652 |
+
}
|
653 |
+
|
654 |
+
@classmethod
|
655 |
+
def _init_mask_head(cls, cfg, input_shape):
|
656 |
+
if not cfg.MODEL.MASK_ON:
|
657 |
+
return {}
|
658 |
+
# fmt: off
|
659 |
+
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
660 |
+
pooler_resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
|
661 |
+
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
|
662 |
+
sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO
|
663 |
+
pooler_type = cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE
|
664 |
+
# fmt: on
|
665 |
+
|
666 |
+
in_channels = [input_shape[f].channels for f in in_features][0]
|
667 |
+
|
668 |
+
ret = {"mask_in_features": in_features}
|
669 |
+
ret["mask_pooler"] = (
|
670 |
+
ROIPooler(
|
671 |
+
output_size=pooler_resolution,
|
672 |
+
scales=pooler_scales,
|
673 |
+
sampling_ratio=sampling_ratio,
|
674 |
+
pooler_type=pooler_type,
|
675 |
+
)
|
676 |
+
if pooler_type
|
677 |
+
else None
|
678 |
+
)
|
679 |
+
if pooler_type:
|
680 |
+
shape = ShapeSpec(
|
681 |
+
channels=in_channels, width=pooler_resolution, height=pooler_resolution
|
682 |
+
)
|
683 |
+
else:
|
684 |
+
shape = {f: input_shape[f] for f in in_features}
|
685 |
+
ret["mask_head"] = build_mask_head(cfg, shape)
|
686 |
+
return ret
|
687 |
+
|
688 |
+
@classmethod
|
689 |
+
def _init_keypoint_head(cls, cfg, input_shape):
|
690 |
+
if not cfg.MODEL.KEYPOINT_ON:
|
691 |
+
return {}
|
692 |
+
# fmt: off
|
693 |
+
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
694 |
+
pooler_resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION
|
695 |
+
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) # noqa
|
696 |
+
sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO
|
697 |
+
pooler_type = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE
|
698 |
+
# fmt: on
|
699 |
+
|
700 |
+
in_channels = [input_shape[f].channels for f in in_features][0]
|
701 |
+
|
702 |
+
ret = {"keypoint_in_features": in_features}
|
703 |
+
ret["keypoint_pooler"] = (
|
704 |
+
ROIPooler(
|
705 |
+
output_size=pooler_resolution,
|
706 |
+
scales=pooler_scales,
|
707 |
+
sampling_ratio=sampling_ratio,
|
708 |
+
pooler_type=pooler_type,
|
709 |
+
)
|
710 |
+
if pooler_type
|
711 |
+
else None
|
712 |
+
)
|
713 |
+
if pooler_type:
|
714 |
+
shape = ShapeSpec(
|
715 |
+
channels=in_channels, width=pooler_resolution, height=pooler_resolution
|
716 |
+
)
|
717 |
+
else:
|
718 |
+
shape = {f: input_shape[f] for f in in_features}
|
719 |
+
ret["keypoint_head"] = build_keypoint_head(cfg, shape)
|
720 |
+
return ret
|
721 |
+
|
722 |
+
def forward(
|
723 |
+
self,
|
724 |
+
images: ImageList,
|
725 |
+
features: Dict[str, torch.Tensor],
|
726 |
+
proposals: List[Instances],
|
727 |
+
targets: Optional[List[Instances]] = None,
|
728 |
+
) -> Tuple[List[Instances], Dict[str, torch.Tensor]]:
|
729 |
+
"""
|
730 |
+
See :class:`ROIHeads.forward`.
|
731 |
+
"""
|
732 |
+
del images
|
733 |
+
if self.training:
|
734 |
+
assert targets, "'targets' argument is required during training"
|
735 |
+
proposals = self.label_and_sample_proposals(proposals, targets)
|
736 |
+
del targets
|
737 |
+
|
738 |
+
if self.training:
|
739 |
+
losses = self._forward_box(features, proposals)
|
740 |
+
# Usually the original proposals used by the box head are used by the mask, keypoint
|
741 |
+
# heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes
|
742 |
+
# predicted by the box head.
|
743 |
+
losses.update(self._forward_mask(features, proposals))
|
744 |
+
losses.update(self._forward_keypoint(features, proposals))
|
745 |
+
return proposals, losses
|
746 |
+
else:
|
747 |
+
pred_instances = self._forward_box(features, proposals)
|
748 |
+
# During inference cascaded prediction is used: the mask and keypoints heads are only
|
749 |
+
# applied to the top scoring box detections.
|
750 |
+
pred_instances = self.forward_with_given_boxes(features, pred_instances)
|
751 |
+
return pred_instances, {}
|
752 |
+
|
753 |
+
def forward_with_given_boxes(
|
754 |
+
self, features: Dict[str, torch.Tensor], instances: List[Instances]
|
755 |
+
) -> List[Instances]:
|
756 |
+
"""
|
757 |
+
Use the given boxes in `instances` to produce other (non-box) per-ROI outputs.
|
758 |
+
|
759 |
+
This is useful for downstream tasks where a box is known, but need to obtain
|
760 |
+
other attributes (outputs of other heads).
|
761 |
+
Test-time augmentation also uses this.
|
762 |
+
|
763 |
+
Args:
|
764 |
+
features: same as in `forward()`
|
765 |
+
instances (list[Instances]): instances to predict other outputs. Expect the keys
|
766 |
+
"pred_boxes" and "pred_classes" to exist.
|
767 |
+
|
768 |
+
Returns:
|
769 |
+
list[Instances]:
|
770 |
+
the same `Instances` objects, with extra
|
771 |
+
fields such as `pred_masks` or `pred_keypoints`.
|
772 |
+
"""
|
773 |
+
assert not self.training
|
774 |
+
assert instances[0].has("pred_boxes") and instances[0].has("pred_classes")
|
775 |
+
|
776 |
+
instances = self._forward_mask(features, instances)
|
777 |
+
instances = self._forward_keypoint(features, instances)
|
778 |
+
return instances
|
779 |
+
|
780 |
+
def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]):
|
781 |
+
"""
|
782 |
+
Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`,
|
783 |
+
the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument.
|
784 |
+
|
785 |
+
Args:
|
786 |
+
features (dict[str, Tensor]): mapping from feature map names to tensor.
|
787 |
+
Same as in :meth:`ROIHeads.forward`.
|
788 |
+
proposals (list[Instances]): the per-image object proposals with
|
789 |
+
their matching ground truth.
|
790 |
+
Each has fields "proposal_boxes", and "objectness_logits",
|
791 |
+
"gt_classes", "gt_boxes".
|
792 |
+
|
793 |
+
Returns:
|
794 |
+
In training, a dict of losses.
|
795 |
+
In inference, a list of `Instances`, the predicted instances.
|
796 |
+
"""
|
797 |
+
features = [features[f] for f in self.box_in_features]
|
798 |
+
box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])
|
799 |
+
box_features = self.box_head(box_features)
|
800 |
+
predictions = self.box_predictor(box_features)
|
801 |
+
del box_features
|
802 |
+
|
803 |
+
if self.training:
|
804 |
+
losses = self.box_predictor.losses(predictions, proposals)
|
805 |
+
# proposals is modified in-place below, so losses must be computed first.
|
806 |
+
if self.train_on_pred_boxes:
|
807 |
+
with torch.no_grad():
|
808 |
+
pred_boxes = self.box_predictor.predict_boxes_for_gt_classes(
|
809 |
+
predictions, proposals
|
810 |
+
)
|
811 |
+
for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes):
|
812 |
+
proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image)
|
813 |
+
return losses
|
814 |
+
else:
|
815 |
+
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
|
816 |
+
return pred_instances
|
817 |
+
|
818 |
+
def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Instances]):
|
819 |
+
"""
|
820 |
+
Forward logic of the mask prediction branch.
|
821 |
+
|
822 |
+
Args:
|
823 |
+
features (dict[str, Tensor]): mapping from feature map names to tensor.
|
824 |
+
Same as in :meth:`ROIHeads.forward`.
|
825 |
+
instances (list[Instances]): the per-image instances to train/predict masks.
|
826 |
+
In training, they can be the proposals.
|
827 |
+
In inference, they can be the boxes predicted by R-CNN box head.
|
828 |
+
|
829 |
+
Returns:
|
830 |
+
In training, a dict of losses.
|
831 |
+
In inference, update `instances` with new fields "pred_masks" and return it.
|
832 |
+
"""
|
833 |
+
if not self.mask_on:
|
834 |
+
return {} if self.training else instances
|
835 |
+
|
836 |
+
if self.training:
|
837 |
+
# head is only trained on positive proposals.
|
838 |
+
instances, _ = select_foreground_proposals(instances, self.num_classes)
|
839 |
+
|
840 |
+
if self.mask_pooler is not None:
|
841 |
+
features = [features[f] for f in self.mask_in_features]
|
842 |
+
boxes = [x.proposal_boxes if self.training else x.pred_boxes for x in instances]
|
843 |
+
features = self.mask_pooler(features, boxes)
|
844 |
+
else:
|
845 |
+
features = {f: features[f] for f in self.mask_in_features}
|
846 |
+
return self.mask_head(features, instances)
|
847 |
+
|
848 |
+
def _forward_keypoint(self, features: Dict[str, torch.Tensor], instances: List[Instances]):
|
849 |
+
"""
|
850 |
+
Forward logic of the keypoint prediction branch.
|
851 |
+
|
852 |
+
Args:
|
853 |
+
features (dict[str, Tensor]): mapping from feature map names to tensor.
|
854 |
+
Same as in :meth:`ROIHeads.forward`.
|
855 |
+
instances (list[Instances]): the per-image instances to train/predict keypoints.
|
856 |
+
In training, they can be the proposals.
|
857 |
+
In inference, they can be the boxes predicted by R-CNN box head.
|
858 |
+
|
859 |
+
Returns:
|
860 |
+
In training, a dict of losses.
|
861 |
+
In inference, update `instances` with new fields "pred_keypoints" and return it.
|
862 |
+
"""
|
863 |
+
if not self.keypoint_on:
|
864 |
+
return {} if self.training else instances
|
865 |
+
|
866 |
+
if self.training:
|
867 |
+
# head is only trained on positive proposals with >=1 visible keypoints.
|
868 |
+
instances, _ = select_foreground_proposals(instances, self.num_classes)
|
869 |
+
instances = select_proposals_with_visible_keypoints(instances)
|
870 |
+
|
871 |
+
if self.keypoint_pooler is not None:
|
872 |
+
features = [features[f] for f in self.keypoint_in_features]
|
873 |
+
boxes = [x.proposal_boxes if self.training else x.pred_boxes for x in instances]
|
874 |
+
features = self.keypoint_pooler(features, boxes)
|
875 |
+
else:
|
876 |
+
features = {f: features[f] for f in self.keypoint_in_features}
|
877 |
+
return self.keypoint_head(features, instances)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/modeling/roi_heads/rotated_fast_rcnn.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from annotator.oneformer.detectron2.config import configurable
|
7 |
+
from annotator.oneformer.detectron2.layers import ShapeSpec, batched_nms_rotated
|
8 |
+
from annotator.oneformer.detectron2.structures import Instances, RotatedBoxes, pairwise_iou_rotated
|
9 |
+
from annotator.oneformer.detectron2.utils.events import get_event_storage
|
10 |
+
|
11 |
+
from ..box_regression import Box2BoxTransformRotated
|
12 |
+
from ..poolers import ROIPooler
|
13 |
+
from ..proposal_generator.proposal_utils import add_ground_truth_to_proposals
|
14 |
+
from .box_head import build_box_head
|
15 |
+
from .fast_rcnn import FastRCNNOutputLayers
|
16 |
+
from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
"""
|
21 |
+
Shape shorthand in this module:
|
22 |
+
|
23 |
+
N: number of images in the minibatch
|
24 |
+
R: number of ROIs, combined over all images, in the minibatch
|
25 |
+
Ri: number of ROIs in image i
|
26 |
+
K: number of foreground classes. E.g.,there are 80 foreground classes in COCO.
|
27 |
+
|
28 |
+
Naming convention:
|
29 |
+
|
30 |
+
deltas: refers to the 5-d (dx, dy, dw, dh, da) deltas that parameterize the box2box
|
31 |
+
transform (see :class:`box_regression.Box2BoxTransformRotated`).
|
32 |
+
|
33 |
+
pred_class_logits: predicted class scores in [-inf, +inf]; use
|
34 |
+
softmax(pred_class_logits) to estimate P(class).
|
35 |
+
|
36 |
+
gt_classes: ground-truth classification labels in [0, K], where [0, K) represent
|
37 |
+
foreground object classes and K represents the background class.
|
38 |
+
|
39 |
+
pred_proposal_deltas: predicted rotated box2box transform deltas for transforming proposals
|
40 |
+
to detection box predictions.
|
41 |
+
|
42 |
+
gt_proposal_deltas: ground-truth rotated box2box transform deltas
|
43 |
+
"""
|
44 |
+
|
45 |
+
|
46 |
+
def fast_rcnn_inference_rotated(
|
47 |
+
boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Call `fast_rcnn_inference_single_image_rotated` for all images.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic
|
54 |
+
boxes for each image. Element i has shape (Ri, K * 5) if doing
|
55 |
+
class-specific regression, or (Ri, 5) if doing class-agnostic
|
56 |
+
regression, where Ri is the number of predicted objects for image i.
|
57 |
+
This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`.
|
58 |
+
scores (list[Tensor]): A list of Tensors of predicted class scores for each image.
|
59 |
+
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
|
60 |
+
for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`.
|
61 |
+
image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch.
|
62 |
+
score_thresh (float): Only return detections with a confidence score exceeding this
|
63 |
+
threshold.
|
64 |
+
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
65 |
+
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return
|
66 |
+
all detections.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
instances: (list[Instances]): A list of N instances, one for each image in the batch,
|
70 |
+
that stores the topk most confidence detections.
|
71 |
+
kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates
|
72 |
+
the corresponding boxes/scores index in [0, Ri) from the input, for image i.
|
73 |
+
"""
|
74 |
+
result_per_image = [
|
75 |
+
fast_rcnn_inference_single_image_rotated(
|
76 |
+
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image
|
77 |
+
)
|
78 |
+
for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes)
|
79 |
+
]
|
80 |
+
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
|
81 |
+
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def fast_rcnn_inference_single_image_rotated(
|
85 |
+
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Single-image inference. Return rotated bounding-box detection results by thresholding
|
89 |
+
on scores and applying rotated non-maximum suppression (Rotated NMS).
|
90 |
+
|
91 |
+
Args:
|
92 |
+
Same as `fast_rcnn_inference_rotated`, but with rotated boxes, scores, and image shapes
|
93 |
+
per image.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Same as `fast_rcnn_inference_rotated`, but for only one image.
|
97 |
+
"""
|
98 |
+
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
|
99 |
+
if not valid_mask.all():
|
100 |
+
boxes = boxes[valid_mask]
|
101 |
+
scores = scores[valid_mask]
|
102 |
+
|
103 |
+
B = 5 # box dimension
|
104 |
+
scores = scores[:, :-1]
|
105 |
+
num_bbox_reg_classes = boxes.shape[1] // B
|
106 |
+
# Convert to Boxes to use the `clip` function ...
|
107 |
+
boxes = RotatedBoxes(boxes.reshape(-1, B))
|
108 |
+
boxes.clip(image_shape)
|
109 |
+
boxes = boxes.tensor.view(-1, num_bbox_reg_classes, B) # R x C x B
|
110 |
+
# Filter results based on detection scores
|
111 |
+
filter_mask = scores > score_thresh # R x K
|
112 |
+
# R' x 2. First column contains indices of the R predictions;
|
113 |
+
# Second column contains indices of classes.
|
114 |
+
filter_inds = filter_mask.nonzero()
|
115 |
+
if num_bbox_reg_classes == 1:
|
116 |
+
boxes = boxes[filter_inds[:, 0], 0]
|
117 |
+
else:
|
118 |
+
boxes = boxes[filter_mask]
|
119 |
+
scores = scores[filter_mask]
|
120 |
+
|
121 |
+
# Apply per-class Rotated NMS
|
122 |
+
keep = batched_nms_rotated(boxes, scores, filter_inds[:, 1], nms_thresh)
|
123 |
+
if topk_per_image >= 0:
|
124 |
+
keep = keep[:topk_per_image]
|
125 |
+
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
|
126 |
+
|
127 |
+
result = Instances(image_shape)
|
128 |
+
result.pred_boxes = RotatedBoxes(boxes)
|
129 |
+
result.scores = scores
|
130 |
+
result.pred_classes = filter_inds[:, 1]
|
131 |
+
|
132 |
+
return result, filter_inds[:, 0]
|
133 |
+
|
134 |
+
|
135 |
+
class RotatedFastRCNNOutputLayers(FastRCNNOutputLayers):
|
136 |
+
"""
|
137 |
+
Two linear layers for predicting Rotated Fast R-CNN outputs.
|
138 |
+
"""
|
139 |
+
|
140 |
+
@classmethod
|
141 |
+
def from_config(cls, cfg, input_shape):
|
142 |
+
args = super().from_config(cfg, input_shape)
|
143 |
+
args["box2box_transform"] = Box2BoxTransformRotated(
|
144 |
+
weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS
|
145 |
+
)
|
146 |
+
return args
|
147 |
+
|
148 |
+
def inference(self, predictions, proposals):
|
149 |
+
"""
|
150 |
+
Returns:
|
151 |
+
list[Instances]: same as `fast_rcnn_inference_rotated`.
|
152 |
+
list[Tensor]: same as `fast_rcnn_inference_rotated`.
|
153 |
+
"""
|
154 |
+
boxes = self.predict_boxes(predictions, proposals)
|
155 |
+
scores = self.predict_probs(predictions, proposals)
|
156 |
+
image_shapes = [x.image_size for x in proposals]
|
157 |
+
|
158 |
+
return fast_rcnn_inference_rotated(
|
159 |
+
boxes,
|
160 |
+
scores,
|
161 |
+
image_shapes,
|
162 |
+
self.test_score_thresh,
|
163 |
+
self.test_nms_thresh,
|
164 |
+
self.test_topk_per_image,
|
165 |
+
)
|
166 |
+
|
167 |
+
|
168 |
+
@ROI_HEADS_REGISTRY.register()
|
169 |
+
class RROIHeads(StandardROIHeads):
|
170 |
+
"""
|
171 |
+
This class is used by Rotated Fast R-CNN to detect rotated boxes.
|
172 |
+
For now, it only supports box predictions but not mask or keypoints.
|
173 |
+
"""
|
174 |
+
|
175 |
+
@configurable
|
176 |
+
def __init__(self, **kwargs):
|
177 |
+
"""
|
178 |
+
NOTE: this interface is experimental.
|
179 |
+
"""
|
180 |
+
super().__init__(**kwargs)
|
181 |
+
assert (
|
182 |
+
not self.mask_on and not self.keypoint_on
|
183 |
+
), "Mask/Keypoints not supported in Rotated ROIHeads."
|
184 |
+
assert not self.train_on_pred_boxes, "train_on_pred_boxes not implemented for RROIHeads!"
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def _init_box_head(cls, cfg, input_shape):
|
188 |
+
# fmt: off
|
189 |
+
in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
|
190 |
+
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
|
191 |
+
pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features)
|
192 |
+
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
193 |
+
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE
|
194 |
+
# fmt: on
|
195 |
+
assert pooler_type in ["ROIAlignRotated"], pooler_type
|
196 |
+
# assume all channel counts are equal
|
197 |
+
in_channels = [input_shape[f].channels for f in in_features][0]
|
198 |
+
|
199 |
+
box_pooler = ROIPooler(
|
200 |
+
output_size=pooler_resolution,
|
201 |
+
scales=pooler_scales,
|
202 |
+
sampling_ratio=sampling_ratio,
|
203 |
+
pooler_type=pooler_type,
|
204 |
+
)
|
205 |
+
box_head = build_box_head(
|
206 |
+
cfg, ShapeSpec(channels=in_channels, height=pooler_resolution, width=pooler_resolution)
|
207 |
+
)
|
208 |
+
# This line is the only difference v.s. StandardROIHeads
|
209 |
+
box_predictor = RotatedFastRCNNOutputLayers(cfg, box_head.output_shape)
|
210 |
+
return {
|
211 |
+
"box_in_features": in_features,
|
212 |
+
"box_pooler": box_pooler,
|
213 |
+
"box_head": box_head,
|
214 |
+
"box_predictor": box_predictor,
|
215 |
+
}
|
216 |
+
|
217 |
+
@torch.no_grad()
|
218 |
+
def label_and_sample_proposals(self, proposals, targets):
|
219 |
+
"""
|
220 |
+
Prepare some proposals to be used to train the RROI heads.
|
221 |
+
It performs box matching between `proposals` and `targets`, and assigns
|
222 |
+
training labels to the proposals.
|
223 |
+
It returns `self.batch_size_per_image` random samples from proposals and groundtruth boxes,
|
224 |
+
with a fraction of positives that is no larger than `self.positive_sample_fraction.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
See :meth:`StandardROIHeads.forward`
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
list[Instances]: length `N` list of `Instances`s containing the proposals
|
231 |
+
sampled for training. Each `Instances` has the following fields:
|
232 |
+
- proposal_boxes: the rotated proposal boxes
|
233 |
+
- gt_boxes: the ground-truth rotated boxes that the proposal is assigned to
|
234 |
+
(this is only meaningful if the proposal has a label > 0; if label = 0
|
235 |
+
then the ground-truth box is random)
|
236 |
+
- gt_classes: the ground-truth classification lable for each proposal
|
237 |
+
"""
|
238 |
+
if self.proposal_append_gt:
|
239 |
+
proposals = add_ground_truth_to_proposals(targets, proposals)
|
240 |
+
|
241 |
+
proposals_with_gt = []
|
242 |
+
|
243 |
+
num_fg_samples = []
|
244 |
+
num_bg_samples = []
|
245 |
+
for proposals_per_image, targets_per_image in zip(proposals, targets):
|
246 |
+
has_gt = len(targets_per_image) > 0
|
247 |
+
match_quality_matrix = pairwise_iou_rotated(
|
248 |
+
targets_per_image.gt_boxes, proposals_per_image.proposal_boxes
|
249 |
+
)
|
250 |
+
matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix)
|
251 |
+
sampled_idxs, gt_classes = self._sample_proposals(
|
252 |
+
matched_idxs, matched_labels, targets_per_image.gt_classes
|
253 |
+
)
|
254 |
+
|
255 |
+
proposals_per_image = proposals_per_image[sampled_idxs]
|
256 |
+
proposals_per_image.gt_classes = gt_classes
|
257 |
+
|
258 |
+
if has_gt:
|
259 |
+
sampled_targets = matched_idxs[sampled_idxs]
|
260 |
+
proposals_per_image.gt_boxes = targets_per_image.gt_boxes[sampled_targets]
|
261 |
+
|
262 |
+
num_bg_samples.append((gt_classes == self.num_classes).sum().item())
|
263 |
+
num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])
|
264 |
+
proposals_with_gt.append(proposals_per_image)
|
265 |
+
|
266 |
+
# Log the number of fg/bg samples that are selected for training ROI heads
|
267 |
+
storage = get_event_storage()
|
268 |
+
storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples))
|
269 |
+
storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples))
|
270 |
+
|
271 |
+
return proposals_with_gt
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/modeling/sampling.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from annotator.oneformer.detectron2.layers import nonzero_tuple
|
5 |
+
|
6 |
+
__all__ = ["subsample_labels"]
|
7 |
+
|
8 |
+
|
9 |
+
def subsample_labels(
|
10 |
+
labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int
|
11 |
+
):
|
12 |
+
"""
|
13 |
+
Return `num_samples` (or fewer, if not enough found)
|
14 |
+
random samples from `labels` which is a mixture of positives & negatives.
|
15 |
+
It will try to return as many positives as possible without
|
16 |
+
exceeding `positive_fraction * num_samples`, and then try to
|
17 |
+
fill the remaining slots with negatives.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
labels (Tensor): (N, ) label vector with values:
|
21 |
+
* -1: ignore
|
22 |
+
* bg_label: background ("negative") class
|
23 |
+
* otherwise: one or more foreground ("positive") classes
|
24 |
+
num_samples (int): The total number of labels with value >= 0 to return.
|
25 |
+
Values that are not sampled will be filled with -1 (ignore).
|
26 |
+
positive_fraction (float): The number of subsampled labels with values > 0
|
27 |
+
is `min(num_positives, int(positive_fraction * num_samples))`. The number
|
28 |
+
of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
|
29 |
+
In order words, if there are not enough positives, the sample is filled with
|
30 |
+
negatives. If there are also not enough negatives, then as many elements are
|
31 |
+
sampled as is possible.
|
32 |
+
bg_label (int): label index of background ("negative") class.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
pos_idx, neg_idx (Tensor):
|
36 |
+
1D vector of indices. The total length of both is `num_samples` or fewer.
|
37 |
+
"""
|
38 |
+
positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0]
|
39 |
+
negative = nonzero_tuple(labels == bg_label)[0]
|
40 |
+
|
41 |
+
num_pos = int(num_samples * positive_fraction)
|
42 |
+
# protect against not enough positive examples
|
43 |
+
num_pos = min(positive.numel(), num_pos)
|
44 |
+
num_neg = num_samples - num_pos
|
45 |
+
# protect against not enough negative examples
|
46 |
+
num_neg = min(negative.numel(), num_neg)
|
47 |
+
|
48 |
+
# randomly select positive and negative examples
|
49 |
+
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
|
50 |
+
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
|
51 |
+
|
52 |
+
pos_idx = positive[perm1]
|
53 |
+
neg_idx = negative[perm2]
|
54 |
+
return pos_idx, neg_idx
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/modeling/test_time_augmentation.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import numpy as np
|
4 |
+
from contextlib import contextmanager
|
5 |
+
from itertools import count
|
6 |
+
from typing import List
|
7 |
+
import torch
|
8 |
+
from fvcore.transforms import HFlipTransform, NoOpTransform
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn.parallel import DistributedDataParallel
|
11 |
+
|
12 |
+
from annotator.oneformer.detectron2.config import configurable
|
13 |
+
from annotator.oneformer.detectron2.data.detection_utils import read_image
|
14 |
+
from annotator.oneformer.detectron2.data.transforms import (
|
15 |
+
RandomFlip,
|
16 |
+
ResizeShortestEdge,
|
17 |
+
ResizeTransform,
|
18 |
+
apply_augmentations,
|
19 |
+
)
|
20 |
+
from annotator.oneformer.detectron2.structures import Boxes, Instances
|
21 |
+
|
22 |
+
from .meta_arch import GeneralizedRCNN
|
23 |
+
from .postprocessing import detector_postprocess
|
24 |
+
from .roi_heads.fast_rcnn import fast_rcnn_inference_single_image
|
25 |
+
|
26 |
+
__all__ = ["DatasetMapperTTA", "GeneralizedRCNNWithTTA"]
|
27 |
+
|
28 |
+
|
29 |
+
class DatasetMapperTTA:
|
30 |
+
"""
|
31 |
+
Implement test-time augmentation for detection data.
|
32 |
+
It is a callable which takes a dataset dict from a detection dataset,
|
33 |
+
and returns a list of dataset dicts where the images
|
34 |
+
are augmented from the input image by the transformations defined in the config.
|
35 |
+
This is used for test-time augmentation.
|
36 |
+
"""
|
37 |
+
|
38 |
+
@configurable
|
39 |
+
def __init__(self, min_sizes: List[int], max_size: int, flip: bool):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
min_sizes: list of short-edge size to resize the image to
|
43 |
+
max_size: maximum height or width of resized images
|
44 |
+
flip: whether to apply flipping augmentation
|
45 |
+
"""
|
46 |
+
self.min_sizes = min_sizes
|
47 |
+
self.max_size = max_size
|
48 |
+
self.flip = flip
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def from_config(cls, cfg):
|
52 |
+
return {
|
53 |
+
"min_sizes": cfg.TEST.AUG.MIN_SIZES,
|
54 |
+
"max_size": cfg.TEST.AUG.MAX_SIZE,
|
55 |
+
"flip": cfg.TEST.AUG.FLIP,
|
56 |
+
}
|
57 |
+
|
58 |
+
def __call__(self, dataset_dict):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
dict: a dict in standard model input format. See tutorials for details.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
list[dict]:
|
65 |
+
a list of dicts, which contain augmented version of the input image.
|
66 |
+
The total number of dicts is ``len(min_sizes) * (2 if flip else 1)``.
|
67 |
+
Each dict has field "transforms" which is a TransformList,
|
68 |
+
containing the transforms that are used to generate this image.
|
69 |
+
"""
|
70 |
+
numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy()
|
71 |
+
shape = numpy_image.shape
|
72 |
+
orig_shape = (dataset_dict["height"], dataset_dict["width"])
|
73 |
+
if shape[:2] != orig_shape:
|
74 |
+
# It transforms the "original" image in the dataset to the input image
|
75 |
+
pre_tfm = ResizeTransform(orig_shape[0], orig_shape[1], shape[0], shape[1])
|
76 |
+
else:
|
77 |
+
pre_tfm = NoOpTransform()
|
78 |
+
|
79 |
+
# Create all combinations of augmentations to use
|
80 |
+
aug_candidates = [] # each element is a list[Augmentation]
|
81 |
+
for min_size in self.min_sizes:
|
82 |
+
resize = ResizeShortestEdge(min_size, self.max_size)
|
83 |
+
aug_candidates.append([resize]) # resize only
|
84 |
+
if self.flip:
|
85 |
+
flip = RandomFlip(prob=1.0)
|
86 |
+
aug_candidates.append([resize, flip]) # resize + flip
|
87 |
+
|
88 |
+
# Apply all the augmentations
|
89 |
+
ret = []
|
90 |
+
for aug in aug_candidates:
|
91 |
+
new_image, tfms = apply_augmentations(aug, np.copy(numpy_image))
|
92 |
+
torch_image = torch.from_numpy(np.ascontiguousarray(new_image.transpose(2, 0, 1)))
|
93 |
+
|
94 |
+
dic = copy.deepcopy(dataset_dict)
|
95 |
+
dic["transforms"] = pre_tfm + tfms
|
96 |
+
dic["image"] = torch_image
|
97 |
+
ret.append(dic)
|
98 |
+
return ret
|
99 |
+
|
100 |
+
|
101 |
+
class GeneralizedRCNNWithTTA(nn.Module):
|
102 |
+
"""
|
103 |
+
A GeneralizedRCNN with test-time augmentation enabled.
|
104 |
+
Its :meth:`__call__` method has the same interface as :meth:`GeneralizedRCNN.forward`.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, cfg, model, tta_mapper=None, batch_size=3):
|
108 |
+
"""
|
109 |
+
Args:
|
110 |
+
cfg (CfgNode):
|
111 |
+
model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on.
|
112 |
+
tta_mapper (callable): takes a dataset dict and returns a list of
|
113 |
+
augmented versions of the dataset dict. Defaults to
|
114 |
+
`DatasetMapperTTA(cfg)`.
|
115 |
+
batch_size (int): batch the augmented images into this batch size for inference.
|
116 |
+
"""
|
117 |
+
super().__init__()
|
118 |
+
if isinstance(model, DistributedDataParallel):
|
119 |
+
model = model.module
|
120 |
+
assert isinstance(
|
121 |
+
model, GeneralizedRCNN
|
122 |
+
), "TTA is only supported on GeneralizedRCNN. Got a model of type {}".format(type(model))
|
123 |
+
self.cfg = cfg.clone()
|
124 |
+
assert not self.cfg.MODEL.KEYPOINT_ON, "TTA for keypoint is not supported yet"
|
125 |
+
assert (
|
126 |
+
not self.cfg.MODEL.LOAD_PROPOSALS
|
127 |
+
), "TTA for pre-computed proposals is not supported yet"
|
128 |
+
|
129 |
+
self.model = model
|
130 |
+
|
131 |
+
if tta_mapper is None:
|
132 |
+
tta_mapper = DatasetMapperTTA(cfg)
|
133 |
+
self.tta_mapper = tta_mapper
|
134 |
+
self.batch_size = batch_size
|
135 |
+
|
136 |
+
@contextmanager
|
137 |
+
def _turn_off_roi_heads(self, attrs):
|
138 |
+
"""
|
139 |
+
Open a context where some heads in `model.roi_heads` are temporarily turned off.
|
140 |
+
Args:
|
141 |
+
attr (list[str]): the attribute in `model.roi_heads` which can be used
|
142 |
+
to turn off a specific head, e.g., "mask_on", "keypoint_on".
|
143 |
+
"""
|
144 |
+
roi_heads = self.model.roi_heads
|
145 |
+
old = {}
|
146 |
+
for attr in attrs:
|
147 |
+
try:
|
148 |
+
old[attr] = getattr(roi_heads, attr)
|
149 |
+
except AttributeError:
|
150 |
+
# The head may not be implemented in certain ROIHeads
|
151 |
+
pass
|
152 |
+
|
153 |
+
if len(old.keys()) == 0:
|
154 |
+
yield
|
155 |
+
else:
|
156 |
+
for attr in old.keys():
|
157 |
+
setattr(roi_heads, attr, False)
|
158 |
+
yield
|
159 |
+
for attr in old.keys():
|
160 |
+
setattr(roi_heads, attr, old[attr])
|
161 |
+
|
162 |
+
def _batch_inference(self, batched_inputs, detected_instances=None):
|
163 |
+
"""
|
164 |
+
Execute inference on a list of inputs,
|
165 |
+
using batch size = self.batch_size, instead of the length of the list.
|
166 |
+
|
167 |
+
Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference`
|
168 |
+
"""
|
169 |
+
if detected_instances is None:
|
170 |
+
detected_instances = [None] * len(batched_inputs)
|
171 |
+
|
172 |
+
outputs = []
|
173 |
+
inputs, instances = [], []
|
174 |
+
for idx, input, instance in zip(count(), batched_inputs, detected_instances):
|
175 |
+
inputs.append(input)
|
176 |
+
instances.append(instance)
|
177 |
+
if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
|
178 |
+
outputs.extend(
|
179 |
+
self.model.inference(
|
180 |
+
inputs,
|
181 |
+
instances if instances[0] is not None else None,
|
182 |
+
do_postprocess=False,
|
183 |
+
)
|
184 |
+
)
|
185 |
+
inputs, instances = [], []
|
186 |
+
return outputs
|
187 |
+
|
188 |
+
def __call__(self, batched_inputs):
|
189 |
+
"""
|
190 |
+
Same input/output format as :meth:`GeneralizedRCNN.forward`
|
191 |
+
"""
|
192 |
+
|
193 |
+
def _maybe_read_image(dataset_dict):
|
194 |
+
ret = copy.copy(dataset_dict)
|
195 |
+
if "image" not in ret:
|
196 |
+
image = read_image(ret.pop("file_name"), self.model.input_format)
|
197 |
+
image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW
|
198 |
+
ret["image"] = image
|
199 |
+
if "height" not in ret and "width" not in ret:
|
200 |
+
ret["height"] = image.shape[1]
|
201 |
+
ret["width"] = image.shape[2]
|
202 |
+
return ret
|
203 |
+
|
204 |
+
return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs]
|
205 |
+
|
206 |
+
def _inference_one_image(self, input):
|
207 |
+
"""
|
208 |
+
Args:
|
209 |
+
input (dict): one dataset dict with "image" field being a CHW tensor
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
dict: one output dict
|
213 |
+
"""
|
214 |
+
orig_shape = (input["height"], input["width"])
|
215 |
+
augmented_inputs, tfms = self._get_augmented_inputs(input)
|
216 |
+
# Detect boxes from all augmented versions
|
217 |
+
with self._turn_off_roi_heads(["mask_on", "keypoint_on"]):
|
218 |
+
# temporarily disable roi heads
|
219 |
+
all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms)
|
220 |
+
# merge all detected boxes to obtain final predictions for boxes
|
221 |
+
merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape)
|
222 |
+
|
223 |
+
if self.cfg.MODEL.MASK_ON:
|
224 |
+
# Use the detected boxes to obtain masks
|
225 |
+
augmented_instances = self._rescale_detected_boxes(
|
226 |
+
augmented_inputs, merged_instances, tfms
|
227 |
+
)
|
228 |
+
# run forward on the detected boxes
|
229 |
+
outputs = self._batch_inference(augmented_inputs, augmented_instances)
|
230 |
+
# Delete now useless variables to avoid being out of memory
|
231 |
+
del augmented_inputs, augmented_instances
|
232 |
+
# average the predictions
|
233 |
+
merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms)
|
234 |
+
merged_instances = detector_postprocess(merged_instances, *orig_shape)
|
235 |
+
return {"instances": merged_instances}
|
236 |
+
else:
|
237 |
+
return {"instances": merged_instances}
|
238 |
+
|
239 |
+
def _get_augmented_inputs(self, input):
|
240 |
+
augmented_inputs = self.tta_mapper(input)
|
241 |
+
tfms = [x.pop("transforms") for x in augmented_inputs]
|
242 |
+
return augmented_inputs, tfms
|
243 |
+
|
244 |
+
def _get_augmented_boxes(self, augmented_inputs, tfms):
|
245 |
+
# 1: forward with all augmented images
|
246 |
+
outputs = self._batch_inference(augmented_inputs)
|
247 |
+
# 2: union the results
|
248 |
+
all_boxes = []
|
249 |
+
all_scores = []
|
250 |
+
all_classes = []
|
251 |
+
for output, tfm in zip(outputs, tfms):
|
252 |
+
# Need to inverse the transforms on boxes, to obtain results on original image
|
253 |
+
pred_boxes = output.pred_boxes.tensor
|
254 |
+
original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy())
|
255 |
+
all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device))
|
256 |
+
|
257 |
+
all_scores.extend(output.scores)
|
258 |
+
all_classes.extend(output.pred_classes)
|
259 |
+
all_boxes = torch.cat(all_boxes, dim=0)
|
260 |
+
return all_boxes, all_scores, all_classes
|
261 |
+
|
262 |
+
def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw):
|
263 |
+
# select from the union of all results
|
264 |
+
num_boxes = len(all_boxes)
|
265 |
+
num_classes = self.cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
266 |
+
# +1 because fast_rcnn_inference expects background scores as well
|
267 |
+
all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device)
|
268 |
+
for idx, cls, score in zip(count(), all_classes, all_scores):
|
269 |
+
all_scores_2d[idx, cls] = score
|
270 |
+
|
271 |
+
merged_instances, _ = fast_rcnn_inference_single_image(
|
272 |
+
all_boxes,
|
273 |
+
all_scores_2d,
|
274 |
+
shape_hw,
|
275 |
+
1e-8,
|
276 |
+
self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST,
|
277 |
+
self.cfg.TEST.DETECTIONS_PER_IMAGE,
|
278 |
+
)
|
279 |
+
|
280 |
+
return merged_instances
|
281 |
+
|
282 |
+
def _rescale_detected_boxes(self, augmented_inputs, merged_instances, tfms):
|
283 |
+
augmented_instances = []
|
284 |
+
for input, tfm in zip(augmented_inputs, tfms):
|
285 |
+
# Transform the target box to the augmented image's coordinate space
|
286 |
+
pred_boxes = merged_instances.pred_boxes.tensor.cpu().numpy()
|
287 |
+
pred_boxes = torch.from_numpy(tfm.apply_box(pred_boxes))
|
288 |
+
|
289 |
+
aug_instances = Instances(
|
290 |
+
image_size=input["image"].shape[1:3],
|
291 |
+
pred_boxes=Boxes(pred_boxes),
|
292 |
+
pred_classes=merged_instances.pred_classes,
|
293 |
+
scores=merged_instances.scores,
|
294 |
+
)
|
295 |
+
augmented_instances.append(aug_instances)
|
296 |
+
return augmented_instances
|
297 |
+
|
298 |
+
def _reduce_pred_masks(self, outputs, tfms):
|
299 |
+
# Should apply inverse transforms on masks.
|
300 |
+
# We assume only resize & flip are used. pred_masks is a scale-invariant
|
301 |
+
# representation, so we handle flip specially
|
302 |
+
for output, tfm in zip(outputs, tfms):
|
303 |
+
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
|
304 |
+
output.pred_masks = output.pred_masks.flip(dims=[3])
|
305 |
+
all_pred_masks = torch.stack([o.pred_masks for o in outputs], dim=0)
|
306 |
+
avg_pred_masks = torch.mean(all_pred_masks, dim=0)
|
307 |
+
return avg_pred_masks
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Projects live in the [`projects` directory](../../projects) under the root of this repository, but not here.
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import importlib.abc
|
3 |
+
import importlib.util
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
__all__ = []
|
7 |
+
|
8 |
+
_PROJECTS = {
|
9 |
+
"point_rend": "PointRend",
|
10 |
+
"deeplab": "DeepLab",
|
11 |
+
"panoptic_deeplab": "Panoptic-DeepLab",
|
12 |
+
}
|
13 |
+
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent / "projects"
|
14 |
+
|
15 |
+
if _PROJECT_ROOT.is_dir():
|
16 |
+
# This is true only for in-place installation (pip install -e, setup.py develop),
|
17 |
+
# where setup(package_dir=) does not work: https://github.com/pypa/setuptools/issues/230
|
18 |
+
|
19 |
+
class _D2ProjectsFinder(importlib.abc.MetaPathFinder):
|
20 |
+
def find_spec(self, name, path, target=None):
|
21 |
+
if not name.startswith("detectron2.projects."):
|
22 |
+
return
|
23 |
+
project_name = name.split(".")[-1]
|
24 |
+
project_dir = _PROJECTS.get(project_name)
|
25 |
+
if not project_dir:
|
26 |
+
return
|
27 |
+
target_file = _PROJECT_ROOT / f"{project_dir}/{project_name}/__init__.py"
|
28 |
+
if not target_file.is_file():
|
29 |
+
return
|
30 |
+
return importlib.util.spec_from_file_location(name, target_file)
|
31 |
+
|
32 |
+
import sys
|
33 |
+
|
34 |
+
sys.meta_path.append(_D2ProjectsFinder())
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .build_solver import build_lr_scheduler
|
3 |
+
from .config import add_deeplab_config
|
4 |
+
from .resnet import build_resnet_deeplab_backbone
|
5 |
+
from .semantic_seg import DeepLabV3Head, DeepLabV3PlusHead
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/build_solver.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from annotator.oneformer.detectron2.config import CfgNode
|
5 |
+
from annotator.oneformer.detectron2.solver import LRScheduler
|
6 |
+
from annotator.oneformer.detectron2.solver import build_lr_scheduler as build_d2_lr_scheduler
|
7 |
+
|
8 |
+
from .lr_scheduler import WarmupPolyLR
|
9 |
+
|
10 |
+
|
11 |
+
def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler:
|
12 |
+
"""
|
13 |
+
Build a LR scheduler from config.
|
14 |
+
"""
|
15 |
+
name = cfg.SOLVER.LR_SCHEDULER_NAME
|
16 |
+
if name == "WarmupPolyLR":
|
17 |
+
return WarmupPolyLR(
|
18 |
+
optimizer,
|
19 |
+
cfg.SOLVER.MAX_ITER,
|
20 |
+
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
21 |
+
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
22 |
+
warmup_method=cfg.SOLVER.WARMUP_METHOD,
|
23 |
+
power=cfg.SOLVER.POLY_LR_POWER,
|
24 |
+
constant_ending=cfg.SOLVER.POLY_LR_CONSTANT_ENDING,
|
25 |
+
)
|
26 |
+
else:
|
27 |
+
return build_d2_lr_scheduler(cfg, optimizer)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
|
5 |
+
def add_deeplab_config(cfg):
|
6 |
+
"""
|
7 |
+
Add config for DeepLab.
|
8 |
+
"""
|
9 |
+
# We retry random cropping until no single category in semantic segmentation GT occupies more
|
10 |
+
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
|
11 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
|
12 |
+
# Used for `poly` learning rate schedule.
|
13 |
+
cfg.SOLVER.POLY_LR_POWER = 0.9
|
14 |
+
cfg.SOLVER.POLY_LR_CONSTANT_ENDING = 0.0
|
15 |
+
# Loss type, choose from `cross_entropy`, `hard_pixel_mining`.
|
16 |
+
cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE = "hard_pixel_mining"
|
17 |
+
# DeepLab settings
|
18 |
+
cfg.MODEL.SEM_SEG_HEAD.PROJECT_FEATURES = ["res2"]
|
19 |
+
cfg.MODEL.SEM_SEG_HEAD.PROJECT_CHANNELS = [48]
|
20 |
+
cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS = 256
|
21 |
+
cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS = [6, 12, 18]
|
22 |
+
cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT = 0.1
|
23 |
+
cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV = False
|
24 |
+
# Backbone new configs
|
25 |
+
cfg.MODEL.RESNETS.RES4_DILATION = 1
|
26 |
+
cfg.MODEL.RESNETS.RES5_MULTI_GRID = [1, 2, 4]
|
27 |
+
# ResNet stem type from: `basic`, `deeplab`
|
28 |
+
cfg.MODEL.RESNETS.STEM_TYPE = "deeplab"
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/loss.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class DeepLabCE(nn.Module):
|
7 |
+
"""
|
8 |
+
Hard pixel mining with cross entropy loss, for semantic segmentation.
|
9 |
+
This is used in TensorFlow DeepLab frameworks.
|
10 |
+
Paper: DeeperLab: Single-Shot Image Parser
|
11 |
+
Reference: https://github.com/tensorflow/models/blob/bd488858d610e44df69da6f89277e9de8a03722c/research/deeplab/utils/train_utils.py#L33 # noqa
|
12 |
+
Arguments:
|
13 |
+
ignore_label: Integer, label to ignore.
|
14 |
+
top_k_percent_pixels: Float, the value lies in [0.0, 1.0]. When its
|
15 |
+
value < 1.0, only compute the loss for the top k percent pixels
|
16 |
+
(e.g., the top 20% pixels). This is useful for hard pixel mining.
|
17 |
+
weight: Tensor, a manual rescaling weight given to each class.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, ignore_label=-1, top_k_percent_pixels=1.0, weight=None):
|
21 |
+
super(DeepLabCE, self).__init__()
|
22 |
+
self.top_k_percent_pixels = top_k_percent_pixels
|
23 |
+
self.ignore_label = ignore_label
|
24 |
+
self.criterion = nn.CrossEntropyLoss(
|
25 |
+
weight=weight, ignore_index=ignore_label, reduction="none"
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, logits, labels, weights=None):
|
29 |
+
if weights is None:
|
30 |
+
pixel_losses = self.criterion(logits, labels).contiguous().view(-1)
|
31 |
+
else:
|
32 |
+
# Apply per-pixel loss weights.
|
33 |
+
pixel_losses = self.criterion(logits, labels) * weights
|
34 |
+
pixel_losses = pixel_losses.contiguous().view(-1)
|
35 |
+
if self.top_k_percent_pixels == 1.0:
|
36 |
+
return pixel_losses.mean()
|
37 |
+
|
38 |
+
top_k_pixels = int(self.top_k_percent_pixels * pixel_losses.numel())
|
39 |
+
pixel_losses, _ = torch.topk(pixel_losses, top_k_pixels)
|
40 |
+
return pixel_losses.mean()
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/lr_scheduler.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import math
|
3 |
+
from typing import List
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from annotator.oneformer.detectron2.solver.lr_scheduler import LRScheduler, _get_warmup_factor_at_iter
|
7 |
+
|
8 |
+
# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
|
9 |
+
# only on epoch boundaries. We typically use iteration based schedules instead.
|
10 |
+
# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean
|
11 |
+
# "iteration" instead.
|
12 |
+
|
13 |
+
# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating
|
14 |
+
# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
|
15 |
+
|
16 |
+
|
17 |
+
class WarmupPolyLR(LRScheduler):
|
18 |
+
"""
|
19 |
+
Poly learning rate schedule used to train DeepLab.
|
20 |
+
Paper: DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,
|
21 |
+
Atrous Convolution, and Fully Connected CRFs.
|
22 |
+
Reference: https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/utils/train_utils.py#L337 # noqa
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
optimizer: torch.optim.Optimizer,
|
28 |
+
max_iters: int,
|
29 |
+
warmup_factor: float = 0.001,
|
30 |
+
warmup_iters: int = 1000,
|
31 |
+
warmup_method: str = "linear",
|
32 |
+
last_epoch: int = -1,
|
33 |
+
power: float = 0.9,
|
34 |
+
constant_ending: float = 0.0,
|
35 |
+
):
|
36 |
+
self.max_iters = max_iters
|
37 |
+
self.warmup_factor = warmup_factor
|
38 |
+
self.warmup_iters = warmup_iters
|
39 |
+
self.warmup_method = warmup_method
|
40 |
+
self.power = power
|
41 |
+
self.constant_ending = constant_ending
|
42 |
+
super().__init__(optimizer, last_epoch)
|
43 |
+
|
44 |
+
def get_lr(self) -> List[float]:
|
45 |
+
warmup_factor = _get_warmup_factor_at_iter(
|
46 |
+
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
47 |
+
)
|
48 |
+
if self.constant_ending > 0 and warmup_factor == 1.0:
|
49 |
+
# Constant ending lr.
|
50 |
+
if (
|
51 |
+
math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
|
52 |
+
< self.constant_ending
|
53 |
+
):
|
54 |
+
return [base_lr * self.constant_ending for base_lr in self.base_lrs]
|
55 |
+
return [
|
56 |
+
base_lr * warmup_factor * math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
|
57 |
+
for base_lr in self.base_lrs
|
58 |
+
]
|
59 |
+
|
60 |
+
def _compute_values(self) -> List[float]:
|
61 |
+
# The new interface
|
62 |
+
return self.get_lr()
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/resnet.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import fvcore.nn.weight_init as weight_init
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from annotator.oneformer.detectron2.layers import CNNBlockBase, Conv2d, get_norm
|
6 |
+
from annotator.oneformer.detectron2.modeling import BACKBONE_REGISTRY
|
7 |
+
from annotator.oneformer.detectron2.modeling.backbone.resnet import (
|
8 |
+
BasicStem,
|
9 |
+
BottleneckBlock,
|
10 |
+
DeformBottleneckBlock,
|
11 |
+
ResNet,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class DeepLabStem(CNNBlockBase):
|
16 |
+
"""
|
17 |
+
The DeepLab ResNet stem (layers before the first residual block).
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, in_channels=3, out_channels=128, norm="BN"):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
norm (str or callable): norm after the first conv layer.
|
24 |
+
See :func:`layers.get_norm` for supported format.
|
25 |
+
"""
|
26 |
+
super().__init__(in_channels, out_channels, 4)
|
27 |
+
self.in_channels = in_channels
|
28 |
+
self.conv1 = Conv2d(
|
29 |
+
in_channels,
|
30 |
+
out_channels // 2,
|
31 |
+
kernel_size=3,
|
32 |
+
stride=2,
|
33 |
+
padding=1,
|
34 |
+
bias=False,
|
35 |
+
norm=get_norm(norm, out_channels // 2),
|
36 |
+
)
|
37 |
+
self.conv2 = Conv2d(
|
38 |
+
out_channels // 2,
|
39 |
+
out_channels // 2,
|
40 |
+
kernel_size=3,
|
41 |
+
stride=1,
|
42 |
+
padding=1,
|
43 |
+
bias=False,
|
44 |
+
norm=get_norm(norm, out_channels // 2),
|
45 |
+
)
|
46 |
+
self.conv3 = Conv2d(
|
47 |
+
out_channels // 2,
|
48 |
+
out_channels,
|
49 |
+
kernel_size=3,
|
50 |
+
stride=1,
|
51 |
+
padding=1,
|
52 |
+
bias=False,
|
53 |
+
norm=get_norm(norm, out_channels),
|
54 |
+
)
|
55 |
+
weight_init.c2_msra_fill(self.conv1)
|
56 |
+
weight_init.c2_msra_fill(self.conv2)
|
57 |
+
weight_init.c2_msra_fill(self.conv3)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.conv1(x)
|
61 |
+
x = F.relu_(x)
|
62 |
+
x = self.conv2(x)
|
63 |
+
x = F.relu_(x)
|
64 |
+
x = self.conv3(x)
|
65 |
+
x = F.relu_(x)
|
66 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
@BACKBONE_REGISTRY.register()
|
71 |
+
def build_resnet_deeplab_backbone(cfg, input_shape):
|
72 |
+
"""
|
73 |
+
Create a ResNet instance from config.
|
74 |
+
Returns:
|
75 |
+
ResNet: a :class:`ResNet` instance.
|
76 |
+
"""
|
77 |
+
# need registration of new blocks/stems?
|
78 |
+
norm = cfg.MODEL.RESNETS.NORM
|
79 |
+
if cfg.MODEL.RESNETS.STEM_TYPE == "basic":
|
80 |
+
stem = BasicStem(
|
81 |
+
in_channels=input_shape.channels,
|
82 |
+
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
83 |
+
norm=norm,
|
84 |
+
)
|
85 |
+
elif cfg.MODEL.RESNETS.STEM_TYPE == "deeplab":
|
86 |
+
stem = DeepLabStem(
|
87 |
+
in_channels=input_shape.channels,
|
88 |
+
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
89 |
+
norm=norm,
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
raise ValueError("Unknown stem type: {}".format(cfg.MODEL.RESNETS.STEM_TYPE))
|
93 |
+
|
94 |
+
# fmt: off
|
95 |
+
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
|
96 |
+
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
|
97 |
+
depth = cfg.MODEL.RESNETS.DEPTH
|
98 |
+
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
99 |
+
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
100 |
+
bottleneck_channels = num_groups * width_per_group
|
101 |
+
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
|
102 |
+
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
|
103 |
+
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
104 |
+
res4_dilation = cfg.MODEL.RESNETS.RES4_DILATION
|
105 |
+
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
|
106 |
+
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
|
107 |
+
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
|
108 |
+
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
|
109 |
+
res5_multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID
|
110 |
+
# fmt: on
|
111 |
+
assert res4_dilation in {1, 2}, "res4_dilation cannot be {}.".format(res4_dilation)
|
112 |
+
assert res5_dilation in {1, 2, 4}, "res5_dilation cannot be {}.".format(res5_dilation)
|
113 |
+
if res4_dilation == 2:
|
114 |
+
# Always dilate res5 if res4 is dilated.
|
115 |
+
assert res5_dilation == 4
|
116 |
+
|
117 |
+
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
|
118 |
+
|
119 |
+
stages = []
|
120 |
+
|
121 |
+
# Avoid creating variables without gradients
|
122 |
+
# It consumes extra memory and may cause allreduce to fail
|
123 |
+
out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features]
|
124 |
+
max_stage_idx = max(out_stage_idx)
|
125 |
+
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
|
126 |
+
if stage_idx == 4:
|
127 |
+
dilation = res4_dilation
|
128 |
+
elif stage_idx == 5:
|
129 |
+
dilation = res5_dilation
|
130 |
+
else:
|
131 |
+
dilation = 1
|
132 |
+
first_stride = 1 if idx == 0 or dilation > 1 else 2
|
133 |
+
stage_kargs = {
|
134 |
+
"num_blocks": num_blocks_per_stage[idx],
|
135 |
+
"stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
|
136 |
+
"in_channels": in_channels,
|
137 |
+
"out_channels": out_channels,
|
138 |
+
"norm": norm,
|
139 |
+
}
|
140 |
+
stage_kargs["bottleneck_channels"] = bottleneck_channels
|
141 |
+
stage_kargs["stride_in_1x1"] = stride_in_1x1
|
142 |
+
stage_kargs["dilation"] = dilation
|
143 |
+
stage_kargs["num_groups"] = num_groups
|
144 |
+
if deform_on_per_stage[idx]:
|
145 |
+
stage_kargs["block_class"] = DeformBottleneckBlock
|
146 |
+
stage_kargs["deform_modulated"] = deform_modulated
|
147 |
+
stage_kargs["deform_num_groups"] = deform_num_groups
|
148 |
+
else:
|
149 |
+
stage_kargs["block_class"] = BottleneckBlock
|
150 |
+
if stage_idx == 5:
|
151 |
+
stage_kargs.pop("dilation")
|
152 |
+
stage_kargs["dilation_per_block"] = [dilation * mg for mg in res5_multi_grid]
|
153 |
+
blocks = ResNet.make_stage(**stage_kargs)
|
154 |
+
in_channels = out_channels
|
155 |
+
out_channels *= 2
|
156 |
+
bottleneck_channels *= 2
|
157 |
+
stages.append(blocks)
|
158 |
+
return ResNet(stem, stages, out_features=out_features).freeze(freeze_at)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/projects/deeplab/semantic_seg.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
3 |
+
import fvcore.nn.weight_init as weight_init
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from annotator.oneformer.detectron2.config import configurable
|
9 |
+
from annotator.oneformer.detectron2.layers import ASPP, Conv2d, DepthwiseSeparableConv2d, ShapeSpec, get_norm
|
10 |
+
from annotator.oneformer.detectron2.modeling import SEM_SEG_HEADS_REGISTRY
|
11 |
+
|
12 |
+
from .loss import DeepLabCE
|
13 |
+
|
14 |
+
|
15 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
16 |
+
class DeepLabV3PlusHead(nn.Module):
|
17 |
+
"""
|
18 |
+
A semantic segmentation head described in :paper:`DeepLabV3+`.
|
19 |
+
"""
|
20 |
+
|
21 |
+
@configurable
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
input_shape: Dict[str, ShapeSpec],
|
25 |
+
*,
|
26 |
+
project_channels: List[int],
|
27 |
+
aspp_dilations: List[int],
|
28 |
+
aspp_dropout: float,
|
29 |
+
decoder_channels: List[int],
|
30 |
+
common_stride: int,
|
31 |
+
norm: Union[str, Callable],
|
32 |
+
train_size: Optional[Tuple],
|
33 |
+
loss_weight: float = 1.0,
|
34 |
+
loss_type: str = "cross_entropy",
|
35 |
+
ignore_value: int = -1,
|
36 |
+
num_classes: Optional[int] = None,
|
37 |
+
use_depthwise_separable_conv: bool = False,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
NOTE: this interface is experimental.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
input_shape: shape of the input features. They will be ordered by stride
|
44 |
+
and the last one (with largest stride) is used as the input to the
|
45 |
+
decoder (i.e. the ASPP module); the rest are low-level feature for
|
46 |
+
the intermediate levels of decoder.
|
47 |
+
project_channels (list[int]): a list of low-level feature channels.
|
48 |
+
The length should be len(in_features) - 1.
|
49 |
+
aspp_dilations (list(int)): a list of 3 dilations in ASPP.
|
50 |
+
aspp_dropout (float): apply dropout on the output of ASPP.
|
51 |
+
decoder_channels (list[int]): a list of output channels of each
|
52 |
+
decoder stage. It should have the same length as "in_features"
|
53 |
+
(each element in "in_features" corresponds to one decoder stage).
|
54 |
+
common_stride (int): output stride of decoder.
|
55 |
+
norm (str or callable): normalization for all conv layers.
|
56 |
+
train_size (tuple): (height, width) of training images.
|
57 |
+
loss_weight (float): loss weight.
|
58 |
+
loss_type (str): type of loss function, 2 opptions:
|
59 |
+
(1) "cross_entropy" is the standard cross entropy loss.
|
60 |
+
(2) "hard_pixel_mining" is the loss in DeepLab that samples
|
61 |
+
top k% hardest pixels.
|
62 |
+
ignore_value (int): category to be ignored during training.
|
63 |
+
num_classes (int): number of classes, if set to None, the decoder
|
64 |
+
will not construct a predictor.
|
65 |
+
use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d
|
66 |
+
in ASPP and decoder.
|
67 |
+
"""
|
68 |
+
super().__init__()
|
69 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
70 |
+
|
71 |
+
# fmt: off
|
72 |
+
self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
|
73 |
+
in_channels = [x[1].channels for x in input_shape]
|
74 |
+
in_strides = [x[1].stride for x in input_shape]
|
75 |
+
aspp_channels = decoder_channels[-1]
|
76 |
+
self.ignore_value = ignore_value
|
77 |
+
self.common_stride = common_stride # output stride
|
78 |
+
self.loss_weight = loss_weight
|
79 |
+
self.loss_type = loss_type
|
80 |
+
self.decoder_only = num_classes is None
|
81 |
+
self.use_depthwise_separable_conv = use_depthwise_separable_conv
|
82 |
+
# fmt: on
|
83 |
+
|
84 |
+
assert (
|
85 |
+
len(project_channels) == len(self.in_features) - 1
|
86 |
+
), "Expected {} project_channels, got {}".format(
|
87 |
+
len(self.in_features) - 1, len(project_channels)
|
88 |
+
)
|
89 |
+
assert len(decoder_channels) == len(
|
90 |
+
self.in_features
|
91 |
+
), "Expected {} decoder_channels, got {}".format(
|
92 |
+
len(self.in_features), len(decoder_channels)
|
93 |
+
)
|
94 |
+
self.decoder = nn.ModuleDict()
|
95 |
+
|
96 |
+
use_bias = norm == ""
|
97 |
+
for idx, in_channel in enumerate(in_channels):
|
98 |
+
decoder_stage = nn.ModuleDict()
|
99 |
+
|
100 |
+
if idx == len(self.in_features) - 1:
|
101 |
+
# ASPP module
|
102 |
+
if train_size is not None:
|
103 |
+
train_h, train_w = train_size
|
104 |
+
encoder_stride = in_strides[-1]
|
105 |
+
if train_h % encoder_stride or train_w % encoder_stride:
|
106 |
+
raise ValueError("Crop size need to be divisible by encoder stride.")
|
107 |
+
pool_h = train_h // encoder_stride
|
108 |
+
pool_w = train_w // encoder_stride
|
109 |
+
pool_kernel_size = (pool_h, pool_w)
|
110 |
+
else:
|
111 |
+
pool_kernel_size = None
|
112 |
+
project_conv = ASPP(
|
113 |
+
in_channel,
|
114 |
+
aspp_channels,
|
115 |
+
aspp_dilations,
|
116 |
+
norm=norm,
|
117 |
+
activation=F.relu,
|
118 |
+
pool_kernel_size=pool_kernel_size,
|
119 |
+
dropout=aspp_dropout,
|
120 |
+
use_depthwise_separable_conv=use_depthwise_separable_conv,
|
121 |
+
)
|
122 |
+
fuse_conv = None
|
123 |
+
else:
|
124 |
+
project_conv = Conv2d(
|
125 |
+
in_channel,
|
126 |
+
project_channels[idx],
|
127 |
+
kernel_size=1,
|
128 |
+
bias=use_bias,
|
129 |
+
norm=get_norm(norm, project_channels[idx]),
|
130 |
+
activation=F.relu,
|
131 |
+
)
|
132 |
+
weight_init.c2_xavier_fill(project_conv)
|
133 |
+
if use_depthwise_separable_conv:
|
134 |
+
# We use a single 5x5 DepthwiseSeparableConv2d to replace
|
135 |
+
# 2 3x3 Conv2d since they have the same receptive field,
|
136 |
+
# proposed in :paper:`Panoptic-DeepLab`.
|
137 |
+
fuse_conv = DepthwiseSeparableConv2d(
|
138 |
+
project_channels[idx] + decoder_channels[idx + 1],
|
139 |
+
decoder_channels[idx],
|
140 |
+
kernel_size=5,
|
141 |
+
padding=2,
|
142 |
+
norm1=norm,
|
143 |
+
activation1=F.relu,
|
144 |
+
norm2=norm,
|
145 |
+
activation2=F.relu,
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
fuse_conv = nn.Sequential(
|
149 |
+
Conv2d(
|
150 |
+
project_channels[idx] + decoder_channels[idx + 1],
|
151 |
+
decoder_channels[idx],
|
152 |
+
kernel_size=3,
|
153 |
+
padding=1,
|
154 |
+
bias=use_bias,
|
155 |
+
norm=get_norm(norm, decoder_channels[idx]),
|
156 |
+
activation=F.relu,
|
157 |
+
),
|
158 |
+
Conv2d(
|
159 |
+
decoder_channels[idx],
|
160 |
+
decoder_channels[idx],
|
161 |
+
kernel_size=3,
|
162 |
+
padding=1,
|
163 |
+
bias=use_bias,
|
164 |
+
norm=get_norm(norm, decoder_channels[idx]),
|
165 |
+
activation=F.relu,
|
166 |
+
),
|
167 |
+
)
|
168 |
+
weight_init.c2_xavier_fill(fuse_conv[0])
|
169 |
+
weight_init.c2_xavier_fill(fuse_conv[1])
|
170 |
+
|
171 |
+
decoder_stage["project_conv"] = project_conv
|
172 |
+
decoder_stage["fuse_conv"] = fuse_conv
|
173 |
+
|
174 |
+
self.decoder[self.in_features[idx]] = decoder_stage
|
175 |
+
|
176 |
+
if not self.decoder_only:
|
177 |
+
self.predictor = Conv2d(
|
178 |
+
decoder_channels[0], num_classes, kernel_size=1, stride=1, padding=0
|
179 |
+
)
|
180 |
+
nn.init.normal_(self.predictor.weight, 0, 0.001)
|
181 |
+
nn.init.constant_(self.predictor.bias, 0)
|
182 |
+
|
183 |
+
if self.loss_type == "cross_entropy":
|
184 |
+
self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.ignore_value)
|
185 |
+
elif self.loss_type == "hard_pixel_mining":
|
186 |
+
self.loss = DeepLabCE(ignore_label=self.ignore_value, top_k_percent_pixels=0.2)
|
187 |
+
else:
|
188 |
+
raise ValueError("Unexpected loss type: %s" % self.loss_type)
|
189 |
+
|
190 |
+
@classmethod
|
191 |
+
def from_config(cls, cfg, input_shape):
|
192 |
+
if cfg.INPUT.CROP.ENABLED:
|
193 |
+
assert cfg.INPUT.CROP.TYPE == "absolute"
|
194 |
+
train_size = cfg.INPUT.CROP.SIZE
|
195 |
+
else:
|
196 |
+
train_size = None
|
197 |
+
decoder_channels = [cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM] * (
|
198 |
+
len(cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES) - 1
|
199 |
+
) + [cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS]
|
200 |
+
ret = dict(
|
201 |
+
input_shape={
|
202 |
+
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
203 |
+
},
|
204 |
+
project_channels=cfg.MODEL.SEM_SEG_HEAD.PROJECT_CHANNELS,
|
205 |
+
aspp_dilations=cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS,
|
206 |
+
aspp_dropout=cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT,
|
207 |
+
decoder_channels=decoder_channels,
|
208 |
+
common_stride=cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE,
|
209 |
+
norm=cfg.MODEL.SEM_SEG_HEAD.NORM,
|
210 |
+
train_size=train_size,
|
211 |
+
loss_weight=cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
|
212 |
+
loss_type=cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE,
|
213 |
+
ignore_value=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
214 |
+
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
|
215 |
+
use_depthwise_separable_conv=cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV,
|
216 |
+
)
|
217 |
+
return ret
|
218 |
+
|
219 |
+
def forward(self, features, targets=None):
|
220 |
+
"""
|
221 |
+
Returns:
|
222 |
+
In training, returns (None, dict of losses)
|
223 |
+
In inference, returns (CxHxW logits, {})
|
224 |
+
"""
|
225 |
+
y = self.layers(features)
|
226 |
+
if self.decoder_only:
|
227 |
+
# Output from self.layers() only contains decoder feature.
|
228 |
+
return y
|
229 |
+
if self.training:
|
230 |
+
return None, self.losses(y, targets)
|
231 |
+
else:
|
232 |
+
y = F.interpolate(
|
233 |
+
y, scale_factor=self.common_stride, mode="bilinear", align_corners=False
|
234 |
+
)
|
235 |
+
return y, {}
|
236 |
+
|
237 |
+
def layers(self, features):
|
238 |
+
# Reverse feature maps into top-down order (from low to high resolution)
|
239 |
+
for f in self.in_features[::-1]:
|
240 |
+
x = features[f]
|
241 |
+
proj_x = self.decoder[f]["project_conv"](x)
|
242 |
+
if self.decoder[f]["fuse_conv"] is None:
|
243 |
+
# This is aspp module
|
244 |
+
y = proj_x
|
245 |
+
else:
|
246 |
+
# Upsample y
|
247 |
+
y = F.interpolate(y, size=proj_x.size()[2:], mode="bilinear", align_corners=False)
|
248 |
+
y = torch.cat([proj_x, y], dim=1)
|
249 |
+
y = self.decoder[f]["fuse_conv"](y)
|
250 |
+
if not self.decoder_only:
|
251 |
+
y = self.predictor(y)
|
252 |
+
return y
|
253 |
+
|
254 |
+
def losses(self, predictions, targets):
|
255 |
+
predictions = F.interpolate(
|
256 |
+
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
|
257 |
+
)
|
258 |
+
loss = self.loss(predictions, targets)
|
259 |
+
losses = {"loss_sem_seg": loss * self.loss_weight}
|
260 |
+
return losses
|
261 |
+
|
262 |
+
|
263 |
+
@SEM_SEG_HEADS_REGISTRY.register()
|
264 |
+
class DeepLabV3Head(nn.Module):
|
265 |
+
"""
|
266 |
+
A semantic segmentation head described in :paper:`DeepLabV3`.
|
267 |
+
"""
|
268 |
+
|
269 |
+
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
|
270 |
+
super().__init__()
|
271 |
+
|
272 |
+
# fmt: off
|
273 |
+
self.in_features = cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
|
274 |
+
in_channels = [input_shape[f].channels for f in self.in_features]
|
275 |
+
aspp_channels = cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS
|
276 |
+
aspp_dilations = cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS
|
277 |
+
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
|
278 |
+
num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
|
279 |
+
conv_dims = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
280 |
+
self.common_stride = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE # output stride
|
281 |
+
norm = cfg.MODEL.SEM_SEG_HEAD.NORM
|
282 |
+
self.loss_weight = cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT
|
283 |
+
self.loss_type = cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE
|
284 |
+
train_crop_size = cfg.INPUT.CROP.SIZE
|
285 |
+
aspp_dropout = cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT
|
286 |
+
use_depthwise_separable_conv = cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV
|
287 |
+
# fmt: on
|
288 |
+
|
289 |
+
assert len(self.in_features) == 1
|
290 |
+
assert len(in_channels) == 1
|
291 |
+
|
292 |
+
# ASPP module
|
293 |
+
if cfg.INPUT.CROP.ENABLED:
|
294 |
+
assert cfg.INPUT.CROP.TYPE == "absolute"
|
295 |
+
train_crop_h, train_crop_w = train_crop_size
|
296 |
+
if train_crop_h % self.common_stride or train_crop_w % self.common_stride:
|
297 |
+
raise ValueError("Crop size need to be divisible by output stride.")
|
298 |
+
pool_h = train_crop_h // self.common_stride
|
299 |
+
pool_w = train_crop_w // self.common_stride
|
300 |
+
pool_kernel_size = (pool_h, pool_w)
|
301 |
+
else:
|
302 |
+
pool_kernel_size = None
|
303 |
+
self.aspp = ASPP(
|
304 |
+
in_channels[0],
|
305 |
+
aspp_channels,
|
306 |
+
aspp_dilations,
|
307 |
+
norm=norm,
|
308 |
+
activation=F.relu,
|
309 |
+
pool_kernel_size=pool_kernel_size,
|
310 |
+
dropout=aspp_dropout,
|
311 |
+
use_depthwise_separable_conv=use_depthwise_separable_conv,
|
312 |
+
)
|
313 |
+
|
314 |
+
self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0)
|
315 |
+
nn.init.normal_(self.predictor.weight, 0, 0.001)
|
316 |
+
nn.init.constant_(self.predictor.bias, 0)
|
317 |
+
|
318 |
+
if self.loss_type == "cross_entropy":
|
319 |
+
self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.ignore_value)
|
320 |
+
elif self.loss_type == "hard_pixel_mining":
|
321 |
+
self.loss = DeepLabCE(ignore_label=self.ignore_value, top_k_percent_pixels=0.2)
|
322 |
+
else:
|
323 |
+
raise ValueError("Unexpected loss type: %s" % self.loss_type)
|
324 |
+
|
325 |
+
def forward(self, features, targets=None):
|
326 |
+
"""
|
327 |
+
Returns:
|
328 |
+
In training, returns (None, dict of losses)
|
329 |
+
In inference, returns (CxHxW logits, {})
|
330 |
+
"""
|
331 |
+
x = features[self.in_features[0]]
|
332 |
+
x = self.aspp(x)
|
333 |
+
x = self.predictor(x)
|
334 |
+
if self.training:
|
335 |
+
return None, self.losses(x, targets)
|
336 |
+
else:
|
337 |
+
x = F.interpolate(
|
338 |
+
x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
|
339 |
+
)
|
340 |
+
return x, {}
|
341 |
+
|
342 |
+
def losses(self, predictions, targets):
|
343 |
+
predictions = F.interpolate(
|
344 |
+
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
|
345 |
+
)
|
346 |
+
loss = self.loss(predictions, targets)
|
347 |
+
losses = {"loss_sem_seg": loss * self.loss_weight}
|
348 |
+
return losses
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/solver/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params
|
3 |
+
from .lr_scheduler import (
|
4 |
+
LRMultiplier,
|
5 |
+
LRScheduler,
|
6 |
+
WarmupCosineLR,
|
7 |
+
WarmupMultiStepLR,
|
8 |
+
WarmupParamScheduler,
|
9 |
+
)
|
10 |
+
|
11 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/solver/build.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import itertools
|
4 |
+
import logging
|
5 |
+
from collections import defaultdict
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
|
8 |
+
import torch
|
9 |
+
from fvcore.common.param_scheduler import (
|
10 |
+
CosineParamScheduler,
|
11 |
+
MultiStepParamScheduler,
|
12 |
+
StepWithFixedGammaParamScheduler,
|
13 |
+
)
|
14 |
+
|
15 |
+
from annotator.oneformer.detectron2.config import CfgNode
|
16 |
+
from annotator.oneformer.detectron2.utils.env import TORCH_VERSION
|
17 |
+
|
18 |
+
from .lr_scheduler import LRMultiplier, LRScheduler, WarmupParamScheduler
|
19 |
+
|
20 |
+
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
|
21 |
+
_GradientClipper = Callable[[_GradientClipperInput], None]
|
22 |
+
|
23 |
+
|
24 |
+
class GradientClipType(Enum):
|
25 |
+
VALUE = "value"
|
26 |
+
NORM = "norm"
|
27 |
+
|
28 |
+
|
29 |
+
def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
|
30 |
+
"""
|
31 |
+
Creates gradient clipping closure to clip by value or by norm,
|
32 |
+
according to the provided config.
|
33 |
+
"""
|
34 |
+
cfg = copy.deepcopy(cfg)
|
35 |
+
|
36 |
+
def clip_grad_norm(p: _GradientClipperInput):
|
37 |
+
torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
|
38 |
+
|
39 |
+
def clip_grad_value(p: _GradientClipperInput):
|
40 |
+
torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
|
41 |
+
|
42 |
+
_GRADIENT_CLIP_TYPE_TO_CLIPPER = {
|
43 |
+
GradientClipType.VALUE: clip_grad_value,
|
44 |
+
GradientClipType.NORM: clip_grad_norm,
|
45 |
+
}
|
46 |
+
return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
|
47 |
+
|
48 |
+
|
49 |
+
def _generate_optimizer_class_with_gradient_clipping(
|
50 |
+
optimizer: Type[torch.optim.Optimizer],
|
51 |
+
*,
|
52 |
+
per_param_clipper: Optional[_GradientClipper] = None,
|
53 |
+
global_clipper: Optional[_GradientClipper] = None,
|
54 |
+
) -> Type[torch.optim.Optimizer]:
|
55 |
+
"""
|
56 |
+
Dynamically creates a new type that inherits the type of a given instance
|
57 |
+
and overrides the `step` method to add gradient clipping
|
58 |
+
"""
|
59 |
+
assert (
|
60 |
+
per_param_clipper is None or global_clipper is None
|
61 |
+
), "Not allowed to use both per-parameter clipping and global clipping"
|
62 |
+
|
63 |
+
def optimizer_wgc_step(self, closure=None):
|
64 |
+
if per_param_clipper is not None:
|
65 |
+
for group in self.param_groups:
|
66 |
+
for p in group["params"]:
|
67 |
+
per_param_clipper(p)
|
68 |
+
else:
|
69 |
+
# global clipper for future use with detr
|
70 |
+
# (https://github.com/facebookresearch/detr/pull/287)
|
71 |
+
all_params = itertools.chain(*[g["params"] for g in self.param_groups])
|
72 |
+
global_clipper(all_params)
|
73 |
+
super(type(self), self).step(closure)
|
74 |
+
|
75 |
+
OptimizerWithGradientClip = type(
|
76 |
+
optimizer.__name__ + "WithGradientClip",
|
77 |
+
(optimizer,),
|
78 |
+
{"step": optimizer_wgc_step},
|
79 |
+
)
|
80 |
+
return OptimizerWithGradientClip
|
81 |
+
|
82 |
+
|
83 |
+
def maybe_add_gradient_clipping(
|
84 |
+
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
|
85 |
+
) -> Type[torch.optim.Optimizer]:
|
86 |
+
"""
|
87 |
+
If gradient clipping is enabled through config options, wraps the existing
|
88 |
+
optimizer type to become a new dynamically created class OptimizerWithGradientClip
|
89 |
+
that inherits the given optimizer and overrides the `step` method to
|
90 |
+
include gradient clipping.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
cfg: CfgNode, configuration options
|
94 |
+
optimizer: type. A subclass of torch.optim.Optimizer
|
95 |
+
|
96 |
+
Return:
|
97 |
+
type: either the input `optimizer` (if gradient clipping is disabled), or
|
98 |
+
a subclass of it with gradient clipping included in the `step` method.
|
99 |
+
"""
|
100 |
+
if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
|
101 |
+
return optimizer
|
102 |
+
if isinstance(optimizer, torch.optim.Optimizer):
|
103 |
+
optimizer_type = type(optimizer)
|
104 |
+
else:
|
105 |
+
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
|
106 |
+
optimizer_type = optimizer
|
107 |
+
|
108 |
+
grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
|
109 |
+
OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
|
110 |
+
optimizer_type, per_param_clipper=grad_clipper
|
111 |
+
)
|
112 |
+
if isinstance(optimizer, torch.optim.Optimizer):
|
113 |
+
optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
|
114 |
+
return optimizer
|
115 |
+
else:
|
116 |
+
return OptimizerWithGradientClip
|
117 |
+
|
118 |
+
|
119 |
+
def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
|
120 |
+
"""
|
121 |
+
Build an optimizer from config.
|
122 |
+
"""
|
123 |
+
params = get_default_optimizer_params(
|
124 |
+
model,
|
125 |
+
base_lr=cfg.SOLVER.BASE_LR,
|
126 |
+
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
127 |
+
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
128 |
+
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
129 |
+
)
|
130 |
+
sgd_args = {
|
131 |
+
"params": params,
|
132 |
+
"lr": cfg.SOLVER.BASE_LR,
|
133 |
+
"momentum": cfg.SOLVER.MOMENTUM,
|
134 |
+
"nesterov": cfg.SOLVER.NESTEROV,
|
135 |
+
"weight_decay": cfg.SOLVER.WEIGHT_DECAY,
|
136 |
+
}
|
137 |
+
if TORCH_VERSION >= (1, 12):
|
138 |
+
sgd_args["foreach"] = True
|
139 |
+
return maybe_add_gradient_clipping(cfg, torch.optim.SGD(**sgd_args))
|
140 |
+
|
141 |
+
|
142 |
+
def get_default_optimizer_params(
|
143 |
+
model: torch.nn.Module,
|
144 |
+
base_lr: Optional[float] = None,
|
145 |
+
weight_decay: Optional[float] = None,
|
146 |
+
weight_decay_norm: Optional[float] = None,
|
147 |
+
bias_lr_factor: Optional[float] = 1.0,
|
148 |
+
weight_decay_bias: Optional[float] = None,
|
149 |
+
lr_factor_func: Optional[Callable] = None,
|
150 |
+
overrides: Optional[Dict[str, Dict[str, float]]] = None,
|
151 |
+
) -> List[Dict[str, Any]]:
|
152 |
+
"""
|
153 |
+
Get default param list for optimizer, with support for a few types of
|
154 |
+
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
|
158 |
+
weight_decay: weight decay for every group by default. Can be omitted to use the one
|
159 |
+
in optimizer.
|
160 |
+
weight_decay_norm: override weight decay for params in normalization layers
|
161 |
+
bias_lr_factor: multiplier of lr for bias parameters.
|
162 |
+
weight_decay_bias: override weight decay for bias parameters.
|
163 |
+
lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
|
164 |
+
corresponding lr decay rate. Note that setting this option requires
|
165 |
+
also setting ``base_lr``.
|
166 |
+
overrides: if not `None`, provides values for optimizer hyperparameters
|
167 |
+
(LR, weight decay) for module parameters with a given name; e.g.
|
168 |
+
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
|
169 |
+
weight decay values for all module parameters named `embedding`.
|
170 |
+
|
171 |
+
For common detection models, ``weight_decay_norm`` is the only option
|
172 |
+
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
|
173 |
+
from Detectron1 that are not found useful.
|
174 |
+
|
175 |
+
Example:
|
176 |
+
::
|
177 |
+
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
|
178 |
+
lr=0.01, weight_decay=1e-4, momentum=0.9)
|
179 |
+
"""
|
180 |
+
if overrides is None:
|
181 |
+
overrides = {}
|
182 |
+
defaults = {}
|
183 |
+
if base_lr is not None:
|
184 |
+
defaults["lr"] = base_lr
|
185 |
+
if weight_decay is not None:
|
186 |
+
defaults["weight_decay"] = weight_decay
|
187 |
+
bias_overrides = {}
|
188 |
+
if bias_lr_factor is not None and bias_lr_factor != 1.0:
|
189 |
+
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
|
190 |
+
# exactly the same as regular weights.
|
191 |
+
if base_lr is None:
|
192 |
+
raise ValueError("bias_lr_factor requires base_lr")
|
193 |
+
bias_overrides["lr"] = base_lr * bias_lr_factor
|
194 |
+
if weight_decay_bias is not None:
|
195 |
+
bias_overrides["weight_decay"] = weight_decay_bias
|
196 |
+
if len(bias_overrides):
|
197 |
+
if "bias" in overrides:
|
198 |
+
raise ValueError("Conflicting overrides for 'bias'")
|
199 |
+
overrides["bias"] = bias_overrides
|
200 |
+
if lr_factor_func is not None:
|
201 |
+
if base_lr is None:
|
202 |
+
raise ValueError("lr_factor_func requires base_lr")
|
203 |
+
norm_module_types = (
|
204 |
+
torch.nn.BatchNorm1d,
|
205 |
+
torch.nn.BatchNorm2d,
|
206 |
+
torch.nn.BatchNorm3d,
|
207 |
+
torch.nn.SyncBatchNorm,
|
208 |
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
209 |
+
torch.nn.GroupNorm,
|
210 |
+
torch.nn.InstanceNorm1d,
|
211 |
+
torch.nn.InstanceNorm2d,
|
212 |
+
torch.nn.InstanceNorm3d,
|
213 |
+
torch.nn.LayerNorm,
|
214 |
+
torch.nn.LocalResponseNorm,
|
215 |
+
)
|
216 |
+
params: List[Dict[str, Any]] = []
|
217 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
218 |
+
for module_name, module in model.named_modules():
|
219 |
+
for module_param_name, value in module.named_parameters(recurse=False):
|
220 |
+
if not value.requires_grad:
|
221 |
+
continue
|
222 |
+
# Avoid duplicating parameters
|
223 |
+
if value in memo:
|
224 |
+
continue
|
225 |
+
memo.add(value)
|
226 |
+
|
227 |
+
hyperparams = copy.copy(defaults)
|
228 |
+
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
|
229 |
+
hyperparams["weight_decay"] = weight_decay_norm
|
230 |
+
if lr_factor_func is not None:
|
231 |
+
hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
|
232 |
+
|
233 |
+
hyperparams.update(overrides.get(module_param_name, {}))
|
234 |
+
params.append({"params": [value], **hyperparams})
|
235 |
+
return reduce_param_groups(params)
|
236 |
+
|
237 |
+
|
238 |
+
def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
239 |
+
# Transform parameter groups into per-parameter structure.
|
240 |
+
# Later items in `params` can overwrite parameters set in previous items.
|
241 |
+
ret = defaultdict(dict)
|
242 |
+
for item in params:
|
243 |
+
assert "params" in item
|
244 |
+
cur_params = {x: y for x, y in item.items() if x != "params"}
|
245 |
+
for param in item["params"]:
|
246 |
+
ret[param].update({"params": [param], **cur_params})
|
247 |
+
return list(ret.values())
|
248 |
+
|
249 |
+
|
250 |
+
def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
251 |
+
# Reorganize the parameter groups and merge duplicated groups.
|
252 |
+
# The number of parameter groups needs to be as small as possible in order
|
253 |
+
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
|
254 |
+
# of using a parameter_group per single parameter, we reorganize the
|
255 |
+
# parameter groups and merge duplicated groups. This approach speeds
|
256 |
+
# up multi-tensor optimizer significantly.
|
257 |
+
params = _expand_param_groups(params)
|
258 |
+
groups = defaultdict(list) # re-group all parameter groups by their hyperparams
|
259 |
+
for item in params:
|
260 |
+
cur_params = tuple((x, y) for x, y in item.items() if x != "params")
|
261 |
+
groups[cur_params].extend(item["params"])
|
262 |
+
ret = []
|
263 |
+
for param_keys, param_values in groups.items():
|
264 |
+
cur = {kv[0]: kv[1] for kv in param_keys}
|
265 |
+
cur["params"] = param_values
|
266 |
+
ret.append(cur)
|
267 |
+
return ret
|
268 |
+
|
269 |
+
|
270 |
+
def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler:
|
271 |
+
"""
|
272 |
+
Build a LR scheduler from config.
|
273 |
+
"""
|
274 |
+
name = cfg.SOLVER.LR_SCHEDULER_NAME
|
275 |
+
|
276 |
+
if name == "WarmupMultiStepLR":
|
277 |
+
steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER]
|
278 |
+
if len(steps) != len(cfg.SOLVER.STEPS):
|
279 |
+
logger = logging.getLogger(__name__)
|
280 |
+
logger.warning(
|
281 |
+
"SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. "
|
282 |
+
"These values will be ignored."
|
283 |
+
)
|
284 |
+
sched = MultiStepParamScheduler(
|
285 |
+
values=[cfg.SOLVER.GAMMA**k for k in range(len(steps) + 1)],
|
286 |
+
milestones=steps,
|
287 |
+
num_updates=cfg.SOLVER.MAX_ITER,
|
288 |
+
)
|
289 |
+
elif name == "WarmupCosineLR":
|
290 |
+
end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR
|
291 |
+
assert end_value >= 0.0 and end_value <= 1.0, end_value
|
292 |
+
sched = CosineParamScheduler(1, end_value)
|
293 |
+
elif name == "WarmupStepWithFixedGammaLR":
|
294 |
+
sched = StepWithFixedGammaParamScheduler(
|
295 |
+
base_value=1.0,
|
296 |
+
gamma=cfg.SOLVER.GAMMA,
|
297 |
+
num_decays=cfg.SOLVER.NUM_DECAYS,
|
298 |
+
num_updates=cfg.SOLVER.MAX_ITER,
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
raise ValueError("Unknown LR scheduler: {}".format(name))
|
302 |
+
|
303 |
+
sched = WarmupParamScheduler(
|
304 |
+
sched,
|
305 |
+
cfg.SOLVER.WARMUP_FACTOR,
|
306 |
+
min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0),
|
307 |
+
cfg.SOLVER.WARMUP_METHOD,
|
308 |
+
cfg.SOLVER.RESCALE_INTERVAL,
|
309 |
+
)
|
310 |
+
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/solver/lr_scheduler.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
from bisect import bisect_right
|
5 |
+
from typing import List
|
6 |
+
import torch
|
7 |
+
from fvcore.common.param_scheduler import (
|
8 |
+
CompositeParamScheduler,
|
9 |
+
ConstantParamScheduler,
|
10 |
+
LinearParamScheduler,
|
11 |
+
ParamScheduler,
|
12 |
+
)
|
13 |
+
|
14 |
+
try:
|
15 |
+
from torch.optim.lr_scheduler import LRScheduler
|
16 |
+
except ImportError:
|
17 |
+
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class WarmupParamScheduler(CompositeParamScheduler):
|
23 |
+
"""
|
24 |
+
Add an initial warmup stage to another scheduler.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
scheduler: ParamScheduler,
|
30 |
+
warmup_factor: float,
|
31 |
+
warmup_length: float,
|
32 |
+
warmup_method: str = "linear",
|
33 |
+
rescale_interval: bool = False,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
scheduler: warmup will be added at the beginning of this scheduler
|
38 |
+
warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001
|
39 |
+
warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire
|
40 |
+
training, e.g. 0.01
|
41 |
+
warmup_method: one of "linear" or "constant"
|
42 |
+
rescale_interval: whether we will rescale the interval of the scheduler after
|
43 |
+
warmup
|
44 |
+
"""
|
45 |
+
end_value = scheduler(warmup_length) # the value to reach when warmup ends
|
46 |
+
start_value = warmup_factor * scheduler(0.0)
|
47 |
+
if warmup_method == "constant":
|
48 |
+
warmup = ConstantParamScheduler(start_value)
|
49 |
+
elif warmup_method == "linear":
|
50 |
+
warmup = LinearParamScheduler(start_value, end_value)
|
51 |
+
else:
|
52 |
+
raise ValueError("Unknown warmup method: {}".format(warmup_method))
|
53 |
+
super().__init__(
|
54 |
+
[warmup, scheduler],
|
55 |
+
interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"],
|
56 |
+
lengths=[warmup_length, 1 - warmup_length],
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
class LRMultiplier(LRScheduler):
|
61 |
+
"""
|
62 |
+
A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the
|
63 |
+
learning rate of each param in the optimizer.
|
64 |
+
Every step, the learning rate of each parameter becomes its initial value
|
65 |
+
multiplied by the output of the given :class:`ParamScheduler`.
|
66 |
+
|
67 |
+
The absolute learning rate value of each parameter can be different.
|
68 |
+
This scheduler can be used as long as the relative scale among them do
|
69 |
+
not change during training.
|
70 |
+
|
71 |
+
Examples:
|
72 |
+
::
|
73 |
+
LRMultiplier(
|
74 |
+
opt,
|
75 |
+
WarmupParamScheduler(
|
76 |
+
MultiStepParamScheduler(
|
77 |
+
[1, 0.1, 0.01],
|
78 |
+
milestones=[60000, 80000],
|
79 |
+
num_updates=90000,
|
80 |
+
), 0.001, 100 / 90000
|
81 |
+
),
|
82 |
+
max_iter=90000
|
83 |
+
)
|
84 |
+
"""
|
85 |
+
|
86 |
+
# NOTES: in the most general case, every LR can use its own scheduler.
|
87 |
+
# Supporting this requires interaction with the optimizer when its parameter
|
88 |
+
# group is initialized. For example, classyvision implements its own optimizer
|
89 |
+
# that allows different schedulers for every parameter group.
|
90 |
+
# To avoid this complexity, we use this class to support the most common cases
|
91 |
+
# where the relative scale among all LRs stay unchanged during training. In this
|
92 |
+
# case we only need a total of one scheduler that defines the relative LR multiplier.
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
optimizer: torch.optim.Optimizer,
|
97 |
+
multiplier: ParamScheduler,
|
98 |
+
max_iter: int,
|
99 |
+
last_iter: int = -1,
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Args:
|
103 |
+
optimizer, last_iter: See ``torch.optim.lr_scheduler.LRScheduler``.
|
104 |
+
``last_iter`` is the same as ``last_epoch``.
|
105 |
+
multiplier: a fvcore ParamScheduler that defines the multiplier on
|
106 |
+
every LR of the optimizer
|
107 |
+
max_iter: the total number of training iterations
|
108 |
+
"""
|
109 |
+
if not isinstance(multiplier, ParamScheduler):
|
110 |
+
raise ValueError(
|
111 |
+
"_LRMultiplier(multiplier=) must be an instance of fvcore "
|
112 |
+
f"ParamScheduler. Got {multiplier} instead."
|
113 |
+
)
|
114 |
+
self._multiplier = multiplier
|
115 |
+
self._max_iter = max_iter
|
116 |
+
super().__init__(optimizer, last_epoch=last_iter)
|
117 |
+
|
118 |
+
def state_dict(self):
|
119 |
+
# fvcore schedulers are stateless. Only keep pytorch scheduler states
|
120 |
+
return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch}
|
121 |
+
|
122 |
+
def get_lr(self) -> List[float]:
|
123 |
+
multiplier = self._multiplier(self.last_epoch / self._max_iter)
|
124 |
+
return [base_lr * multiplier for base_lr in self.base_lrs]
|
125 |
+
|
126 |
+
|
127 |
+
"""
|
128 |
+
Content below is no longer needed!
|
129 |
+
"""
|
130 |
+
|
131 |
+
# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
|
132 |
+
# only on epoch boundaries. We typically use iteration based schedules instead.
|
133 |
+
# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean
|
134 |
+
# "iteration" instead.
|
135 |
+
|
136 |
+
# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating
|
137 |
+
# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
|
138 |
+
|
139 |
+
|
140 |
+
class WarmupMultiStepLR(LRScheduler):
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
optimizer: torch.optim.Optimizer,
|
144 |
+
milestones: List[int],
|
145 |
+
gamma: float = 0.1,
|
146 |
+
warmup_factor: float = 0.001,
|
147 |
+
warmup_iters: int = 1000,
|
148 |
+
warmup_method: str = "linear",
|
149 |
+
last_epoch: int = -1,
|
150 |
+
):
|
151 |
+
logger.warning(
|
152 |
+
"WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
|
153 |
+
)
|
154 |
+
if not list(milestones) == sorted(milestones):
|
155 |
+
raise ValueError(
|
156 |
+
"Milestones should be a list of" " increasing integers. Got {}", milestones
|
157 |
+
)
|
158 |
+
self.milestones = milestones
|
159 |
+
self.gamma = gamma
|
160 |
+
self.warmup_factor = warmup_factor
|
161 |
+
self.warmup_iters = warmup_iters
|
162 |
+
self.warmup_method = warmup_method
|
163 |
+
super().__init__(optimizer, last_epoch)
|
164 |
+
|
165 |
+
def get_lr(self) -> List[float]:
|
166 |
+
warmup_factor = _get_warmup_factor_at_iter(
|
167 |
+
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
168 |
+
)
|
169 |
+
return [
|
170 |
+
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
171 |
+
for base_lr in self.base_lrs
|
172 |
+
]
|
173 |
+
|
174 |
+
def _compute_values(self) -> List[float]:
|
175 |
+
# The new interface
|
176 |
+
return self.get_lr()
|
177 |
+
|
178 |
+
|
179 |
+
class WarmupCosineLR(LRScheduler):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
optimizer: torch.optim.Optimizer,
|
183 |
+
max_iters: int,
|
184 |
+
warmup_factor: float = 0.001,
|
185 |
+
warmup_iters: int = 1000,
|
186 |
+
warmup_method: str = "linear",
|
187 |
+
last_epoch: int = -1,
|
188 |
+
):
|
189 |
+
logger.warning(
|
190 |
+
"WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
|
191 |
+
)
|
192 |
+
self.max_iters = max_iters
|
193 |
+
self.warmup_factor = warmup_factor
|
194 |
+
self.warmup_iters = warmup_iters
|
195 |
+
self.warmup_method = warmup_method
|
196 |
+
super().__init__(optimizer, last_epoch)
|
197 |
+
|
198 |
+
def get_lr(self) -> List[float]:
|
199 |
+
warmup_factor = _get_warmup_factor_at_iter(
|
200 |
+
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
201 |
+
)
|
202 |
+
# Different definitions of half-cosine with warmup are possible. For
|
203 |
+
# simplicity we multiply the standard half-cosine schedule by the warmup
|
204 |
+
# factor. An alternative is to start the period of the cosine at warmup_iters
|
205 |
+
# instead of at 0. In the case that warmup_iters << max_iters the two are
|
206 |
+
# very close to each other.
|
207 |
+
return [
|
208 |
+
base_lr
|
209 |
+
* warmup_factor
|
210 |
+
* 0.5
|
211 |
+
* (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters))
|
212 |
+
for base_lr in self.base_lrs
|
213 |
+
]
|
214 |
+
|
215 |
+
def _compute_values(self) -> List[float]:
|
216 |
+
# The new interface
|
217 |
+
return self.get_lr()
|
218 |
+
|
219 |
+
|
220 |
+
def _get_warmup_factor_at_iter(
|
221 |
+
method: str, iter: int, warmup_iters: int, warmup_factor: float
|
222 |
+
) -> float:
|
223 |
+
"""
|
224 |
+
Return the learning rate warmup factor at a specific iteration.
|
225 |
+
See :paper:`ImageNet in 1h` for more details.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
method (str): warmup method; either "constant" or "linear".
|
229 |
+
iter (int): iteration at which to calculate the warmup factor.
|
230 |
+
warmup_iters (int): the number of warmup iterations.
|
231 |
+
warmup_factor (float): the base warmup factor (the meaning changes according
|
232 |
+
to the method used).
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
float: the effective warmup factor at the given iteration.
|
236 |
+
"""
|
237 |
+
if iter >= warmup_iters:
|
238 |
+
return 1.0
|
239 |
+
|
240 |
+
if method == "constant":
|
241 |
+
return warmup_factor
|
242 |
+
elif method == "linear":
|
243 |
+
alpha = iter / warmup_iters
|
244 |
+
return warmup_factor * (1 - alpha) + alpha
|
245 |
+
else:
|
246 |
+
raise ValueError("Unknown warmup method: {}".format(method))
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .boxes import Boxes, BoxMode, pairwise_iou, pairwise_ioa, pairwise_point_box_distance
|
3 |
+
from .image_list import ImageList
|
4 |
+
|
5 |
+
from .instances import Instances
|
6 |
+
from .keypoints import Keypoints, heatmaps_to_keypoints
|
7 |
+
from .masks import BitMasks, PolygonMasks, polygons_to_bitmask, ROIMasks
|
8 |
+
from .rotated_boxes import RotatedBoxes
|
9 |
+
from .rotated_boxes import pairwise_iou as pairwise_iou_rotated
|
10 |
+
|
11 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
12 |
+
|
13 |
+
|
14 |
+
from annotator.oneformer.detectron2.utils.env import fixup_module_metadata
|
15 |
+
|
16 |
+
fixup_module_metadata(__name__, globals(), __all__)
|
17 |
+
del fixup_module_metadata
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/boxes.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from enum import IntEnum, unique
|
5 |
+
from typing import List, Tuple, Union
|
6 |
+
import torch
|
7 |
+
from torch import device
|
8 |
+
|
9 |
+
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
|
10 |
+
|
11 |
+
|
12 |
+
@unique
|
13 |
+
class BoxMode(IntEnum):
|
14 |
+
"""
|
15 |
+
Enum of different ways to represent a box.
|
16 |
+
"""
|
17 |
+
|
18 |
+
XYXY_ABS = 0
|
19 |
+
"""
|
20 |
+
(x0, y0, x1, y1) in absolute floating points coordinates.
|
21 |
+
The coordinates in range [0, width or height].
|
22 |
+
"""
|
23 |
+
XYWH_ABS = 1
|
24 |
+
"""
|
25 |
+
(x0, y0, w, h) in absolute floating points coordinates.
|
26 |
+
"""
|
27 |
+
XYXY_REL = 2
|
28 |
+
"""
|
29 |
+
Not yet supported!
|
30 |
+
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
|
31 |
+
"""
|
32 |
+
XYWH_REL = 3
|
33 |
+
"""
|
34 |
+
Not yet supported!
|
35 |
+
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
|
36 |
+
"""
|
37 |
+
XYWHA_ABS = 4
|
38 |
+
"""
|
39 |
+
(xc, yc, w, h, a) in absolute floating points coordinates.
|
40 |
+
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
|
41 |
+
"""
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType:
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
|
48 |
+
from_mode, to_mode (BoxMode)
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
The converted box of the same type.
|
52 |
+
"""
|
53 |
+
if from_mode == to_mode:
|
54 |
+
return box
|
55 |
+
|
56 |
+
original_type = type(box)
|
57 |
+
is_numpy = isinstance(box, np.ndarray)
|
58 |
+
single_box = isinstance(box, (list, tuple))
|
59 |
+
if single_box:
|
60 |
+
assert len(box) == 4 or len(box) == 5, (
|
61 |
+
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
|
62 |
+
" where k == 4 or 5"
|
63 |
+
)
|
64 |
+
arr = torch.tensor(box)[None, :]
|
65 |
+
else:
|
66 |
+
# avoid modifying the input box
|
67 |
+
if is_numpy:
|
68 |
+
arr = torch.from_numpy(np.asarray(box)).clone()
|
69 |
+
else:
|
70 |
+
arr = box.clone()
|
71 |
+
|
72 |
+
assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [
|
73 |
+
BoxMode.XYXY_REL,
|
74 |
+
BoxMode.XYWH_REL,
|
75 |
+
], "Relative mode not yet supported!"
|
76 |
+
|
77 |
+
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
78 |
+
assert (
|
79 |
+
arr.shape[-1] == 5
|
80 |
+
), "The last dimension of input shape must be 5 for XYWHA format"
|
81 |
+
original_dtype = arr.dtype
|
82 |
+
arr = arr.double()
|
83 |
+
|
84 |
+
w = arr[:, 2]
|
85 |
+
h = arr[:, 3]
|
86 |
+
a = arr[:, 4]
|
87 |
+
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
88 |
+
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
89 |
+
# This basically computes the horizontal bounding rectangle of the rotated box
|
90 |
+
new_w = c * w + s * h
|
91 |
+
new_h = c * h + s * w
|
92 |
+
|
93 |
+
# convert center to top-left corner
|
94 |
+
arr[:, 0] -= new_w / 2.0
|
95 |
+
arr[:, 1] -= new_h / 2.0
|
96 |
+
# bottom-right corner
|
97 |
+
arr[:, 2] = arr[:, 0] + new_w
|
98 |
+
arr[:, 3] = arr[:, 1] + new_h
|
99 |
+
|
100 |
+
arr = arr[:, :4].to(dtype=original_dtype)
|
101 |
+
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
|
102 |
+
original_dtype = arr.dtype
|
103 |
+
arr = arr.double()
|
104 |
+
arr[:, 0] += arr[:, 2] / 2.0
|
105 |
+
arr[:, 1] += arr[:, 3] / 2.0
|
106 |
+
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
|
107 |
+
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
|
108 |
+
else:
|
109 |
+
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
|
110 |
+
arr[:, 2] += arr[:, 0]
|
111 |
+
arr[:, 3] += arr[:, 1]
|
112 |
+
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
|
113 |
+
arr[:, 2] -= arr[:, 0]
|
114 |
+
arr[:, 3] -= arr[:, 1]
|
115 |
+
else:
|
116 |
+
raise NotImplementedError(
|
117 |
+
"Conversion from BoxMode {} to {} is not supported yet".format(
|
118 |
+
from_mode, to_mode
|
119 |
+
)
|
120 |
+
)
|
121 |
+
|
122 |
+
if single_box:
|
123 |
+
return original_type(arr.flatten().tolist())
|
124 |
+
if is_numpy:
|
125 |
+
return arr.numpy()
|
126 |
+
else:
|
127 |
+
return arr
|
128 |
+
|
129 |
+
|
130 |
+
class Boxes:
|
131 |
+
"""
|
132 |
+
This structure stores a list of boxes as a Nx4 torch.Tensor.
|
133 |
+
It supports some common methods about boxes
|
134 |
+
(`area`, `clip`, `nonempty`, etc),
|
135 |
+
and also behaves like a Tensor
|
136 |
+
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
|
137 |
+
|
138 |
+
Attributes:
|
139 |
+
tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2).
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, tensor: torch.Tensor):
|
143 |
+
"""
|
144 |
+
Args:
|
145 |
+
tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
|
146 |
+
"""
|
147 |
+
if not isinstance(tensor, torch.Tensor):
|
148 |
+
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=torch.device("cpu"))
|
149 |
+
else:
|
150 |
+
tensor = tensor.to(torch.float32)
|
151 |
+
if tensor.numel() == 0:
|
152 |
+
# Use reshape, so we don't end up creating a new tensor that does not depend on
|
153 |
+
# the inputs (and consequently confuses jit)
|
154 |
+
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)
|
155 |
+
assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
|
156 |
+
|
157 |
+
self.tensor = tensor
|
158 |
+
|
159 |
+
def clone(self) -> "Boxes":
|
160 |
+
"""
|
161 |
+
Clone the Boxes.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
Boxes
|
165 |
+
"""
|
166 |
+
return Boxes(self.tensor.clone())
|
167 |
+
|
168 |
+
def to(self, device: torch.device):
|
169 |
+
# Boxes are assumed float32 and does not support to(dtype)
|
170 |
+
return Boxes(self.tensor.to(device=device))
|
171 |
+
|
172 |
+
def area(self) -> torch.Tensor:
|
173 |
+
"""
|
174 |
+
Computes the area of all the boxes.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
torch.Tensor: a vector with areas of each box.
|
178 |
+
"""
|
179 |
+
box = self.tensor
|
180 |
+
area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
181 |
+
return area
|
182 |
+
|
183 |
+
def clip(self, box_size: Tuple[int, int]) -> None:
|
184 |
+
"""
|
185 |
+
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
|
186 |
+
and y coordinates to the range [0, height].
|
187 |
+
|
188 |
+
Args:
|
189 |
+
box_size (height, width): The clipping box's size.
|
190 |
+
"""
|
191 |
+
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
|
192 |
+
h, w = box_size
|
193 |
+
x1 = self.tensor[:, 0].clamp(min=0, max=w)
|
194 |
+
y1 = self.tensor[:, 1].clamp(min=0, max=h)
|
195 |
+
x2 = self.tensor[:, 2].clamp(min=0, max=w)
|
196 |
+
y2 = self.tensor[:, 3].clamp(min=0, max=h)
|
197 |
+
self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
|
198 |
+
|
199 |
+
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
|
200 |
+
"""
|
201 |
+
Find boxes that are non-empty.
|
202 |
+
A box is considered empty, if either of its side is no larger than threshold.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
Tensor:
|
206 |
+
a binary vector which represents whether each box is empty
|
207 |
+
(False) or non-empty (True).
|
208 |
+
"""
|
209 |
+
box = self.tensor
|
210 |
+
widths = box[:, 2] - box[:, 0]
|
211 |
+
heights = box[:, 3] - box[:, 1]
|
212 |
+
keep = (widths > threshold) & (heights > threshold)
|
213 |
+
return keep
|
214 |
+
|
215 |
+
def __getitem__(self, item) -> "Boxes":
|
216 |
+
"""
|
217 |
+
Args:
|
218 |
+
item: int, slice, or a BoolTensor
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
Boxes: Create a new :class:`Boxes` by indexing.
|
222 |
+
|
223 |
+
The following usage are allowed:
|
224 |
+
|
225 |
+
1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
|
226 |
+
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
|
227 |
+
3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
|
228 |
+
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
|
229 |
+
|
230 |
+
Note that the returned Boxes might share storage with this Boxes,
|
231 |
+
subject to Pytorch's indexing semantics.
|
232 |
+
"""
|
233 |
+
if isinstance(item, int):
|
234 |
+
return Boxes(self.tensor[item].view(1, -1))
|
235 |
+
b = self.tensor[item]
|
236 |
+
assert b.dim() == 2, "Indexing on Boxes with {} failed to return a matrix!".format(item)
|
237 |
+
return Boxes(b)
|
238 |
+
|
239 |
+
def __len__(self) -> int:
|
240 |
+
return self.tensor.shape[0]
|
241 |
+
|
242 |
+
def __repr__(self) -> str:
|
243 |
+
return "Boxes(" + str(self.tensor) + ")"
|
244 |
+
|
245 |
+
def inside_box(self, box_size: Tuple[int, int], boundary_threshold: int = 0) -> torch.Tensor:
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
box_size (height, width): Size of the reference box.
|
249 |
+
boundary_threshold (int): Boxes that extend beyond the reference box
|
250 |
+
boundary by more than boundary_threshold are considered "outside".
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
a binary vector, indicating whether each box is inside the reference box.
|
254 |
+
"""
|
255 |
+
height, width = box_size
|
256 |
+
inds_inside = (
|
257 |
+
(self.tensor[..., 0] >= -boundary_threshold)
|
258 |
+
& (self.tensor[..., 1] >= -boundary_threshold)
|
259 |
+
& (self.tensor[..., 2] < width + boundary_threshold)
|
260 |
+
& (self.tensor[..., 3] < height + boundary_threshold)
|
261 |
+
)
|
262 |
+
return inds_inside
|
263 |
+
|
264 |
+
def get_centers(self) -> torch.Tensor:
|
265 |
+
"""
|
266 |
+
Returns:
|
267 |
+
The box centers in a Nx2 array of (x, y).
|
268 |
+
"""
|
269 |
+
return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2
|
270 |
+
|
271 |
+
def scale(self, scale_x: float, scale_y: float) -> None:
|
272 |
+
"""
|
273 |
+
Scale the box with horizontal and vertical scaling factors
|
274 |
+
"""
|
275 |
+
self.tensor[:, 0::2] *= scale_x
|
276 |
+
self.tensor[:, 1::2] *= scale_y
|
277 |
+
|
278 |
+
@classmethod
|
279 |
+
def cat(cls, boxes_list: List["Boxes"]) -> "Boxes":
|
280 |
+
"""
|
281 |
+
Concatenates a list of Boxes into a single Boxes
|
282 |
+
|
283 |
+
Arguments:
|
284 |
+
boxes_list (list[Boxes])
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
Boxes: the concatenated Boxes
|
288 |
+
"""
|
289 |
+
assert isinstance(boxes_list, (list, tuple))
|
290 |
+
if len(boxes_list) == 0:
|
291 |
+
return cls(torch.empty(0))
|
292 |
+
assert all([isinstance(box, Boxes) for box in boxes_list])
|
293 |
+
|
294 |
+
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
|
295 |
+
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
|
296 |
+
return cat_boxes
|
297 |
+
|
298 |
+
@property
|
299 |
+
def device(self) -> device:
|
300 |
+
return self.tensor.device
|
301 |
+
|
302 |
+
# type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
|
303 |
+
# https://github.com/pytorch/pytorch/issues/18627
|
304 |
+
@torch.jit.unused
|
305 |
+
def __iter__(self):
|
306 |
+
"""
|
307 |
+
Yield a box as a Tensor of shape (4,) at a time.
|
308 |
+
"""
|
309 |
+
yield from self.tensor
|
310 |
+
|
311 |
+
|
312 |
+
def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
313 |
+
"""
|
314 |
+
Given two lists of boxes of size N and M,
|
315 |
+
compute the intersection area between __all__ N x M pairs of boxes.
|
316 |
+
The box order must be (xmin, ymin, xmax, ymax)
|
317 |
+
|
318 |
+
Args:
|
319 |
+
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
Tensor: intersection, sized [N,M].
|
323 |
+
"""
|
324 |
+
boxes1, boxes2 = boxes1.tensor, boxes2.tensor
|
325 |
+
width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
|
326 |
+
boxes1[:, None, :2], boxes2[:, :2]
|
327 |
+
) # [N,M,2]
|
328 |
+
|
329 |
+
width_height.clamp_(min=0) # [N,M,2]
|
330 |
+
intersection = width_height.prod(dim=2) # [N,M]
|
331 |
+
return intersection
|
332 |
+
|
333 |
+
|
334 |
+
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
|
335 |
+
# with slight modifications
|
336 |
+
def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
337 |
+
"""
|
338 |
+
Given two lists of boxes of size N and M, compute the IoU
|
339 |
+
(intersection over union) between **all** N x M pairs of boxes.
|
340 |
+
The box order must be (xmin, ymin, xmax, ymax).
|
341 |
+
|
342 |
+
Args:
|
343 |
+
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
Tensor: IoU, sized [N,M].
|
347 |
+
"""
|
348 |
+
area1 = boxes1.area() # [N]
|
349 |
+
area2 = boxes2.area() # [M]
|
350 |
+
inter = pairwise_intersection(boxes1, boxes2)
|
351 |
+
|
352 |
+
# handle empty boxes
|
353 |
+
iou = torch.where(
|
354 |
+
inter > 0,
|
355 |
+
inter / (area1[:, None] + area2 - inter),
|
356 |
+
torch.zeros(1, dtype=inter.dtype, device=inter.device),
|
357 |
+
)
|
358 |
+
return iou
|
359 |
+
|
360 |
+
|
361 |
+
def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
362 |
+
"""
|
363 |
+
Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area).
|
364 |
+
|
365 |
+
Args:
|
366 |
+
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
Tensor: IoA, sized [N,M].
|
370 |
+
"""
|
371 |
+
area2 = boxes2.area() # [M]
|
372 |
+
inter = pairwise_intersection(boxes1, boxes2)
|
373 |
+
|
374 |
+
# handle empty boxes
|
375 |
+
ioa = torch.where(
|
376 |
+
inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device)
|
377 |
+
)
|
378 |
+
return ioa
|
379 |
+
|
380 |
+
|
381 |
+
def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes):
|
382 |
+
"""
|
383 |
+
Pairwise distance between N points and M boxes. The distance between a
|
384 |
+
point and a box is represented by the distance from the point to 4 edges
|
385 |
+
of the box. Distances are all positive when the point is inside the box.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
points: Nx2 coordinates. Each row is (x, y)
|
389 |
+
boxes: M boxes
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
Tensor: distances of size (N, M, 4). The 4 values are distances from
|
393 |
+
the point to the left, top, right, bottom of the box.
|
394 |
+
"""
|
395 |
+
x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
|
396 |
+
x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M)
|
397 |
+
return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)
|
398 |
+
|
399 |
+
|
400 |
+
def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
401 |
+
"""
|
402 |
+
Compute pairwise intersection over union (IOU) of two sets of matched
|
403 |
+
boxes that have the same number of boxes.
|
404 |
+
Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
|
405 |
+
|
406 |
+
Args:
|
407 |
+
boxes1 (Boxes): bounding boxes, sized [N,4].
|
408 |
+
boxes2 (Boxes): same length as boxes1
|
409 |
+
Returns:
|
410 |
+
Tensor: iou, sized [N].
|
411 |
+
"""
|
412 |
+
assert len(boxes1) == len(
|
413 |
+
boxes2
|
414 |
+
), "boxlists should have the same" "number of entries, got {}, {}".format(
|
415 |
+
len(boxes1), len(boxes2)
|
416 |
+
)
|
417 |
+
area1 = boxes1.area() # [N]
|
418 |
+
area2 = boxes2.area() # [N]
|
419 |
+
box1, box2 = boxes1.tensor, boxes2.tensor
|
420 |
+
lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
|
421 |
+
rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
|
422 |
+
wh = (rb - lt).clamp(min=0) # [N,2]
|
423 |
+
inter = wh[:, 0] * wh[:, 1] # [N]
|
424 |
+
iou = inter / (area1 + area2 - inter) # [N]
|
425 |
+
return iou
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/image_list.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from __future__ import division
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
4 |
+
import torch
|
5 |
+
from torch import device
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from annotator.oneformer.detectron2.layers.wrappers import move_device_like, shapes_to_tensor
|
9 |
+
|
10 |
+
|
11 |
+
class ImageList(object):
|
12 |
+
"""
|
13 |
+
Structure that holds a list of images (of possibly
|
14 |
+
varying sizes) as a single tensor.
|
15 |
+
This works by padding the images to the same size.
|
16 |
+
The original sizes of each image is stored in `image_sizes`.
|
17 |
+
|
18 |
+
Attributes:
|
19 |
+
image_sizes (list[tuple[int, int]]): each tuple is (h, w).
|
20 |
+
During tracing, it becomes list[Tensor] instead.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]):
|
24 |
+
"""
|
25 |
+
Arguments:
|
26 |
+
tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1
|
27 |
+
image_sizes (list[tuple[int, int]]): Each tuple is (h, w). It can
|
28 |
+
be smaller than (H, W) due to padding.
|
29 |
+
"""
|
30 |
+
self.tensor = tensor
|
31 |
+
self.image_sizes = image_sizes
|
32 |
+
|
33 |
+
def __len__(self) -> int:
|
34 |
+
return len(self.image_sizes)
|
35 |
+
|
36 |
+
def __getitem__(self, idx) -> torch.Tensor:
|
37 |
+
"""
|
38 |
+
Access the individual image in its original size.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
idx: int or slice
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1
|
45 |
+
"""
|
46 |
+
size = self.image_sizes[idx]
|
47 |
+
return self.tensor[idx, ..., : size[0], : size[1]]
|
48 |
+
|
49 |
+
@torch.jit.unused
|
50 |
+
def to(self, *args: Any, **kwargs: Any) -> "ImageList":
|
51 |
+
cast_tensor = self.tensor.to(*args, **kwargs)
|
52 |
+
return ImageList(cast_tensor, self.image_sizes)
|
53 |
+
|
54 |
+
@property
|
55 |
+
def device(self) -> device:
|
56 |
+
return self.tensor.device
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def from_tensors(
|
60 |
+
tensors: List[torch.Tensor],
|
61 |
+
size_divisibility: int = 0,
|
62 |
+
pad_value: float = 0.0,
|
63 |
+
padding_constraints: Optional[Dict[str, int]] = None,
|
64 |
+
) -> "ImageList":
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
tensors: a tuple or list of `torch.Tensor`, each of shape (Hi, Wi) or
|
68 |
+
(C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded
|
69 |
+
to the same shape with `pad_value`.
|
70 |
+
size_divisibility (int): If `size_divisibility > 0`, add padding to ensure
|
71 |
+
the common height and width is divisible by `size_divisibility`.
|
72 |
+
This depends on the model and many models need a divisibility of 32.
|
73 |
+
pad_value (float): value to pad.
|
74 |
+
padding_constraints (optional[Dict]): If given, it would follow the format as
|
75 |
+
{"size_divisibility": int, "square_size": int}, where `size_divisibility` will
|
76 |
+
overwrite the above one if presented and `square_size` indicates the
|
77 |
+
square padding size if `square_size` > 0.
|
78 |
+
Returns:
|
79 |
+
an `ImageList`.
|
80 |
+
"""
|
81 |
+
assert len(tensors) > 0
|
82 |
+
assert isinstance(tensors, (tuple, list))
|
83 |
+
for t in tensors:
|
84 |
+
assert isinstance(t, torch.Tensor), type(t)
|
85 |
+
assert t.shape[:-2] == tensors[0].shape[:-2], t.shape
|
86 |
+
|
87 |
+
image_sizes = [(im.shape[-2], im.shape[-1]) for im in tensors]
|
88 |
+
image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes]
|
89 |
+
max_size = torch.stack(image_sizes_tensor).max(0).values
|
90 |
+
|
91 |
+
if padding_constraints is not None:
|
92 |
+
square_size = padding_constraints.get("square_size", 0)
|
93 |
+
if square_size > 0:
|
94 |
+
# pad to square.
|
95 |
+
max_size[0] = max_size[1] = square_size
|
96 |
+
if "size_divisibility" in padding_constraints:
|
97 |
+
size_divisibility = padding_constraints["size_divisibility"]
|
98 |
+
if size_divisibility > 1:
|
99 |
+
stride = size_divisibility
|
100 |
+
# the last two dims are H,W, both subject to divisibility requirement
|
101 |
+
max_size = (max_size + (stride - 1)).div(stride, rounding_mode="floor") * stride
|
102 |
+
|
103 |
+
# handle weirdness of scripting and tracing ...
|
104 |
+
if torch.jit.is_scripting():
|
105 |
+
max_size: List[int] = max_size.to(dtype=torch.long).tolist()
|
106 |
+
else:
|
107 |
+
if torch.jit.is_tracing():
|
108 |
+
image_sizes = image_sizes_tensor
|
109 |
+
|
110 |
+
if len(tensors) == 1:
|
111 |
+
# This seems slightly (2%) faster.
|
112 |
+
# TODO: check whether it's faster for multiple images as well
|
113 |
+
image_size = image_sizes[0]
|
114 |
+
padding_size = [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]]
|
115 |
+
batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)
|
116 |
+
else:
|
117 |
+
# max_size can be a tensor in tracing mode, therefore convert to list
|
118 |
+
batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
|
119 |
+
device = (
|
120 |
+
None if torch.jit.is_scripting() else ("cpu" if torch.jit.is_tracing() else None)
|
121 |
+
)
|
122 |
+
batched_imgs = tensors[0].new_full(batch_shape, pad_value, device=device)
|
123 |
+
batched_imgs = move_device_like(batched_imgs, tensors[0])
|
124 |
+
for i, img in enumerate(tensors):
|
125 |
+
# Use `batched_imgs` directly instead of `img, pad_img = zip(tensors, batched_imgs)`
|
126 |
+
# Tracing mode cannot capture `copy_()` of temporary locals
|
127 |
+
batched_imgs[i, ..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
128 |
+
|
129 |
+
return ImageList(batched_imgs.contiguous(), image_sizes)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/instances.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import itertools
|
3 |
+
import warnings
|
4 |
+
from typing import Any, Dict, List, Tuple, Union
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class Instances:
|
9 |
+
"""
|
10 |
+
This class represents a list of instances in an image.
|
11 |
+
It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
|
12 |
+
All fields must have the same ``__len__`` which is the number of instances.
|
13 |
+
|
14 |
+
All other (non-field) attributes of this class are considered private:
|
15 |
+
they must start with '_' and are not modifiable by a user.
|
16 |
+
|
17 |
+
Some basic usage:
|
18 |
+
|
19 |
+
1. Set/get/check a field:
|
20 |
+
|
21 |
+
.. code-block:: python
|
22 |
+
|
23 |
+
instances.gt_boxes = Boxes(...)
|
24 |
+
print(instances.pred_masks) # a tensor of shape (N, H, W)
|
25 |
+
print('gt_masks' in instances)
|
26 |
+
|
27 |
+
2. ``len(instances)`` returns the number of instances
|
28 |
+
3. Indexing: ``instances[indices]`` will apply the indexing on all the fields
|
29 |
+
and returns a new :class:`Instances`.
|
30 |
+
Typically, ``indices`` is a integer vector of indices,
|
31 |
+
or a binary mask of length ``num_instances``
|
32 |
+
|
33 |
+
.. code-block:: python
|
34 |
+
|
35 |
+
category_3_detections = instances[instances.pred_classes == 3]
|
36 |
+
confident_detections = instances[instances.scores > 0.9]
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
image_size (height, width): the spatial size of the image.
|
43 |
+
kwargs: fields to add to this `Instances`.
|
44 |
+
"""
|
45 |
+
self._image_size = image_size
|
46 |
+
self._fields: Dict[str, Any] = {}
|
47 |
+
for k, v in kwargs.items():
|
48 |
+
self.set(k, v)
|
49 |
+
|
50 |
+
@property
|
51 |
+
def image_size(self) -> Tuple[int, int]:
|
52 |
+
"""
|
53 |
+
Returns:
|
54 |
+
tuple: height, width
|
55 |
+
"""
|
56 |
+
return self._image_size
|
57 |
+
|
58 |
+
def __setattr__(self, name: str, val: Any) -> None:
|
59 |
+
if name.startswith("_"):
|
60 |
+
super().__setattr__(name, val)
|
61 |
+
else:
|
62 |
+
self.set(name, val)
|
63 |
+
|
64 |
+
def __getattr__(self, name: str) -> Any:
|
65 |
+
if name == "_fields" or name not in self._fields:
|
66 |
+
raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
|
67 |
+
return self._fields[name]
|
68 |
+
|
69 |
+
def set(self, name: str, value: Any) -> None:
|
70 |
+
"""
|
71 |
+
Set the field named `name` to `value`.
|
72 |
+
The length of `value` must be the number of instances,
|
73 |
+
and must agree with other existing fields in this object.
|
74 |
+
"""
|
75 |
+
with warnings.catch_warnings(record=True):
|
76 |
+
data_len = len(value)
|
77 |
+
if len(self._fields):
|
78 |
+
assert (
|
79 |
+
len(self) == data_len
|
80 |
+
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
|
81 |
+
self._fields[name] = value
|
82 |
+
|
83 |
+
def has(self, name: str) -> bool:
|
84 |
+
"""
|
85 |
+
Returns:
|
86 |
+
bool: whether the field called `name` exists.
|
87 |
+
"""
|
88 |
+
return name in self._fields
|
89 |
+
|
90 |
+
def remove(self, name: str) -> None:
|
91 |
+
"""
|
92 |
+
Remove the field called `name`.
|
93 |
+
"""
|
94 |
+
del self._fields[name]
|
95 |
+
|
96 |
+
def get(self, name: str) -> Any:
|
97 |
+
"""
|
98 |
+
Returns the field called `name`.
|
99 |
+
"""
|
100 |
+
return self._fields[name]
|
101 |
+
|
102 |
+
def get_fields(self) -> Dict[str, Any]:
|
103 |
+
"""
|
104 |
+
Returns:
|
105 |
+
dict: a dict which maps names (str) to data of the fields
|
106 |
+
|
107 |
+
Modifying the returned dict will modify this instance.
|
108 |
+
"""
|
109 |
+
return self._fields
|
110 |
+
|
111 |
+
# Tensor-like methods
|
112 |
+
def to(self, *args: Any, **kwargs: Any) -> "Instances":
|
113 |
+
"""
|
114 |
+
Returns:
|
115 |
+
Instances: all fields are called with a `to(device)`, if the field has this method.
|
116 |
+
"""
|
117 |
+
ret = Instances(self._image_size)
|
118 |
+
for k, v in self._fields.items():
|
119 |
+
if hasattr(v, "to"):
|
120 |
+
v = v.to(*args, **kwargs)
|
121 |
+
ret.set(k, v)
|
122 |
+
return ret
|
123 |
+
|
124 |
+
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances":
|
125 |
+
"""
|
126 |
+
Args:
|
127 |
+
item: an index-like object and will be used to index all the fields.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
If `item` is a string, return the data in the corresponding field.
|
131 |
+
Otherwise, returns an `Instances` where all fields are indexed by `item`.
|
132 |
+
"""
|
133 |
+
if type(item) == int:
|
134 |
+
if item >= len(self) or item < -len(self):
|
135 |
+
raise IndexError("Instances index out of range!")
|
136 |
+
else:
|
137 |
+
item = slice(item, None, len(self))
|
138 |
+
|
139 |
+
ret = Instances(self._image_size)
|
140 |
+
for k, v in self._fields.items():
|
141 |
+
ret.set(k, v[item])
|
142 |
+
return ret
|
143 |
+
|
144 |
+
def __len__(self) -> int:
|
145 |
+
for v in self._fields.values():
|
146 |
+
# use __len__ because len() has to be int and is not friendly to tracing
|
147 |
+
return v.__len__()
|
148 |
+
raise NotImplementedError("Empty Instances does not support __len__!")
|
149 |
+
|
150 |
+
def __iter__(self):
|
151 |
+
raise NotImplementedError("`Instances` object is not iterable!")
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def cat(instance_lists: List["Instances"]) -> "Instances":
|
155 |
+
"""
|
156 |
+
Args:
|
157 |
+
instance_lists (list[Instances])
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Instances
|
161 |
+
"""
|
162 |
+
assert all(isinstance(i, Instances) for i in instance_lists)
|
163 |
+
assert len(instance_lists) > 0
|
164 |
+
if len(instance_lists) == 1:
|
165 |
+
return instance_lists[0]
|
166 |
+
|
167 |
+
image_size = instance_lists[0].image_size
|
168 |
+
if not isinstance(image_size, torch.Tensor): # could be a tensor in tracing
|
169 |
+
for i in instance_lists[1:]:
|
170 |
+
assert i.image_size == image_size
|
171 |
+
ret = Instances(image_size)
|
172 |
+
for k in instance_lists[0]._fields.keys():
|
173 |
+
values = [i.get(k) for i in instance_lists]
|
174 |
+
v0 = values[0]
|
175 |
+
if isinstance(v0, torch.Tensor):
|
176 |
+
values = torch.cat(values, dim=0)
|
177 |
+
elif isinstance(v0, list):
|
178 |
+
values = list(itertools.chain(*values))
|
179 |
+
elif hasattr(type(v0), "cat"):
|
180 |
+
values = type(v0).cat(values)
|
181 |
+
else:
|
182 |
+
raise ValueError("Unsupported type {} for concatenation".format(type(v0)))
|
183 |
+
ret.set(k, values)
|
184 |
+
return ret
|
185 |
+
|
186 |
+
def __str__(self) -> str:
|
187 |
+
s = self.__class__.__name__ + "("
|
188 |
+
s += "num_instances={}, ".format(len(self))
|
189 |
+
s += "image_height={}, ".format(self._image_size[0])
|
190 |
+
s += "image_width={}, ".format(self._image_size[1])
|
191 |
+
s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
|
192 |
+
return s
|
193 |
+
|
194 |
+
__repr__ = __str__
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/keypoints.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import numpy as np
|
3 |
+
from typing import Any, List, Tuple, Union
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Keypoints:
|
9 |
+
"""
|
10 |
+
Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
|
11 |
+
containing the x,y location and visibility flag of each keypoint. This tensor has shape
|
12 |
+
(N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
|
13 |
+
|
14 |
+
The visibility flag follows the COCO format and must be one of three integers:
|
15 |
+
|
16 |
+
* v=0: not labeled (in which case x=y=0)
|
17 |
+
* v=1: labeled but not visible
|
18 |
+
* v=2: labeled and visible
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
|
22 |
+
"""
|
23 |
+
Arguments:
|
24 |
+
keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
|
25 |
+
The shape should be (N, K, 3) where N is the number of
|
26 |
+
instances, and K is the number of keypoints per instance.
|
27 |
+
"""
|
28 |
+
device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu")
|
29 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
|
30 |
+
assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
|
31 |
+
self.tensor = keypoints
|
32 |
+
|
33 |
+
def __len__(self) -> int:
|
34 |
+
return self.tensor.size(0)
|
35 |
+
|
36 |
+
def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
|
37 |
+
return type(self)(self.tensor.to(*args, **kwargs))
|
38 |
+
|
39 |
+
@property
|
40 |
+
def device(self) -> torch.device:
|
41 |
+
return self.tensor.device
|
42 |
+
|
43 |
+
def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Convert keypoint annotations to a heatmap of one-hot labels for training,
|
46 |
+
as described in :paper:`Mask R-CNN`.
|
47 |
+
|
48 |
+
Arguments:
|
49 |
+
boxes: Nx4 tensor, the boxes to draw the keypoints to
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
heatmaps:
|
53 |
+
A tensor of shape (N, K), each element is integer spatial label
|
54 |
+
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
|
55 |
+
valid:
|
56 |
+
A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
|
57 |
+
"""
|
58 |
+
return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
|
59 |
+
|
60 |
+
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
|
61 |
+
"""
|
62 |
+
Create a new `Keypoints` by indexing on this `Keypoints`.
|
63 |
+
|
64 |
+
The following usage are allowed:
|
65 |
+
|
66 |
+
1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
|
67 |
+
2. `new_kpts = kpts[2:10]`: return a slice of key points.
|
68 |
+
3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
|
69 |
+
with `length = len(kpts)`. Nonzero elements in the vector will be selected.
|
70 |
+
|
71 |
+
Note that the returned Keypoints might share storage with this Keypoints,
|
72 |
+
subject to Pytorch's indexing semantics.
|
73 |
+
"""
|
74 |
+
if isinstance(item, int):
|
75 |
+
return Keypoints([self.tensor[item]])
|
76 |
+
return Keypoints(self.tensor[item])
|
77 |
+
|
78 |
+
def __repr__(self) -> str:
|
79 |
+
s = self.__class__.__name__ + "("
|
80 |
+
s += "num_instances={})".format(len(self.tensor))
|
81 |
+
return s
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
|
85 |
+
"""
|
86 |
+
Concatenates a list of Keypoints into a single Keypoints
|
87 |
+
|
88 |
+
Arguments:
|
89 |
+
keypoints_list (list[Keypoints])
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Keypoints: the concatenated Keypoints
|
93 |
+
"""
|
94 |
+
assert isinstance(keypoints_list, (list, tuple))
|
95 |
+
assert len(keypoints_list) > 0
|
96 |
+
assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
|
97 |
+
|
98 |
+
cat_kpts = type(keypoints_list[0])(
|
99 |
+
torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
|
100 |
+
)
|
101 |
+
return cat_kpts
|
102 |
+
|
103 |
+
|
104 |
+
# TODO make this nicer, this is a direct translation from C2 (but removing the inner loop)
|
105 |
+
def _keypoints_to_heatmap(
|
106 |
+
keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
|
107 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
108 |
+
"""
|
109 |
+
Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
|
110 |
+
|
111 |
+
Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
|
112 |
+
closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
|
113 |
+
continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
|
114 |
+
d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
|
115 |
+
|
116 |
+
Arguments:
|
117 |
+
keypoints: tensor of keypoint locations in of shape (N, K, 3).
|
118 |
+
rois: Nx4 tensor of rois in xyxy format
|
119 |
+
heatmap_size: integer side length of square heatmap.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
heatmaps: A tensor of shape (N, K) containing an integer spatial label
|
123 |
+
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
|
124 |
+
valid: A tensor of shape (N, K) containing whether each keypoint is in
|
125 |
+
the roi or not.
|
126 |
+
"""
|
127 |
+
|
128 |
+
if rois.numel() == 0:
|
129 |
+
return rois.new().long(), rois.new().long()
|
130 |
+
offset_x = rois[:, 0]
|
131 |
+
offset_y = rois[:, 1]
|
132 |
+
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
|
133 |
+
scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
|
134 |
+
|
135 |
+
offset_x = offset_x[:, None]
|
136 |
+
offset_y = offset_y[:, None]
|
137 |
+
scale_x = scale_x[:, None]
|
138 |
+
scale_y = scale_y[:, None]
|
139 |
+
|
140 |
+
x = keypoints[..., 0]
|
141 |
+
y = keypoints[..., 1]
|
142 |
+
|
143 |
+
x_boundary_inds = x == rois[:, 2][:, None]
|
144 |
+
y_boundary_inds = y == rois[:, 3][:, None]
|
145 |
+
|
146 |
+
x = (x - offset_x) * scale_x
|
147 |
+
x = x.floor().long()
|
148 |
+
y = (y - offset_y) * scale_y
|
149 |
+
y = y.floor().long()
|
150 |
+
|
151 |
+
x[x_boundary_inds] = heatmap_size - 1
|
152 |
+
y[y_boundary_inds] = heatmap_size - 1
|
153 |
+
|
154 |
+
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
|
155 |
+
vis = keypoints[..., 2] > 0
|
156 |
+
valid = (valid_loc & vis).long()
|
157 |
+
|
158 |
+
lin_ind = y * heatmap_size + x
|
159 |
+
heatmaps = lin_ind * valid
|
160 |
+
|
161 |
+
return heatmaps, valid
|
162 |
+
|
163 |
+
|
164 |
+
@torch.jit.script_if_tracing
|
165 |
+
def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
|
166 |
+
"""
|
167 |
+
Extract predicted keypoint locations from heatmaps.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
|
171 |
+
each ROI and each keypoint.
|
172 |
+
rois (Tensor): (#ROIs, 4). The box of each ROI.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
|
176 |
+
(x, y, logit, score) for each keypoint.
|
177 |
+
|
178 |
+
When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
|
179 |
+
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
|
180 |
+
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
|
181 |
+
"""
|
182 |
+
|
183 |
+
offset_x = rois[:, 0]
|
184 |
+
offset_y = rois[:, 1]
|
185 |
+
|
186 |
+
widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
|
187 |
+
heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
|
188 |
+
widths_ceil = widths.ceil()
|
189 |
+
heights_ceil = heights.ceil()
|
190 |
+
|
191 |
+
num_rois, num_keypoints = maps.shape[:2]
|
192 |
+
xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
|
193 |
+
|
194 |
+
width_corrections = widths / widths_ceil
|
195 |
+
height_corrections = heights / heights_ceil
|
196 |
+
|
197 |
+
keypoints_idx = torch.arange(num_keypoints, device=maps.device)
|
198 |
+
|
199 |
+
for i in range(num_rois):
|
200 |
+
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
|
201 |
+
roi_map = F.interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False)
|
202 |
+
|
203 |
+
# Although semantically equivalent, `reshape` is used instead of `squeeze` due
|
204 |
+
# to limitation during ONNX export of `squeeze` in scripting mode
|
205 |
+
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
|
206 |
+
|
207 |
+
# softmax over the spatial region
|
208 |
+
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
|
209 |
+
max_score = max_score.view(num_keypoints, 1, 1)
|
210 |
+
tmp_full_resolution = (roi_map - max_score).exp_()
|
211 |
+
tmp_pool_resolution = (maps[i] - max_score).exp_()
|
212 |
+
# Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
|
213 |
+
# so that the scores of objects of different absolute sizes will be more comparable
|
214 |
+
roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum((1, 2), keepdim=True)
|
215 |
+
|
216 |
+
w = roi_map.shape[2]
|
217 |
+
pos = roi_map.view(num_keypoints, -1).argmax(1)
|
218 |
+
|
219 |
+
x_int = pos % w
|
220 |
+
y_int = (pos - x_int) // w
|
221 |
+
|
222 |
+
assert (
|
223 |
+
roi_map_scores[keypoints_idx, y_int, x_int]
|
224 |
+
== roi_map_scores.view(num_keypoints, -1).max(1)[0]
|
225 |
+
).all()
|
226 |
+
|
227 |
+
x = (x_int.float() + 0.5) * width_corrections[i]
|
228 |
+
y = (y_int.float() + 0.5) * height_corrections[i]
|
229 |
+
|
230 |
+
xy_preds[i, :, 0] = x + offset_x[i]
|
231 |
+
xy_preds[i, :, 1] = y + offset_y[i]
|
232 |
+
xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
|
233 |
+
xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
|
234 |
+
|
235 |
+
return xy_preds
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/masks.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import itertools
|
4 |
+
import numpy as np
|
5 |
+
from typing import Any, Iterator, List, Union
|
6 |
+
import annotator.oneformer.pycocotools.mask as mask_util
|
7 |
+
import torch
|
8 |
+
from torch import device
|
9 |
+
|
10 |
+
from annotator.oneformer.detectron2.layers.roi_align import ROIAlign
|
11 |
+
from annotator.oneformer.detectron2.utils.memory import retry_if_cuda_oom
|
12 |
+
|
13 |
+
from .boxes import Boxes
|
14 |
+
|
15 |
+
|
16 |
+
def polygon_area(x, y):
|
17 |
+
# Using the shoelace formula
|
18 |
+
# https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
|
19 |
+
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
|
20 |
+
|
21 |
+
|
22 |
+
def polygons_to_bitmask(polygons: List[np.ndarray], height: int, width: int) -> np.ndarray:
|
23 |
+
"""
|
24 |
+
Args:
|
25 |
+
polygons (list[ndarray]): each array has shape (Nx2,)
|
26 |
+
height, width (int)
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
ndarray: a bool mask of shape (height, width)
|
30 |
+
"""
|
31 |
+
if len(polygons) == 0:
|
32 |
+
# COCOAPI does not support empty polygons
|
33 |
+
return np.zeros((height, width)).astype(bool)
|
34 |
+
rles = mask_util.frPyObjects(polygons, height, width)
|
35 |
+
rle = mask_util.merge(rles)
|
36 |
+
return mask_util.decode(rle).astype(bool)
|
37 |
+
|
38 |
+
|
39 |
+
def rasterize_polygons_within_box(
|
40 |
+
polygons: List[np.ndarray], box: np.ndarray, mask_size: int
|
41 |
+
) -> torch.Tensor:
|
42 |
+
"""
|
43 |
+
Rasterize the polygons into a mask image and
|
44 |
+
crop the mask content in the given box.
|
45 |
+
The cropped mask is resized to (mask_size, mask_size).
|
46 |
+
|
47 |
+
This function is used when generating training targets for mask head in Mask R-CNN.
|
48 |
+
Given original ground-truth masks for an image, new ground-truth mask
|
49 |
+
training targets in the size of `mask_size x mask_size`
|
50 |
+
must be provided for each predicted box. This function will be called to
|
51 |
+
produce such targets.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
polygons (list[ndarray[float]]): a list of polygons, which represents an instance.
|
55 |
+
box: 4-element numpy array
|
56 |
+
mask_size (int):
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Tensor: BoolTensor of shape (mask_size, mask_size)
|
60 |
+
"""
|
61 |
+
# 1. Shift the polygons w.r.t the boxes
|
62 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
63 |
+
|
64 |
+
polygons = copy.deepcopy(polygons)
|
65 |
+
for p in polygons:
|
66 |
+
p[0::2] = p[0::2] - box[0]
|
67 |
+
p[1::2] = p[1::2] - box[1]
|
68 |
+
|
69 |
+
# 2. Rescale the polygons to the new box size
|
70 |
+
# max() to avoid division by small number
|
71 |
+
ratio_h = mask_size / max(h, 0.1)
|
72 |
+
ratio_w = mask_size / max(w, 0.1)
|
73 |
+
|
74 |
+
if ratio_h == ratio_w:
|
75 |
+
for p in polygons:
|
76 |
+
p *= ratio_h
|
77 |
+
else:
|
78 |
+
for p in polygons:
|
79 |
+
p[0::2] *= ratio_w
|
80 |
+
p[1::2] *= ratio_h
|
81 |
+
|
82 |
+
# 3. Rasterize the polygons with coco api
|
83 |
+
mask = polygons_to_bitmask(polygons, mask_size, mask_size)
|
84 |
+
mask = torch.from_numpy(mask)
|
85 |
+
return mask
|
86 |
+
|
87 |
+
|
88 |
+
class BitMasks:
|
89 |
+
"""
|
90 |
+
This class stores the segmentation masks for all objects in one image, in
|
91 |
+
the form of bitmaps.
|
92 |
+
|
93 |
+
Attributes:
|
94 |
+
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
|
98 |
+
"""
|
99 |
+
Args:
|
100 |
+
tensor: bool Tensor of N,H,W, representing N instances in the image.
|
101 |
+
"""
|
102 |
+
if isinstance(tensor, torch.Tensor):
|
103 |
+
tensor = tensor.to(torch.bool)
|
104 |
+
else:
|
105 |
+
tensor = torch.as_tensor(tensor, dtype=torch.bool, device=torch.device("cpu"))
|
106 |
+
assert tensor.dim() == 3, tensor.size()
|
107 |
+
self.image_size = tensor.shape[1:]
|
108 |
+
self.tensor = tensor
|
109 |
+
|
110 |
+
@torch.jit.unused
|
111 |
+
def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
|
112 |
+
return BitMasks(self.tensor.to(*args, **kwargs))
|
113 |
+
|
114 |
+
@property
|
115 |
+
def device(self) -> torch.device:
|
116 |
+
return self.tensor.device
|
117 |
+
|
118 |
+
@torch.jit.unused
|
119 |
+
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
|
120 |
+
"""
|
121 |
+
Returns:
|
122 |
+
BitMasks: Create a new :class:`BitMasks` by indexing.
|
123 |
+
|
124 |
+
The following usage are allowed:
|
125 |
+
|
126 |
+
1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
|
127 |
+
2. `new_masks = masks[2:10]`: return a slice of masks.
|
128 |
+
3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
129 |
+
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
130 |
+
|
131 |
+
Note that the returned object might share storage with this object,
|
132 |
+
subject to Pytorch's indexing semantics.
|
133 |
+
"""
|
134 |
+
if isinstance(item, int):
|
135 |
+
return BitMasks(self.tensor[item].unsqueeze(0))
|
136 |
+
m = self.tensor[item]
|
137 |
+
assert m.dim() == 3, "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
138 |
+
item, m.shape
|
139 |
+
)
|
140 |
+
return BitMasks(m)
|
141 |
+
|
142 |
+
@torch.jit.unused
|
143 |
+
def __iter__(self) -> torch.Tensor:
|
144 |
+
yield from self.tensor
|
145 |
+
|
146 |
+
@torch.jit.unused
|
147 |
+
def __repr__(self) -> str:
|
148 |
+
s = self.__class__.__name__ + "("
|
149 |
+
s += "num_instances={})".format(len(self.tensor))
|
150 |
+
return s
|
151 |
+
|
152 |
+
def __len__(self) -> int:
|
153 |
+
return self.tensor.shape[0]
|
154 |
+
|
155 |
+
def nonempty(self) -> torch.Tensor:
|
156 |
+
"""
|
157 |
+
Find masks that are non-empty.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Tensor: a BoolTensor which represents
|
161 |
+
whether each mask is empty (False) or non-empty (True).
|
162 |
+
"""
|
163 |
+
return self.tensor.flatten(1).any(dim=1)
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def from_polygon_masks(
|
167 |
+
polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]], height: int, width: int
|
168 |
+
) -> "BitMasks":
|
169 |
+
"""
|
170 |
+
Args:
|
171 |
+
polygon_masks (list[list[ndarray]] or PolygonMasks)
|
172 |
+
height, width (int)
|
173 |
+
"""
|
174 |
+
if isinstance(polygon_masks, PolygonMasks):
|
175 |
+
polygon_masks = polygon_masks.polygons
|
176 |
+
masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks]
|
177 |
+
if len(masks):
|
178 |
+
return BitMasks(torch.stack([torch.from_numpy(x) for x in masks]))
|
179 |
+
else:
|
180 |
+
return BitMasks(torch.empty(0, height, width, dtype=torch.bool))
|
181 |
+
|
182 |
+
@staticmethod
|
183 |
+
def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks":
|
184 |
+
"""
|
185 |
+
Args:
|
186 |
+
roi_masks:
|
187 |
+
height, width (int):
|
188 |
+
"""
|
189 |
+
return roi_masks.to_bitmasks(height, width)
|
190 |
+
|
191 |
+
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
|
192 |
+
"""
|
193 |
+
Crop each bitmask by the given box, and resize results to (mask_size, mask_size).
|
194 |
+
This can be used to prepare training targets for Mask R-CNN.
|
195 |
+
It has less reconstruction error compared to rasterization with polygons.
|
196 |
+
However we observe no difference in accuracy,
|
197 |
+
but BitMasks requires more memory to store all the masks.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
boxes (Tensor): Nx4 tensor storing the boxes for each mask
|
201 |
+
mask_size (int): the size of the rasterized mask.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
Tensor:
|
205 |
+
A bool tensor of shape (N, mask_size, mask_size), where
|
206 |
+
N is the number of predicted boxes for this image.
|
207 |
+
"""
|
208 |
+
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
|
209 |
+
device = self.tensor.device
|
210 |
+
|
211 |
+
batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[:, None]
|
212 |
+
rois = torch.cat([batch_inds, boxes], dim=1) # Nx5
|
213 |
+
|
214 |
+
bit_masks = self.tensor.to(dtype=torch.float32)
|
215 |
+
rois = rois.to(device=device)
|
216 |
+
output = (
|
217 |
+
ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True)
|
218 |
+
.forward(bit_masks[:, None, :, :], rois)
|
219 |
+
.squeeze(1)
|
220 |
+
)
|
221 |
+
output = output >= 0.5
|
222 |
+
return output
|
223 |
+
|
224 |
+
def get_bounding_boxes(self) -> Boxes:
|
225 |
+
"""
|
226 |
+
Returns:
|
227 |
+
Boxes: tight bounding boxes around bitmasks.
|
228 |
+
If a mask is empty, it's bounding box will be all zero.
|
229 |
+
"""
|
230 |
+
boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32)
|
231 |
+
x_any = torch.any(self.tensor, dim=1)
|
232 |
+
y_any = torch.any(self.tensor, dim=2)
|
233 |
+
for idx in range(self.tensor.shape[0]):
|
234 |
+
x = torch.where(x_any[idx, :])[0]
|
235 |
+
y = torch.where(y_any[idx, :])[0]
|
236 |
+
if len(x) > 0 and len(y) > 0:
|
237 |
+
boxes[idx, :] = torch.as_tensor(
|
238 |
+
[x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32
|
239 |
+
)
|
240 |
+
return Boxes(boxes)
|
241 |
+
|
242 |
+
@staticmethod
|
243 |
+
def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
|
244 |
+
"""
|
245 |
+
Concatenates a list of BitMasks into a single BitMasks
|
246 |
+
|
247 |
+
Arguments:
|
248 |
+
bitmasks_list (list[BitMasks])
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
BitMasks: the concatenated BitMasks
|
252 |
+
"""
|
253 |
+
assert isinstance(bitmasks_list, (list, tuple))
|
254 |
+
assert len(bitmasks_list) > 0
|
255 |
+
assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
|
256 |
+
|
257 |
+
cat_bitmasks = type(bitmasks_list[0])(torch.cat([bm.tensor for bm in bitmasks_list], dim=0))
|
258 |
+
return cat_bitmasks
|
259 |
+
|
260 |
+
|
261 |
+
class PolygonMasks:
|
262 |
+
"""
|
263 |
+
This class stores the segmentation masks for all objects in one image, in the form of polygons.
|
264 |
+
|
265 |
+
Attributes:
|
266 |
+
polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
|
267 |
+
"""
|
268 |
+
|
269 |
+
def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]):
|
270 |
+
"""
|
271 |
+
Arguments:
|
272 |
+
polygons (list[list[np.ndarray]]): The first
|
273 |
+
level of the list correspond to individual instances,
|
274 |
+
the second level to all the polygons that compose the
|
275 |
+
instance, and the third level to the polygon coordinates.
|
276 |
+
The third level array should have the format of
|
277 |
+
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
|
278 |
+
"""
|
279 |
+
if not isinstance(polygons, list):
|
280 |
+
raise ValueError(
|
281 |
+
"Cannot create PolygonMasks: Expect a list of list of polygons per image. "
|
282 |
+
"Got '{}' instead.".format(type(polygons))
|
283 |
+
)
|
284 |
+
|
285 |
+
def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
286 |
+
# Use float64 for higher precision, because why not?
|
287 |
+
# Always put polygons on CPU (self.to is a no-op) since they
|
288 |
+
# are supposed to be small tensors.
|
289 |
+
# May need to change this assumption if GPU placement becomes useful
|
290 |
+
if isinstance(t, torch.Tensor):
|
291 |
+
t = t.cpu().numpy()
|
292 |
+
return np.asarray(t).astype("float64")
|
293 |
+
|
294 |
+
def process_polygons(
|
295 |
+
polygons_per_instance: List[Union[torch.Tensor, np.ndarray]]
|
296 |
+
) -> List[np.ndarray]:
|
297 |
+
if not isinstance(polygons_per_instance, list):
|
298 |
+
raise ValueError(
|
299 |
+
"Cannot create polygons: Expect a list of polygons per instance. "
|
300 |
+
"Got '{}' instead.".format(type(polygons_per_instance))
|
301 |
+
)
|
302 |
+
# transform each polygon to a numpy array
|
303 |
+
polygons_per_instance = [_make_array(p) for p in polygons_per_instance]
|
304 |
+
for polygon in polygons_per_instance:
|
305 |
+
if len(polygon) % 2 != 0 or len(polygon) < 6:
|
306 |
+
raise ValueError(f"Cannot create a polygon from {len(polygon)} coordinates.")
|
307 |
+
return polygons_per_instance
|
308 |
+
|
309 |
+
self.polygons: List[List[np.ndarray]] = [
|
310 |
+
process_polygons(polygons_per_instance) for polygons_per_instance in polygons
|
311 |
+
]
|
312 |
+
|
313 |
+
def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks":
|
314 |
+
return self
|
315 |
+
|
316 |
+
@property
|
317 |
+
def device(self) -> torch.device:
|
318 |
+
return torch.device("cpu")
|
319 |
+
|
320 |
+
def get_bounding_boxes(self) -> Boxes:
|
321 |
+
"""
|
322 |
+
Returns:
|
323 |
+
Boxes: tight bounding boxes around polygon masks.
|
324 |
+
"""
|
325 |
+
boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32)
|
326 |
+
for idx, polygons_per_instance in enumerate(self.polygons):
|
327 |
+
minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32)
|
328 |
+
maxxy = torch.zeros(2, dtype=torch.float32)
|
329 |
+
for polygon in polygons_per_instance:
|
330 |
+
coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32)
|
331 |
+
minxy = torch.min(minxy, torch.min(coords, dim=0).values)
|
332 |
+
maxxy = torch.max(maxxy, torch.max(coords, dim=0).values)
|
333 |
+
boxes[idx, :2] = minxy
|
334 |
+
boxes[idx, 2:] = maxxy
|
335 |
+
return Boxes(boxes)
|
336 |
+
|
337 |
+
def nonempty(self) -> torch.Tensor:
|
338 |
+
"""
|
339 |
+
Find masks that are non-empty.
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
Tensor:
|
343 |
+
a BoolTensor which represents whether each mask is empty (False) or not (True).
|
344 |
+
"""
|
345 |
+
keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons]
|
346 |
+
return torch.from_numpy(np.asarray(keep, dtype=bool))
|
347 |
+
|
348 |
+
def __getitem__(self, item: Union[int, slice, List[int], torch.BoolTensor]) -> "PolygonMasks":
|
349 |
+
"""
|
350 |
+
Support indexing over the instances and return a `PolygonMasks` object.
|
351 |
+
`item` can be:
|
352 |
+
|
353 |
+
1. An integer. It will return an object with only one instance.
|
354 |
+
2. A slice. It will return an object with the selected instances.
|
355 |
+
3. A list[int]. It will return an object with the selected instances,
|
356 |
+
correpsonding to the indices in the list.
|
357 |
+
4. A vector mask of type BoolTensor, whose length is num_instances.
|
358 |
+
It will return an object with the instances whose mask is nonzero.
|
359 |
+
"""
|
360 |
+
if isinstance(item, int):
|
361 |
+
selected_polygons = [self.polygons[item]]
|
362 |
+
elif isinstance(item, slice):
|
363 |
+
selected_polygons = self.polygons[item]
|
364 |
+
elif isinstance(item, list):
|
365 |
+
selected_polygons = [self.polygons[i] for i in item]
|
366 |
+
elif isinstance(item, torch.Tensor):
|
367 |
+
# Polygons is a list, so we have to move the indices back to CPU.
|
368 |
+
if item.dtype == torch.bool:
|
369 |
+
assert item.dim() == 1, item.shape
|
370 |
+
item = item.nonzero().squeeze(1).cpu().numpy().tolist()
|
371 |
+
elif item.dtype in [torch.int32, torch.int64]:
|
372 |
+
item = item.cpu().numpy().tolist()
|
373 |
+
else:
|
374 |
+
raise ValueError("Unsupported tensor dtype={} for indexing!".format(item.dtype))
|
375 |
+
selected_polygons = [self.polygons[i] for i in item]
|
376 |
+
return PolygonMasks(selected_polygons)
|
377 |
+
|
378 |
+
def __iter__(self) -> Iterator[List[np.ndarray]]:
|
379 |
+
"""
|
380 |
+
Yields:
|
381 |
+
list[ndarray]: the polygons for one instance.
|
382 |
+
Each Tensor is a float64 vector representing a polygon.
|
383 |
+
"""
|
384 |
+
return iter(self.polygons)
|
385 |
+
|
386 |
+
def __repr__(self) -> str:
|
387 |
+
s = self.__class__.__name__ + "("
|
388 |
+
s += "num_instances={})".format(len(self.polygons))
|
389 |
+
return s
|
390 |
+
|
391 |
+
def __len__(self) -> int:
|
392 |
+
return len(self.polygons)
|
393 |
+
|
394 |
+
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
|
395 |
+
"""
|
396 |
+
Crop each mask by the given box, and resize results to (mask_size, mask_size).
|
397 |
+
This can be used to prepare training targets for Mask R-CNN.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
boxes (Tensor): Nx4 tensor storing the boxes for each mask
|
401 |
+
mask_size (int): the size of the rasterized mask.
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
Tensor: A bool tensor of shape (N, mask_size, mask_size), where
|
405 |
+
N is the number of predicted boxes for this image.
|
406 |
+
"""
|
407 |
+
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
|
408 |
+
|
409 |
+
device = boxes.device
|
410 |
+
# Put boxes on the CPU, as the polygon representation is not efficient GPU-wise
|
411 |
+
# (several small tensors for representing a single instance mask)
|
412 |
+
boxes = boxes.to(torch.device("cpu"))
|
413 |
+
|
414 |
+
results = [
|
415 |
+
rasterize_polygons_within_box(poly, box.numpy(), mask_size)
|
416 |
+
for poly, box in zip(self.polygons, boxes)
|
417 |
+
]
|
418 |
+
"""
|
419 |
+
poly: list[list[float]], the polygons for one instance
|
420 |
+
box: a tensor of shape (4,)
|
421 |
+
"""
|
422 |
+
if len(results) == 0:
|
423 |
+
return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device)
|
424 |
+
return torch.stack(results, dim=0).to(device=device)
|
425 |
+
|
426 |
+
def area(self):
|
427 |
+
"""
|
428 |
+
Computes area of the mask.
|
429 |
+
Only works with Polygons, using the shoelace formula:
|
430 |
+
https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
Tensor: a vector, area for each instance
|
434 |
+
"""
|
435 |
+
|
436 |
+
area = []
|
437 |
+
for polygons_per_instance in self.polygons:
|
438 |
+
area_per_instance = 0
|
439 |
+
for p in polygons_per_instance:
|
440 |
+
area_per_instance += polygon_area(p[0::2], p[1::2])
|
441 |
+
area.append(area_per_instance)
|
442 |
+
|
443 |
+
return torch.tensor(area)
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks":
|
447 |
+
"""
|
448 |
+
Concatenates a list of PolygonMasks into a single PolygonMasks
|
449 |
+
|
450 |
+
Arguments:
|
451 |
+
polymasks_list (list[PolygonMasks])
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
PolygonMasks: the concatenated PolygonMasks
|
455 |
+
"""
|
456 |
+
assert isinstance(polymasks_list, (list, tuple))
|
457 |
+
assert len(polymasks_list) > 0
|
458 |
+
assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list)
|
459 |
+
|
460 |
+
cat_polymasks = type(polymasks_list[0])(
|
461 |
+
list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list))
|
462 |
+
)
|
463 |
+
return cat_polymasks
|
464 |
+
|
465 |
+
|
466 |
+
class ROIMasks:
|
467 |
+
"""
|
468 |
+
Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given,
|
469 |
+
full-image bitmask can be obtained by "pasting" the mask on the region defined
|
470 |
+
by the corresponding ROI box.
|
471 |
+
"""
|
472 |
+
|
473 |
+
def __init__(self, tensor: torch.Tensor):
|
474 |
+
"""
|
475 |
+
Args:
|
476 |
+
tensor: (N, M, M) mask tensor that defines the mask within each ROI.
|
477 |
+
"""
|
478 |
+
if tensor.dim() != 3:
|
479 |
+
raise ValueError("ROIMasks must take a masks of 3 dimension.")
|
480 |
+
self.tensor = tensor
|
481 |
+
|
482 |
+
def to(self, device: torch.device) -> "ROIMasks":
|
483 |
+
return ROIMasks(self.tensor.to(device))
|
484 |
+
|
485 |
+
@property
|
486 |
+
def device(self) -> device:
|
487 |
+
return self.tensor.device
|
488 |
+
|
489 |
+
def __len__(self):
|
490 |
+
return self.tensor.shape[0]
|
491 |
+
|
492 |
+
def __getitem__(self, item) -> "ROIMasks":
|
493 |
+
"""
|
494 |
+
Returns:
|
495 |
+
ROIMasks: Create a new :class:`ROIMasks` by indexing.
|
496 |
+
|
497 |
+
The following usage are allowed:
|
498 |
+
|
499 |
+
1. `new_masks = masks[2:10]`: return a slice of masks.
|
500 |
+
2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
|
501 |
+
with `length = len(masks)`. Nonzero elements in the vector will be selected.
|
502 |
+
|
503 |
+
Note that the returned object might share storage with this object,
|
504 |
+
subject to Pytorch's indexing semantics.
|
505 |
+
"""
|
506 |
+
t = self.tensor[item]
|
507 |
+
if t.dim() != 3:
|
508 |
+
raise ValueError(
|
509 |
+
f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!"
|
510 |
+
)
|
511 |
+
return ROIMasks(t)
|
512 |
+
|
513 |
+
@torch.jit.unused
|
514 |
+
def __repr__(self) -> str:
|
515 |
+
s = self.__class__.__name__ + "("
|
516 |
+
s += "num_instances={})".format(len(self.tensor))
|
517 |
+
return s
|
518 |
+
|
519 |
+
@torch.jit.unused
|
520 |
+
def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5):
|
521 |
+
"""
|
522 |
+
Args: see documentation of :func:`paste_masks_in_image`.
|
523 |
+
"""
|
524 |
+
from annotator.oneformer.detectron2.layers.mask_ops import paste_masks_in_image, _paste_masks_tensor_shape
|
525 |
+
|
526 |
+
if torch.jit.is_tracing():
|
527 |
+
if isinstance(height, torch.Tensor):
|
528 |
+
paste_func = _paste_masks_tensor_shape
|
529 |
+
else:
|
530 |
+
paste_func = paste_masks_in_image
|
531 |
+
else:
|
532 |
+
paste_func = retry_if_cuda_oom(paste_masks_in_image)
|
533 |
+
bitmasks = paste_func(self.tensor, boxes.tensor, (height, width), threshold=threshold)
|
534 |
+
return BitMasks(bitmasks)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/structures/rotated_boxes.py
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import math
|
3 |
+
from typing import List, Tuple
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from annotator.oneformer.detectron2.layers.rotated_boxes import pairwise_iou_rotated
|
7 |
+
|
8 |
+
from .boxes import Boxes
|
9 |
+
|
10 |
+
|
11 |
+
class RotatedBoxes(Boxes):
|
12 |
+
"""
|
13 |
+
This structure stores a list of rotated boxes as a Nx5 torch.Tensor.
|
14 |
+
It supports some common methods about boxes
|
15 |
+
(`area`, `clip`, `nonempty`, etc),
|
16 |
+
and also behaves like a Tensor
|
17 |
+
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, tensor: torch.Tensor):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
tensor (Tensor[float]): a Nx5 matrix. Each row is
|
24 |
+
(x_center, y_center, width, height, angle),
|
25 |
+
in which angle is represented in degrees.
|
26 |
+
While there's no strict range restriction for it,
|
27 |
+
the recommended principal range is between [-180, 180) degrees.
|
28 |
+
|
29 |
+
Assume we have a horizontal box B = (x_center, y_center, width, height),
|
30 |
+
where width is along the x-axis and height is along the y-axis.
|
31 |
+
The rotated box B_rot (x_center, y_center, width, height, angle)
|
32 |
+
can be seen as:
|
33 |
+
|
34 |
+
1. When angle == 0:
|
35 |
+
B_rot == B
|
36 |
+
2. When angle > 0:
|
37 |
+
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW;
|
38 |
+
3. When angle < 0:
|
39 |
+
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW.
|
40 |
+
|
41 |
+
Mathematically, since the right-handed coordinate system for image space
|
42 |
+
is (y, x), where y is top->down and x is left->right, the 4 vertices of the
|
43 |
+
rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from
|
44 |
+
the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4)
|
45 |
+
in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians,
|
46 |
+
:math:`(y_c, x_c)` is the center of the rectangle):
|
47 |
+
|
48 |
+
.. math::
|
49 |
+
|
50 |
+
yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c,
|
51 |
+
|
52 |
+
xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c,
|
53 |
+
|
54 |
+
which is the standard rigid-body rotation transformation.
|
55 |
+
|
56 |
+
Intuitively, the angle is
|
57 |
+
(1) the rotation angle from y-axis in image space
|
58 |
+
to the height vector (top->down in the box's local coordinate system)
|
59 |
+
of the box in CCW, and
|
60 |
+
(2) the rotation angle from x-axis in image space
|
61 |
+
to the width vector (left->right in the box's local coordinate system)
|
62 |
+
of the box in CCW.
|
63 |
+
|
64 |
+
More intuitively, consider the following horizontal box ABCD represented
|
65 |
+
in (x1, y1, x2, y2): (3, 2, 7, 4),
|
66 |
+
covering the [3, 7] x [2, 4] region of the continuous coordinate system
|
67 |
+
which looks like this:
|
68 |
+
|
69 |
+
.. code:: none
|
70 |
+
|
71 |
+
O--------> x
|
72 |
+
|
|
73 |
+
| A---B
|
74 |
+
| | |
|
75 |
+
| D---C
|
76 |
+
|
|
77 |
+
v y
|
78 |
+
|
79 |
+
Note that each capital letter represents one 0-dimensional geometric point
|
80 |
+
instead of a 'square pixel' here.
|
81 |
+
|
82 |
+
In the example above, using (x, y) to represent a point we have:
|
83 |
+
|
84 |
+
.. math::
|
85 |
+
|
86 |
+
O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4)
|
87 |
+
|
88 |
+
We name vector AB = vector DC as the width vector in box's local coordinate system, and
|
89 |
+
vector AD = vector BC as the height vector in box's local coordinate system. Initially,
|
90 |
+
when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis
|
91 |
+
in the image space, respectively.
|
92 |
+
|
93 |
+
For better illustration, we denote the center of the box as E,
|
94 |
+
|
95 |
+
.. code:: none
|
96 |
+
|
97 |
+
O--------> x
|
98 |
+
|
|
99 |
+
| A---B
|
100 |
+
| | E |
|
101 |
+
| D---C
|
102 |
+
|
|
103 |
+
v y
|
104 |
+
|
105 |
+
where the center E = ((3+7)/2, (2+4)/2) = (5, 3).
|
106 |
+
|
107 |
+
Also,
|
108 |
+
|
109 |
+
.. math::
|
110 |
+
|
111 |
+
width = |AB| = |CD| = 7 - 3 = 4,
|
112 |
+
height = |AD| = |BC| = 4 - 2 = 2.
|
113 |
+
|
114 |
+
Therefore, the corresponding representation for the same shape in rotated box in
|
115 |
+
(x_center, y_center, width, height, angle) format is:
|
116 |
+
|
117 |
+
(5, 3, 4, 2, 0),
|
118 |
+
|
119 |
+
Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees
|
120 |
+
CCW (counter-clockwise) by definition. It looks like this:
|
121 |
+
|
122 |
+
.. code:: none
|
123 |
+
|
124 |
+
O--------> x
|
125 |
+
| B-C
|
126 |
+
| | |
|
127 |
+
| |E|
|
128 |
+
| | |
|
129 |
+
| A-D
|
130 |
+
v y
|
131 |
+
|
132 |
+
The center E is still located at the same point (5, 3), while the vertices
|
133 |
+
ABCD are rotated by 90 degrees CCW with regard to E:
|
134 |
+
A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5)
|
135 |
+
|
136 |
+
Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to
|
137 |
+
vector AD or vector BC (the top->down height vector in box's local coordinate system),
|
138 |
+
or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right
|
139 |
+
width vector in box's local coordinate system).
|
140 |
+
|
141 |
+
.. math::
|
142 |
+
|
143 |
+
width = |AB| = |CD| = 5 - 1 = 4,
|
144 |
+
height = |AD| = |BC| = 6 - 4 = 2.
|
145 |
+
|
146 |
+
Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise)
|
147 |
+
by definition? It looks like this:
|
148 |
+
|
149 |
+
.. code:: none
|
150 |
+
|
151 |
+
O--------> x
|
152 |
+
| D-A
|
153 |
+
| | |
|
154 |
+
| |E|
|
155 |
+
| | |
|
156 |
+
| C-B
|
157 |
+
v y
|
158 |
+
|
159 |
+
The center E is still located at the same point (5, 3), while the vertices
|
160 |
+
ABCD are rotated by 90 degrees CW with regard to E:
|
161 |
+
A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1)
|
162 |
+
|
163 |
+
.. math::
|
164 |
+
|
165 |
+
width = |AB| = |CD| = 5 - 1 = 4,
|
166 |
+
height = |AD| = |BC| = 6 - 4 = 2.
|
167 |
+
|
168 |
+
This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU
|
169 |
+
will be 1. However, these two will generate different RoI Pooling results and
|
170 |
+
should not be treated as an identical box.
|
171 |
+
|
172 |
+
On the other hand, it's easy to see that (X, Y, W, H, A) is identical to
|
173 |
+
(X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be
|
174 |
+
identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is
|
175 |
+
equivalent to rotating the same shape 90 degrees CW.
|
176 |
+
|
177 |
+
We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180):
|
178 |
+
|
179 |
+
.. code:: none
|
180 |
+
|
181 |
+
O--------> x
|
182 |
+
|
|
183 |
+
| C---D
|
184 |
+
| | E |
|
185 |
+
| B---A
|
186 |
+
|
|
187 |
+
v y
|
188 |
+
|
189 |
+
.. math::
|
190 |
+
|
191 |
+
A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2),
|
192 |
+
|
193 |
+
width = |AB| = |CD| = 7 - 3 = 4,
|
194 |
+
height = |AD| = |BC| = 4 - 2 = 2.
|
195 |
+
|
196 |
+
Finally, this is a very inaccurate (heavily quantized) illustration of
|
197 |
+
how (5, 3, 4, 2, 60) looks like in case anyone wonders:
|
198 |
+
|
199 |
+
.. code:: none
|
200 |
+
|
201 |
+
O--------> x
|
202 |
+
| B\
|
203 |
+
| / C
|
204 |
+
| /E /
|
205 |
+
| A /
|
206 |
+
| `D
|
207 |
+
v y
|
208 |
+
|
209 |
+
It's still a rectangle with center of (5, 3), width of 4 and height of 2,
|
210 |
+
but its angle (and thus orientation) is somewhere between
|
211 |
+
(5, 3, 4, 2, 0) and (5, 3, 4, 2, 90).
|
212 |
+
"""
|
213 |
+
device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
|
214 |
+
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
|
215 |
+
if tensor.numel() == 0:
|
216 |
+
# Use reshape, so we don't end up creating a new tensor that does not depend on
|
217 |
+
# the inputs (and consequently confuses jit)
|
218 |
+
tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device)
|
219 |
+
assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size()
|
220 |
+
|
221 |
+
self.tensor = tensor
|
222 |
+
|
223 |
+
def clone(self) -> "RotatedBoxes":
|
224 |
+
"""
|
225 |
+
Clone the RotatedBoxes.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
RotatedBoxes
|
229 |
+
"""
|
230 |
+
return RotatedBoxes(self.tensor.clone())
|
231 |
+
|
232 |
+
def to(self, device: torch.device):
|
233 |
+
# Boxes are assumed float32 and does not support to(dtype)
|
234 |
+
return RotatedBoxes(self.tensor.to(device=device))
|
235 |
+
|
236 |
+
def area(self) -> torch.Tensor:
|
237 |
+
"""
|
238 |
+
Computes the area of all the boxes.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
torch.Tensor: a vector with areas of each box.
|
242 |
+
"""
|
243 |
+
box = self.tensor
|
244 |
+
area = box[:, 2] * box[:, 3]
|
245 |
+
return area
|
246 |
+
|
247 |
+
# Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor
|
248 |
+
def normalize_angles(self) -> None:
|
249 |
+
"""
|
250 |
+
Restrict angles to the range of [-180, 180) degrees
|
251 |
+
"""
|
252 |
+
angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0
|
253 |
+
self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1)
|
254 |
+
|
255 |
+
def clip(self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0) -> None:
|
256 |
+
"""
|
257 |
+
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
|
258 |
+
and y coordinates to the range [0, height].
|
259 |
+
|
260 |
+
For RRPN:
|
261 |
+
Only clip boxes that are almost horizontal with a tolerance of
|
262 |
+
clip_angle_threshold to maintain backward compatibility.
|
263 |
+
|
264 |
+
Rotated boxes beyond this threshold are not clipped for two reasons:
|
265 |
+
|
266 |
+
1. There are potentially multiple ways to clip a rotated box to make it
|
267 |
+
fit within the image.
|
268 |
+
2. It's tricky to make the entire rectangular box fit within the image
|
269 |
+
and still be able to not leave out pixels of interest.
|
270 |
+
|
271 |
+
Therefore we rely on ops like RoIAlignRotated to safely handle this.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
box_size (height, width): The clipping box's size.
|
275 |
+
clip_angle_threshold:
|
276 |
+
Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees),
|
277 |
+
we do the clipping as horizontal boxes.
|
278 |
+
"""
|
279 |
+
h, w = box_size
|
280 |
+
|
281 |
+
# normalize angles to be within (-180, 180] degrees
|
282 |
+
self.normalize_angles()
|
283 |
+
|
284 |
+
idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0]
|
285 |
+
|
286 |
+
# convert to (x1, y1, x2, y2)
|
287 |
+
x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0
|
288 |
+
y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0
|
289 |
+
x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0
|
290 |
+
y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0
|
291 |
+
|
292 |
+
# clip
|
293 |
+
x1.clamp_(min=0, max=w)
|
294 |
+
y1.clamp_(min=0, max=h)
|
295 |
+
x2.clamp_(min=0, max=w)
|
296 |
+
y2.clamp_(min=0, max=h)
|
297 |
+
|
298 |
+
# convert back to (xc, yc, w, h)
|
299 |
+
self.tensor[idx, 0] = (x1 + x2) / 2.0
|
300 |
+
self.tensor[idx, 1] = (y1 + y2) / 2.0
|
301 |
+
# make sure widths and heights do not increase due to numerical errors
|
302 |
+
self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1)
|
303 |
+
self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1)
|
304 |
+
|
305 |
+
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
|
306 |
+
"""
|
307 |
+
Find boxes that are non-empty.
|
308 |
+
A box is considered empty, if either of its side is no larger than threshold.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
Tensor: a binary vector which represents
|
312 |
+
whether each box is empty (False) or non-empty (True).
|
313 |
+
"""
|
314 |
+
box = self.tensor
|
315 |
+
widths = box[:, 2]
|
316 |
+
heights = box[:, 3]
|
317 |
+
keep = (widths > threshold) & (heights > threshold)
|
318 |
+
return keep
|
319 |
+
|
320 |
+
def __getitem__(self, item) -> "RotatedBoxes":
|
321 |
+
"""
|
322 |
+
Returns:
|
323 |
+
RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing.
|
324 |
+
|
325 |
+
The following usage are allowed:
|
326 |
+
|
327 |
+
1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box.
|
328 |
+
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
|
329 |
+
3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor
|
330 |
+
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
|
331 |
+
|
332 |
+
Note that the returned RotatedBoxes might share storage with this RotatedBoxes,
|
333 |
+
subject to Pytorch's indexing semantics.
|
334 |
+
"""
|
335 |
+
if isinstance(item, int):
|
336 |
+
return RotatedBoxes(self.tensor[item].view(1, -1))
|
337 |
+
b = self.tensor[item]
|
338 |
+
assert b.dim() == 2, "Indexing on RotatedBoxes with {} failed to return a matrix!".format(
|
339 |
+
item
|
340 |
+
)
|
341 |
+
return RotatedBoxes(b)
|
342 |
+
|
343 |
+
def __len__(self) -> int:
|
344 |
+
return self.tensor.shape[0]
|
345 |
+
|
346 |
+
def __repr__(self) -> str:
|
347 |
+
return "RotatedBoxes(" + str(self.tensor) + ")"
|
348 |
+
|
349 |
+
def inside_box(self, box_size: Tuple[int, int], boundary_threshold: int = 0) -> torch.Tensor:
|
350 |
+
"""
|
351 |
+
Args:
|
352 |
+
box_size (height, width): Size of the reference box covering
|
353 |
+
[0, width] x [0, height]
|
354 |
+
boundary_threshold (int): Boxes that extend beyond the reference box
|
355 |
+
boundary by more than boundary_threshold are considered "outside".
|
356 |
+
|
357 |
+
For RRPN, it might not be necessary to call this function since it's common
|
358 |
+
for rotated box to extend to outside of the image boundaries
|
359 |
+
(the clip function only clips the near-horizontal boxes)
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
a binary vector, indicating whether each box is inside the reference box.
|
363 |
+
"""
|
364 |
+
height, width = box_size
|
365 |
+
|
366 |
+
cnt_x = self.tensor[..., 0]
|
367 |
+
cnt_y = self.tensor[..., 1]
|
368 |
+
half_w = self.tensor[..., 2] / 2.0
|
369 |
+
half_h = self.tensor[..., 3] / 2.0
|
370 |
+
a = self.tensor[..., 4]
|
371 |
+
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
372 |
+
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
373 |
+
# This basically computes the horizontal bounding rectangle of the rotated box
|
374 |
+
max_rect_dx = c * half_w + s * half_h
|
375 |
+
max_rect_dy = c * half_h + s * half_w
|
376 |
+
|
377 |
+
inds_inside = (
|
378 |
+
(cnt_x - max_rect_dx >= -boundary_threshold)
|
379 |
+
& (cnt_y - max_rect_dy >= -boundary_threshold)
|
380 |
+
& (cnt_x + max_rect_dx < width + boundary_threshold)
|
381 |
+
& (cnt_y + max_rect_dy < height + boundary_threshold)
|
382 |
+
)
|
383 |
+
|
384 |
+
return inds_inside
|
385 |
+
|
386 |
+
def get_centers(self) -> torch.Tensor:
|
387 |
+
"""
|
388 |
+
Returns:
|
389 |
+
The box centers in a Nx2 array of (x, y).
|
390 |
+
"""
|
391 |
+
return self.tensor[:, :2]
|
392 |
+
|
393 |
+
def scale(self, scale_x: float, scale_y: float) -> None:
|
394 |
+
"""
|
395 |
+
Scale the rotated box with horizontal and vertical scaling factors
|
396 |
+
Note: when scale_factor_x != scale_factor_y,
|
397 |
+
the rotated box does not preserve the rectangular shape when the angle
|
398 |
+
is not a multiple of 90 degrees under resize transformation.
|
399 |
+
Instead, the shape is a parallelogram (that has skew)
|
400 |
+
Here we make an approximation by fitting a rotated rectangle to the parallelogram.
|
401 |
+
"""
|
402 |
+
self.tensor[:, 0] *= scale_x
|
403 |
+
self.tensor[:, 1] *= scale_y
|
404 |
+
theta = self.tensor[:, 4] * math.pi / 180.0
|
405 |
+
c = torch.cos(theta)
|
406 |
+
s = torch.sin(theta)
|
407 |
+
|
408 |
+
# In image space, y is top->down and x is left->right
|
409 |
+
# Consider the local coordintate system for the rotated box,
|
410 |
+
# where the box center is located at (0, 0), and the four vertices ABCD are
|
411 |
+
# A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2)
|
412 |
+
# the midpoint of the left edge AD of the rotated box E is:
|
413 |
+
# E = (A+D)/2 = (-w / 2, 0)
|
414 |
+
# the midpoint of the top edge AB of the rotated box F is:
|
415 |
+
# F(0, -h / 2)
|
416 |
+
# To get the old coordinates in the global system, apply the rotation transformation
|
417 |
+
# (Note: the right-handed coordinate system for image space is yOx):
|
418 |
+
# (old_x, old_y) = (s * y + c * x, c * y - s * x)
|
419 |
+
# E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2)
|
420 |
+
# F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2)
|
421 |
+
# After applying the scaling factor (sfx, sfy):
|
422 |
+
# E(new) = (-sfx * c * w / 2, sfy * s * w / 2)
|
423 |
+
# F(new) = (-sfx * s * h / 2, -sfy * c * h / 2)
|
424 |
+
# The new width after scaling tranformation becomes:
|
425 |
+
|
426 |
+
# w(new) = |E(new) - O| * 2
|
427 |
+
# = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2
|
428 |
+
# = sqrt[(sfx * c)^2 + (sfy * s)^2] * w
|
429 |
+
# i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2]
|
430 |
+
#
|
431 |
+
# For example,
|
432 |
+
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x;
|
433 |
+
# when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y
|
434 |
+
self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2)
|
435 |
+
|
436 |
+
# h(new) = |F(new) - O| * 2
|
437 |
+
# = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2
|
438 |
+
# = sqrt[(sfx * s)^2 + (sfy * c)^2] * h
|
439 |
+
# i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2]
|
440 |
+
#
|
441 |
+
# For example,
|
442 |
+
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y;
|
443 |
+
# when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x
|
444 |
+
self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2)
|
445 |
+
|
446 |
+
# The angle is the rotation angle from y-axis in image space to the height
|
447 |
+
# vector (top->down in the box's local coordinate system) of the box in CCW.
|
448 |
+
#
|
449 |
+
# angle(new) = angle_yOx(O - F(new))
|
450 |
+
# = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) )
|
451 |
+
# = atan2(sfx * s * h / 2, sfy * c * h / 2)
|
452 |
+
# = atan2(sfx * s, sfy * c)
|
453 |
+
#
|
454 |
+
# For example,
|
455 |
+
# when sfx == sfy, angle(new) == atan2(s, c) == angle(old)
|
456 |
+
self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi
|
457 |
+
|
458 |
+
@classmethod
|
459 |
+
def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes":
|
460 |
+
"""
|
461 |
+
Concatenates a list of RotatedBoxes into a single RotatedBoxes
|
462 |
+
|
463 |
+
Arguments:
|
464 |
+
boxes_list (list[RotatedBoxes])
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
RotatedBoxes: the concatenated RotatedBoxes
|
468 |
+
"""
|
469 |
+
assert isinstance(boxes_list, (list, tuple))
|
470 |
+
if len(boxes_list) == 0:
|
471 |
+
return cls(torch.empty(0))
|
472 |
+
assert all([isinstance(box, RotatedBoxes) for box in boxes_list])
|
473 |
+
|
474 |
+
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
|
475 |
+
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
|
476 |
+
return cat_boxes
|
477 |
+
|
478 |
+
@property
|
479 |
+
def device(self) -> torch.device:
|
480 |
+
return self.tensor.device
|
481 |
+
|
482 |
+
@torch.jit.unused
|
483 |
+
def __iter__(self):
|
484 |
+
"""
|
485 |
+
Yield a box as a Tensor of shape (5,) at a time.
|
486 |
+
"""
|
487 |
+
yield from self.tensor
|
488 |
+
|
489 |
+
|
490 |
+
def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None:
|
491 |
+
"""
|
492 |
+
Given two lists of rotated boxes of size N and M,
|
493 |
+
compute the IoU (intersection over union)
|
494 |
+
between **all** N x M pairs of boxes.
|
495 |
+
The box order must be (x_center, y_center, width, height, angle).
|
496 |
+
|
497 |
+
Args:
|
498 |
+
boxes1, boxes2 (RotatedBoxes):
|
499 |
+
two `RotatedBoxes`. Contains N & M rotated boxes, respectively.
|
500 |
+
|
501 |
+
Returns:
|
502 |
+
Tensor: IoU, sized [N,M].
|
503 |
+
"""
|
504 |
+
|
505 |
+
return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .base_tracker import ( # noqa
|
3 |
+
BaseTracker,
|
4 |
+
build_tracker_head,
|
5 |
+
TRACKER_HEADS_REGISTRY,
|
6 |
+
)
|
7 |
+
from .bbox_iou_tracker import BBoxIOUTracker # noqa
|
8 |
+
from .hungarian_tracker import BaseHungarianTracker # noqa
|
9 |
+
from .iou_weighted_hungarian_bbox_iou_tracker import ( # noqa
|
10 |
+
IOUWeightedHungarianBBoxIOUTracker,
|
11 |
+
)
|
12 |
+
from .utils import create_prediction_pairs # noqa
|
13 |
+
from .vanilla_hungarian_bbox_iou_tracker import VanillaHungarianBBoxIOUTracker # noqa
|
14 |
+
|
15 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/base_tracker.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2004-present Facebook. All Rights Reserved.
|
3 |
+
from annotator.oneformer.detectron2.config import configurable
|
4 |
+
from annotator.oneformer.detectron2.utils.registry import Registry
|
5 |
+
|
6 |
+
from ..config.config import CfgNode as CfgNode_
|
7 |
+
from ..structures import Instances
|
8 |
+
|
9 |
+
TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS")
|
10 |
+
TRACKER_HEADS_REGISTRY.__doc__ = """
|
11 |
+
Registry for tracking classes.
|
12 |
+
"""
|
13 |
+
|
14 |
+
|
15 |
+
class BaseTracker(object):
|
16 |
+
"""
|
17 |
+
A parent class for all trackers
|
18 |
+
"""
|
19 |
+
|
20 |
+
@configurable
|
21 |
+
def __init__(self, **kwargs):
|
22 |
+
self._prev_instances = None # (D2)instances for previous frame
|
23 |
+
self._matched_idx = set() # indices in prev_instances found matching
|
24 |
+
self._matched_ID = set() # idendities in prev_instances found matching
|
25 |
+
self._untracked_prev_idx = set() # indices in prev_instances not found matching
|
26 |
+
self._id_count = 0 # used to assign new id
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def from_config(cls, cfg: CfgNode_):
|
30 |
+
raise NotImplementedError("Calling BaseTracker::from_config")
|
31 |
+
|
32 |
+
def update(self, predictions: Instances) -> Instances:
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
predictions: D2 Instances for predictions of the current frame
|
36 |
+
Return:
|
37 |
+
D2 Instances for predictions of the current frame with ID assigned
|
38 |
+
|
39 |
+
_prev_instances and instances will have the following fields:
|
40 |
+
.pred_boxes (shape=[N, 4])
|
41 |
+
.scores (shape=[N,])
|
42 |
+
.pred_classes (shape=[N,])
|
43 |
+
.pred_keypoints (shape=[N, M, 3], Optional)
|
44 |
+
.pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W]
|
45 |
+
.ID (shape=[N,])
|
46 |
+
|
47 |
+
N: # of detected bboxes
|
48 |
+
H and W: height and width of 2D mask
|
49 |
+
"""
|
50 |
+
raise NotImplementedError("Calling BaseTracker::update")
|
51 |
+
|
52 |
+
|
53 |
+
def build_tracker_head(cfg: CfgNode_) -> BaseTracker:
|
54 |
+
"""
|
55 |
+
Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
cfg: D2 CfgNode, config file with tracker information
|
59 |
+
Return:
|
60 |
+
tracker object
|
61 |
+
"""
|
62 |
+
name = cfg.TRACKER_HEADS.TRACKER_NAME
|
63 |
+
tracker_class = TRACKER_HEADS_REGISTRY.get(name)
|
64 |
+
return tracker_class(cfg)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/bbox_iou_tracker.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2004-present Facebook. All Rights Reserved.
|
3 |
+
import copy
|
4 |
+
import numpy as np
|
5 |
+
from typing import List
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from annotator.oneformer.detectron2.config import configurable
|
9 |
+
from annotator.oneformer.detectron2.structures import Boxes, Instances
|
10 |
+
from annotator.oneformer.detectron2.structures.boxes import pairwise_iou
|
11 |
+
|
12 |
+
from ..config.config import CfgNode as CfgNode_
|
13 |
+
from .base_tracker import TRACKER_HEADS_REGISTRY, BaseTracker
|
14 |
+
|
15 |
+
|
16 |
+
@TRACKER_HEADS_REGISTRY.register()
|
17 |
+
class BBoxIOUTracker(BaseTracker):
|
18 |
+
"""
|
19 |
+
A bounding box tracker to assign ID based on IoU between current and previous instances
|
20 |
+
"""
|
21 |
+
|
22 |
+
@configurable
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
*,
|
26 |
+
video_height: int,
|
27 |
+
video_width: int,
|
28 |
+
max_num_instances: int = 200,
|
29 |
+
max_lost_frame_count: int = 0,
|
30 |
+
min_box_rel_dim: float = 0.02,
|
31 |
+
min_instance_period: int = 1,
|
32 |
+
track_iou_threshold: float = 0.5,
|
33 |
+
**kwargs,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
video_height: height the video frame
|
38 |
+
video_width: width of the video frame
|
39 |
+
max_num_instances: maximum number of id allowed to be tracked
|
40 |
+
max_lost_frame_count: maximum number of frame an id can lost tracking
|
41 |
+
exceed this number, an id is considered as lost
|
42 |
+
forever
|
43 |
+
min_box_rel_dim: a percentage, smaller than this dimension, a bbox is
|
44 |
+
removed from tracking
|
45 |
+
min_instance_period: an instance will be shown after this number of period
|
46 |
+
since its first showing up in the video
|
47 |
+
track_iou_threshold: iou threshold, below this number a bbox pair is removed
|
48 |
+
from tracking
|
49 |
+
"""
|
50 |
+
super().__init__(**kwargs)
|
51 |
+
self._video_height = video_height
|
52 |
+
self._video_width = video_width
|
53 |
+
self._max_num_instances = max_num_instances
|
54 |
+
self._max_lost_frame_count = max_lost_frame_count
|
55 |
+
self._min_box_rel_dim = min_box_rel_dim
|
56 |
+
self._min_instance_period = min_instance_period
|
57 |
+
self._track_iou_threshold = track_iou_threshold
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def from_config(cls, cfg: CfgNode_):
|
61 |
+
"""
|
62 |
+
Old style initialization using CfgNode
|
63 |
+
|
64 |
+
Args:
|
65 |
+
cfg: D2 CfgNode, config file
|
66 |
+
Return:
|
67 |
+
dictionary storing arguments for __init__ method
|
68 |
+
"""
|
69 |
+
assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS
|
70 |
+
assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS
|
71 |
+
video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT")
|
72 |
+
video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH")
|
73 |
+
max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200)
|
74 |
+
max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0)
|
75 |
+
min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02)
|
76 |
+
min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1)
|
77 |
+
track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5)
|
78 |
+
return {
|
79 |
+
"_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker",
|
80 |
+
"video_height": video_height,
|
81 |
+
"video_width": video_width,
|
82 |
+
"max_num_instances": max_num_instances,
|
83 |
+
"max_lost_frame_count": max_lost_frame_count,
|
84 |
+
"min_box_rel_dim": min_box_rel_dim,
|
85 |
+
"min_instance_period": min_instance_period,
|
86 |
+
"track_iou_threshold": track_iou_threshold,
|
87 |
+
}
|
88 |
+
|
89 |
+
def update(self, instances: Instances) -> Instances:
|
90 |
+
"""
|
91 |
+
See BaseTracker description
|
92 |
+
"""
|
93 |
+
instances = self._initialize_extra_fields(instances)
|
94 |
+
if self._prev_instances is not None:
|
95 |
+
# calculate IoU of all bbox pairs
|
96 |
+
iou_all = pairwise_iou(
|
97 |
+
boxes1=instances.pred_boxes,
|
98 |
+
boxes2=self._prev_instances.pred_boxes,
|
99 |
+
)
|
100 |
+
# sort IoU in descending order
|
101 |
+
bbox_pairs = self._create_prediction_pairs(instances, iou_all)
|
102 |
+
# assign previous ID to current bbox if IoU > track_iou_threshold
|
103 |
+
self._reset_fields()
|
104 |
+
for bbox_pair in bbox_pairs:
|
105 |
+
idx = bbox_pair["idx"]
|
106 |
+
prev_id = bbox_pair["prev_id"]
|
107 |
+
if (
|
108 |
+
idx in self._matched_idx
|
109 |
+
or prev_id in self._matched_ID
|
110 |
+
or bbox_pair["IoU"] < self._track_iou_threshold
|
111 |
+
):
|
112 |
+
continue
|
113 |
+
instances.ID[idx] = prev_id
|
114 |
+
instances.ID_period[idx] = bbox_pair["prev_period"] + 1
|
115 |
+
instances.lost_frame_count[idx] = 0
|
116 |
+
self._matched_idx.add(idx)
|
117 |
+
self._matched_ID.add(prev_id)
|
118 |
+
self._untracked_prev_idx.remove(bbox_pair["prev_idx"])
|
119 |
+
instances = self._assign_new_id(instances)
|
120 |
+
instances = self._merge_untracked_instances(instances)
|
121 |
+
self._prev_instances = copy.deepcopy(instances)
|
122 |
+
return instances
|
123 |
+
|
124 |
+
def _create_prediction_pairs(self, instances: Instances, iou_all: np.ndarray) -> List:
|
125 |
+
"""
|
126 |
+
For all instances in previous and current frames, create pairs. For each
|
127 |
+
pair, store index of the instance in current frame predcitions, index in
|
128 |
+
previous predictions, ID in previous predictions, IoU of the bboxes in this
|
129 |
+
pair, period in previous predictions.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
instances: D2 Instances, for predictions of the current frame
|
133 |
+
iou_all: IoU for all bboxes pairs
|
134 |
+
Return:
|
135 |
+
A list of IoU for all pairs
|
136 |
+
"""
|
137 |
+
bbox_pairs = []
|
138 |
+
for i in range(len(instances)):
|
139 |
+
for j in range(len(self._prev_instances)):
|
140 |
+
bbox_pairs.append(
|
141 |
+
{
|
142 |
+
"idx": i,
|
143 |
+
"prev_idx": j,
|
144 |
+
"prev_id": self._prev_instances.ID[j],
|
145 |
+
"IoU": iou_all[i, j],
|
146 |
+
"prev_period": self._prev_instances.ID_period[j],
|
147 |
+
}
|
148 |
+
)
|
149 |
+
return bbox_pairs
|
150 |
+
|
151 |
+
def _initialize_extra_fields(self, instances: Instances) -> Instances:
|
152 |
+
"""
|
153 |
+
If input instances don't have ID, ID_period, lost_frame_count fields,
|
154 |
+
this method is used to initialize these fields.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
instances: D2 Instances, for predictions of the current frame
|
158 |
+
Return:
|
159 |
+
D2 Instances with extra fields added
|
160 |
+
"""
|
161 |
+
if not instances.has("ID"):
|
162 |
+
instances.set("ID", [None] * len(instances))
|
163 |
+
if not instances.has("ID_period"):
|
164 |
+
instances.set("ID_period", [None] * len(instances))
|
165 |
+
if not instances.has("lost_frame_count"):
|
166 |
+
instances.set("lost_frame_count", [None] * len(instances))
|
167 |
+
if self._prev_instances is None:
|
168 |
+
instances.ID = list(range(len(instances)))
|
169 |
+
self._id_count += len(instances)
|
170 |
+
instances.ID_period = [1] * len(instances)
|
171 |
+
instances.lost_frame_count = [0] * len(instances)
|
172 |
+
return instances
|
173 |
+
|
174 |
+
def _reset_fields(self):
|
175 |
+
"""
|
176 |
+
Before each uodate call, reset fields first
|
177 |
+
"""
|
178 |
+
self._matched_idx = set()
|
179 |
+
self._matched_ID = set()
|
180 |
+
self._untracked_prev_idx = set(range(len(self._prev_instances)))
|
181 |
+
|
182 |
+
def _assign_new_id(self, instances: Instances) -> Instances:
|
183 |
+
"""
|
184 |
+
For each untracked instance, assign a new id
|
185 |
+
|
186 |
+
Args:
|
187 |
+
instances: D2 Instances, for predictions of the current frame
|
188 |
+
Return:
|
189 |
+
D2 Instances with new ID assigned
|
190 |
+
"""
|
191 |
+
untracked_idx = set(range(len(instances))).difference(self._matched_idx)
|
192 |
+
for idx in untracked_idx:
|
193 |
+
instances.ID[idx] = self._id_count
|
194 |
+
self._id_count += 1
|
195 |
+
instances.ID_period[idx] = 1
|
196 |
+
instances.lost_frame_count[idx] = 0
|
197 |
+
return instances
|
198 |
+
|
199 |
+
def _merge_untracked_instances(self, instances: Instances) -> Instances:
|
200 |
+
"""
|
201 |
+
For untracked previous instances, under certain condition, still keep them
|
202 |
+
in tracking and merge with the current instances.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
instances: D2 Instances, for predictions of the current frame
|
206 |
+
Return:
|
207 |
+
D2 Instances merging current instances and instances from previous
|
208 |
+
frame decided to keep tracking
|
209 |
+
"""
|
210 |
+
untracked_instances = Instances(
|
211 |
+
image_size=instances.image_size,
|
212 |
+
pred_boxes=[],
|
213 |
+
pred_classes=[],
|
214 |
+
scores=[],
|
215 |
+
ID=[],
|
216 |
+
ID_period=[],
|
217 |
+
lost_frame_count=[],
|
218 |
+
)
|
219 |
+
prev_bboxes = list(self._prev_instances.pred_boxes)
|
220 |
+
prev_classes = list(self._prev_instances.pred_classes)
|
221 |
+
prev_scores = list(self._prev_instances.scores)
|
222 |
+
prev_ID_period = self._prev_instances.ID_period
|
223 |
+
if instances.has("pred_masks"):
|
224 |
+
untracked_instances.set("pred_masks", [])
|
225 |
+
prev_masks = list(self._prev_instances.pred_masks)
|
226 |
+
if instances.has("pred_keypoints"):
|
227 |
+
untracked_instances.set("pred_keypoints", [])
|
228 |
+
prev_keypoints = list(self._prev_instances.pred_keypoints)
|
229 |
+
if instances.has("pred_keypoint_heatmaps"):
|
230 |
+
untracked_instances.set("pred_keypoint_heatmaps", [])
|
231 |
+
prev_keypoint_heatmaps = list(self._prev_instances.pred_keypoint_heatmaps)
|
232 |
+
for idx in self._untracked_prev_idx:
|
233 |
+
x_left, y_top, x_right, y_bot = prev_bboxes[idx]
|
234 |
+
if (
|
235 |
+
(1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim)
|
236 |
+
or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim)
|
237 |
+
or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count
|
238 |
+
or prev_ID_period[idx] <= self._min_instance_period
|
239 |
+
):
|
240 |
+
continue
|
241 |
+
untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy()))
|
242 |
+
untracked_instances.pred_classes.append(int(prev_classes[idx]))
|
243 |
+
untracked_instances.scores.append(float(prev_scores[idx]))
|
244 |
+
untracked_instances.ID.append(self._prev_instances.ID[idx])
|
245 |
+
untracked_instances.ID_period.append(self._prev_instances.ID_period[idx])
|
246 |
+
untracked_instances.lost_frame_count.append(
|
247 |
+
self._prev_instances.lost_frame_count[idx] + 1
|
248 |
+
)
|
249 |
+
if instances.has("pred_masks"):
|
250 |
+
untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8))
|
251 |
+
if instances.has("pred_keypoints"):
|
252 |
+
untracked_instances.pred_keypoints.append(
|
253 |
+
prev_keypoints[idx].numpy().astype(np.uint8)
|
254 |
+
)
|
255 |
+
if instances.has("pred_keypoint_heatmaps"):
|
256 |
+
untracked_instances.pred_keypoint_heatmaps.append(
|
257 |
+
prev_keypoint_heatmaps[idx].numpy().astype(np.float32)
|
258 |
+
)
|
259 |
+
untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes))
|
260 |
+
untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes)
|
261 |
+
untracked_instances.scores = torch.FloatTensor(untracked_instances.scores)
|
262 |
+
if instances.has("pred_masks"):
|
263 |
+
untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks)
|
264 |
+
if instances.has("pred_keypoints"):
|
265 |
+
untracked_instances.pred_keypoints = torch.IntTensor(untracked_instances.pred_keypoints)
|
266 |
+
if instances.has("pred_keypoint_heatmaps"):
|
267 |
+
untracked_instances.pred_keypoint_heatmaps = torch.FloatTensor(
|
268 |
+
untracked_instances.pred_keypoint_heatmaps
|
269 |
+
)
|
270 |
+
|
271 |
+
return Instances.cat(
|
272 |
+
[
|
273 |
+
instances,
|
274 |
+
untracked_instances,
|
275 |
+
]
|
276 |
+
)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/hungarian_tracker.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2004-present Facebook. All Rights Reserved.
|
3 |
+
import copy
|
4 |
+
import numpy as np
|
5 |
+
from typing import Dict
|
6 |
+
import torch
|
7 |
+
from scipy.optimize import linear_sum_assignment
|
8 |
+
|
9 |
+
from annotator.oneformer.detectron2.config import configurable
|
10 |
+
from annotator.oneformer.detectron2.structures import Boxes, Instances
|
11 |
+
|
12 |
+
from ..config.config import CfgNode as CfgNode_
|
13 |
+
from .base_tracker import BaseTracker
|
14 |
+
|
15 |
+
|
16 |
+
class BaseHungarianTracker(BaseTracker):
|
17 |
+
"""
|
18 |
+
A base class for all Hungarian trackers
|
19 |
+
"""
|
20 |
+
|
21 |
+
@configurable
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
video_height: int,
|
25 |
+
video_width: int,
|
26 |
+
max_num_instances: int = 200,
|
27 |
+
max_lost_frame_count: int = 0,
|
28 |
+
min_box_rel_dim: float = 0.02,
|
29 |
+
min_instance_period: int = 1,
|
30 |
+
**kwargs
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
video_height: height the video frame
|
35 |
+
video_width: width of the video frame
|
36 |
+
max_num_instances: maximum number of id allowed to be tracked
|
37 |
+
max_lost_frame_count: maximum number of frame an id can lost tracking
|
38 |
+
exceed this number, an id is considered as lost
|
39 |
+
forever
|
40 |
+
min_box_rel_dim: a percentage, smaller than this dimension, a bbox is
|
41 |
+
removed from tracking
|
42 |
+
min_instance_period: an instance will be shown after this number of period
|
43 |
+
since its first showing up in the video
|
44 |
+
"""
|
45 |
+
super().__init__(**kwargs)
|
46 |
+
self._video_height = video_height
|
47 |
+
self._video_width = video_width
|
48 |
+
self._max_num_instances = max_num_instances
|
49 |
+
self._max_lost_frame_count = max_lost_frame_count
|
50 |
+
self._min_box_rel_dim = min_box_rel_dim
|
51 |
+
self._min_instance_period = min_instance_period
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def from_config(cls, cfg: CfgNode_) -> Dict:
|
55 |
+
raise NotImplementedError("Calling HungarianTracker::from_config")
|
56 |
+
|
57 |
+
def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray:
|
58 |
+
raise NotImplementedError("Calling HungarianTracker::build_matrix")
|
59 |
+
|
60 |
+
def update(self, instances: Instances) -> Instances:
|
61 |
+
if instances.has("pred_keypoints"):
|
62 |
+
raise NotImplementedError("Need to add support for keypoints")
|
63 |
+
instances = self._initialize_extra_fields(instances)
|
64 |
+
if self._prev_instances is not None:
|
65 |
+
self._untracked_prev_idx = set(range(len(self._prev_instances)))
|
66 |
+
cost_matrix = self.build_cost_matrix(instances, self._prev_instances)
|
67 |
+
matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix)
|
68 |
+
instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx)
|
69 |
+
instances = self._process_unmatched_idx(instances, matched_idx)
|
70 |
+
instances = self._process_unmatched_prev_idx(instances, matched_prev_idx)
|
71 |
+
self._prev_instances = copy.deepcopy(instances)
|
72 |
+
return instances
|
73 |
+
|
74 |
+
def _initialize_extra_fields(self, instances: Instances) -> Instances:
|
75 |
+
"""
|
76 |
+
If input instances don't have ID, ID_period, lost_frame_count fields,
|
77 |
+
this method is used to initialize these fields.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
instances: D2 Instances, for predictions of the current frame
|
81 |
+
Return:
|
82 |
+
D2 Instances with extra fields added
|
83 |
+
"""
|
84 |
+
if not instances.has("ID"):
|
85 |
+
instances.set("ID", [None] * len(instances))
|
86 |
+
if not instances.has("ID_period"):
|
87 |
+
instances.set("ID_period", [None] * len(instances))
|
88 |
+
if not instances.has("lost_frame_count"):
|
89 |
+
instances.set("lost_frame_count", [None] * len(instances))
|
90 |
+
if self._prev_instances is None:
|
91 |
+
instances.ID = list(range(len(instances)))
|
92 |
+
self._id_count += len(instances)
|
93 |
+
instances.ID_period = [1] * len(instances)
|
94 |
+
instances.lost_frame_count = [0] * len(instances)
|
95 |
+
return instances
|
96 |
+
|
97 |
+
def _process_matched_idx(
|
98 |
+
self, instances: Instances, matched_idx: np.ndarray, matched_prev_idx: np.ndarray
|
99 |
+
) -> Instances:
|
100 |
+
assert matched_idx.size == matched_prev_idx.size
|
101 |
+
for i in range(matched_idx.size):
|
102 |
+
instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]]
|
103 |
+
instances.ID_period[matched_idx[i]] = (
|
104 |
+
self._prev_instances.ID_period[matched_prev_idx[i]] + 1
|
105 |
+
)
|
106 |
+
instances.lost_frame_count[matched_idx[i]] = 0
|
107 |
+
return instances
|
108 |
+
|
109 |
+
def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances:
|
110 |
+
untracked_idx = set(range(len(instances))).difference(set(matched_idx))
|
111 |
+
for idx in untracked_idx:
|
112 |
+
instances.ID[idx] = self._id_count
|
113 |
+
self._id_count += 1
|
114 |
+
instances.ID_period[idx] = 1
|
115 |
+
instances.lost_frame_count[idx] = 0
|
116 |
+
return instances
|
117 |
+
|
118 |
+
def _process_unmatched_prev_idx(
|
119 |
+
self, instances: Instances, matched_prev_idx: np.ndarray
|
120 |
+
) -> Instances:
|
121 |
+
untracked_instances = Instances(
|
122 |
+
image_size=instances.image_size,
|
123 |
+
pred_boxes=[],
|
124 |
+
pred_masks=[],
|
125 |
+
pred_classes=[],
|
126 |
+
scores=[],
|
127 |
+
ID=[],
|
128 |
+
ID_period=[],
|
129 |
+
lost_frame_count=[],
|
130 |
+
)
|
131 |
+
prev_bboxes = list(self._prev_instances.pred_boxes)
|
132 |
+
prev_classes = list(self._prev_instances.pred_classes)
|
133 |
+
prev_scores = list(self._prev_instances.scores)
|
134 |
+
prev_ID_period = self._prev_instances.ID_period
|
135 |
+
if instances.has("pred_masks"):
|
136 |
+
prev_masks = list(self._prev_instances.pred_masks)
|
137 |
+
untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx))
|
138 |
+
for idx in untracked_prev_idx:
|
139 |
+
x_left, y_top, x_right, y_bot = prev_bboxes[idx]
|
140 |
+
if (
|
141 |
+
(1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim)
|
142 |
+
or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim)
|
143 |
+
or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count
|
144 |
+
or prev_ID_period[idx] <= self._min_instance_period
|
145 |
+
):
|
146 |
+
continue
|
147 |
+
untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy()))
|
148 |
+
untracked_instances.pred_classes.append(int(prev_classes[idx]))
|
149 |
+
untracked_instances.scores.append(float(prev_scores[idx]))
|
150 |
+
untracked_instances.ID.append(self._prev_instances.ID[idx])
|
151 |
+
untracked_instances.ID_period.append(self._prev_instances.ID_period[idx])
|
152 |
+
untracked_instances.lost_frame_count.append(
|
153 |
+
self._prev_instances.lost_frame_count[idx] + 1
|
154 |
+
)
|
155 |
+
if instances.has("pred_masks"):
|
156 |
+
untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8))
|
157 |
+
|
158 |
+
untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes))
|
159 |
+
untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes)
|
160 |
+
untracked_instances.scores = torch.FloatTensor(untracked_instances.scores)
|
161 |
+
if instances.has("pred_masks"):
|
162 |
+
untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks)
|
163 |
+
else:
|
164 |
+
untracked_instances.remove("pred_masks")
|
165 |
+
|
166 |
+
return Instances.cat(
|
167 |
+
[
|
168 |
+
instances,
|
169 |
+
untracked_instances,
|
170 |
+
]
|
171 |
+
)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2004-present Facebook. All Rights Reserved.
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from annotator.oneformer.detectron2.config import CfgNode as CfgNode_
|
8 |
+
from annotator.oneformer.detectron2.config import configurable
|
9 |
+
|
10 |
+
from .base_tracker import TRACKER_HEADS_REGISTRY
|
11 |
+
from .vanilla_hungarian_bbox_iou_tracker import VanillaHungarianBBoxIOUTracker
|
12 |
+
|
13 |
+
|
14 |
+
@TRACKER_HEADS_REGISTRY.register()
|
15 |
+
class IOUWeightedHungarianBBoxIOUTracker(VanillaHungarianBBoxIOUTracker):
|
16 |
+
"""
|
17 |
+
A tracker using IoU as weight in Hungarian algorithm, also known
|
18 |
+
as Munkres or Kuhn-Munkres algorithm
|
19 |
+
"""
|
20 |
+
|
21 |
+
@configurable
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
*,
|
25 |
+
video_height: int,
|
26 |
+
video_width: int,
|
27 |
+
max_num_instances: int = 200,
|
28 |
+
max_lost_frame_count: int = 0,
|
29 |
+
min_box_rel_dim: float = 0.02,
|
30 |
+
min_instance_period: int = 1,
|
31 |
+
track_iou_threshold: float = 0.5,
|
32 |
+
**kwargs,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
video_height: height the video frame
|
37 |
+
video_width: width of the video frame
|
38 |
+
max_num_instances: maximum number of id allowed to be tracked
|
39 |
+
max_lost_frame_count: maximum number of frame an id can lost tracking
|
40 |
+
exceed this number, an id is considered as lost
|
41 |
+
forever
|
42 |
+
min_box_rel_dim: a percentage, smaller than this dimension, a bbox is
|
43 |
+
removed from tracking
|
44 |
+
min_instance_period: an instance will be shown after this number of period
|
45 |
+
since its first showing up in the video
|
46 |
+
track_iou_threshold: iou threshold, below this number a bbox pair is removed
|
47 |
+
from tracking
|
48 |
+
"""
|
49 |
+
super().__init__(
|
50 |
+
video_height=video_height,
|
51 |
+
video_width=video_width,
|
52 |
+
max_num_instances=max_num_instances,
|
53 |
+
max_lost_frame_count=max_lost_frame_count,
|
54 |
+
min_box_rel_dim=min_box_rel_dim,
|
55 |
+
min_instance_period=min_instance_period,
|
56 |
+
track_iou_threshold=track_iou_threshold,
|
57 |
+
)
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def from_config(cls, cfg: CfgNode_):
|
61 |
+
"""
|
62 |
+
Old style initialization using CfgNode
|
63 |
+
|
64 |
+
Args:
|
65 |
+
cfg: D2 CfgNode, config file
|
66 |
+
Return:
|
67 |
+
dictionary storing arguments for __init__ method
|
68 |
+
"""
|
69 |
+
assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS
|
70 |
+
assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS
|
71 |
+
video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT")
|
72 |
+
video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH")
|
73 |
+
max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200)
|
74 |
+
max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0)
|
75 |
+
min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02)
|
76 |
+
min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1)
|
77 |
+
track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5)
|
78 |
+
return {
|
79 |
+
"_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa
|
80 |
+
"video_height": video_height,
|
81 |
+
"video_width": video_width,
|
82 |
+
"max_num_instances": max_num_instances,
|
83 |
+
"max_lost_frame_count": max_lost_frame_count,
|
84 |
+
"min_box_rel_dim": min_box_rel_dim,
|
85 |
+
"min_instance_period": min_instance_period,
|
86 |
+
"track_iou_threshold": track_iou_threshold,
|
87 |
+
}
|
88 |
+
|
89 |
+
def assign_cost_matrix_values(self, cost_matrix: np.ndarray, bbox_pairs: List) -> np.ndarray:
|
90 |
+
"""
|
91 |
+
Based on IoU for each pair of bbox, assign the associated value in cost matrix
|
92 |
+
|
93 |
+
Args:
|
94 |
+
cost_matrix: np.ndarray, initialized 2D array with target dimensions
|
95 |
+
bbox_pairs: list of bbox pair, in each pair, iou value is stored
|
96 |
+
Return:
|
97 |
+
np.ndarray, cost_matrix with assigned values
|
98 |
+
"""
|
99 |
+
for pair in bbox_pairs:
|
100 |
+
# assign (-1 * IoU) for above threshold pairs, algorithms will minimize cost
|
101 |
+
cost_matrix[pair["idx"]][pair["prev_idx"]] = -1 * pair["IoU"]
|
102 |
+
return cost_matrix
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import numpy as np
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from annotator.oneformer.detectron2.structures import Instances
|
6 |
+
|
7 |
+
|
8 |
+
def create_prediction_pairs(
|
9 |
+
instances: Instances,
|
10 |
+
prev_instances: Instances,
|
11 |
+
iou_all: np.ndarray,
|
12 |
+
threshold: float = 0.5,
|
13 |
+
) -> List:
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
instances: predictions from current frame
|
17 |
+
prev_instances: predictions from previous frame
|
18 |
+
iou_all: 2D numpy array containing iou for each bbox pair
|
19 |
+
threshold: below the threshold, doesn't consider the pair of bbox is valid
|
20 |
+
Return:
|
21 |
+
List of bbox pairs
|
22 |
+
"""
|
23 |
+
bbox_pairs = []
|
24 |
+
for i in range(len(instances)):
|
25 |
+
for j in range(len(prev_instances)):
|
26 |
+
if iou_all[i, j] < threshold:
|
27 |
+
continue
|
28 |
+
bbox_pairs.append(
|
29 |
+
{
|
30 |
+
"idx": i,
|
31 |
+
"prev_idx": j,
|
32 |
+
"prev_id": prev_instances.ID[j],
|
33 |
+
"IoU": iou_all[i, j],
|
34 |
+
"prev_period": prev_instances.ID_period[j],
|
35 |
+
}
|
36 |
+
)
|
37 |
+
return bbox_pairs
|
38 |
+
|
39 |
+
|
40 |
+
LARGE_COST_VALUE = 100000
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2004-present Facebook. All Rights Reserved.
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from annotator.oneformer.detectron2.config import CfgNode as CfgNode_
|
8 |
+
from annotator.oneformer.detectron2.config import configurable
|
9 |
+
from annotator.oneformer.detectron2.structures import Instances
|
10 |
+
from annotator.oneformer.detectron2.structures.boxes import pairwise_iou
|
11 |
+
from annotator.oneformer.detectron2.tracking.utils import LARGE_COST_VALUE, create_prediction_pairs
|
12 |
+
|
13 |
+
from .base_tracker import TRACKER_HEADS_REGISTRY
|
14 |
+
from .hungarian_tracker import BaseHungarianTracker
|
15 |
+
|
16 |
+
|
17 |
+
@TRACKER_HEADS_REGISTRY.register()
|
18 |
+
class VanillaHungarianBBoxIOUTracker(BaseHungarianTracker):
|
19 |
+
"""
|
20 |
+
Hungarian algo based tracker using bbox iou as metric
|
21 |
+
"""
|
22 |
+
|
23 |
+
@configurable
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
*,
|
27 |
+
video_height: int,
|
28 |
+
video_width: int,
|
29 |
+
max_num_instances: int = 200,
|
30 |
+
max_lost_frame_count: int = 0,
|
31 |
+
min_box_rel_dim: float = 0.02,
|
32 |
+
min_instance_period: int = 1,
|
33 |
+
track_iou_threshold: float = 0.5,
|
34 |
+
**kwargs,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
video_height: height the video frame
|
39 |
+
video_width: width of the video frame
|
40 |
+
max_num_instances: maximum number of id allowed to be tracked
|
41 |
+
max_lost_frame_count: maximum number of frame an id can lost tracking
|
42 |
+
exceed this number, an id is considered as lost
|
43 |
+
forever
|
44 |
+
min_box_rel_dim: a percentage, smaller than this dimension, a bbox is
|
45 |
+
removed from tracking
|
46 |
+
min_instance_period: an instance will be shown after this number of period
|
47 |
+
since its first showing up in the video
|
48 |
+
track_iou_threshold: iou threshold, below this number a bbox pair is removed
|
49 |
+
from tracking
|
50 |
+
"""
|
51 |
+
super().__init__(
|
52 |
+
video_height=video_height,
|
53 |
+
video_width=video_width,
|
54 |
+
max_num_instances=max_num_instances,
|
55 |
+
max_lost_frame_count=max_lost_frame_count,
|
56 |
+
min_box_rel_dim=min_box_rel_dim,
|
57 |
+
min_instance_period=min_instance_period,
|
58 |
+
)
|
59 |
+
self._track_iou_threshold = track_iou_threshold
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_config(cls, cfg: CfgNode_):
|
63 |
+
"""
|
64 |
+
Old style initialization using CfgNode
|
65 |
+
|
66 |
+
Args:
|
67 |
+
cfg: D2 CfgNode, config file
|
68 |
+
Return:
|
69 |
+
dictionary storing arguments for __init__ method
|
70 |
+
"""
|
71 |
+
assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS
|
72 |
+
assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS
|
73 |
+
video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT")
|
74 |
+
video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH")
|
75 |
+
max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200)
|
76 |
+
max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0)
|
77 |
+
min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02)
|
78 |
+
min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1)
|
79 |
+
track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5)
|
80 |
+
return {
|
81 |
+
"_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa
|
82 |
+
"video_height": video_height,
|
83 |
+
"video_width": video_width,
|
84 |
+
"max_num_instances": max_num_instances,
|
85 |
+
"max_lost_frame_count": max_lost_frame_count,
|
86 |
+
"min_box_rel_dim": min_box_rel_dim,
|
87 |
+
"min_instance_period": min_instance_period,
|
88 |
+
"track_iou_threshold": track_iou_threshold,
|
89 |
+
}
|
90 |
+
|
91 |
+
def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray:
|
92 |
+
"""
|
93 |
+
Build the cost matrix for assignment problem
|
94 |
+
(https://en.wikipedia.org/wiki/Assignment_problem)
|
95 |
+
|
96 |
+
Args:
|
97 |
+
instances: D2 Instances, for current frame predictions
|
98 |
+
prev_instances: D2 Instances, for previous frame predictions
|
99 |
+
|
100 |
+
Return:
|
101 |
+
the cost matrix in numpy array
|
102 |
+
"""
|
103 |
+
assert instances is not None and prev_instances is not None
|
104 |
+
# calculate IoU of all bbox pairs
|
105 |
+
iou_all = pairwise_iou(
|
106 |
+
boxes1=instances.pred_boxes,
|
107 |
+
boxes2=self._prev_instances.pred_boxes,
|
108 |
+
)
|
109 |
+
bbox_pairs = create_prediction_pairs(
|
110 |
+
instances, self._prev_instances, iou_all, threshold=self._track_iou_threshold
|
111 |
+
)
|
112 |
+
# assign large cost value to make sure pair below IoU threshold won't be matched
|
113 |
+
cost_matrix = np.full((len(instances), len(prev_instances)), LARGE_COST_VALUE)
|
114 |
+
return self.assign_cost_matrix_values(cost_matrix, bbox_pairs)
|
115 |
+
|
116 |
+
def assign_cost_matrix_values(self, cost_matrix: np.ndarray, bbox_pairs: List) -> np.ndarray:
|
117 |
+
"""
|
118 |
+
Based on IoU for each pair of bbox, assign the associated value in cost matrix
|
119 |
+
|
120 |
+
Args:
|
121 |
+
cost_matrix: np.ndarray, initialized 2D array with target dimensions
|
122 |
+
bbox_pairs: list of bbox pair, in each pair, iou value is stored
|
123 |
+
Return:
|
124 |
+
np.ndarray, cost_matrix with assigned values
|
125 |
+
"""
|
126 |
+
for pair in bbox_pairs:
|
127 |
+
# assign -1 for IoU above threshold pairs, algorithms will minimize cost
|
128 |
+
cost_matrix[pair["idx"]][pair["prev_idx"]] = -1
|
129 |
+
return cost_matrix
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Utility functions
|
2 |
+
|
3 |
+
This folder contain utility functions that are not used in the
|
4 |
+
core library, but are useful for building models or training
|
5 |
+
code using the config system.
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/analysis.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import typing
|
5 |
+
from typing import Any, List
|
6 |
+
import fvcore
|
7 |
+
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from annotator.oneformer.detectron2.export import TracingAdapter
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"activation_count_operators",
|
14 |
+
"flop_count_operators",
|
15 |
+
"parameter_count_table",
|
16 |
+
"parameter_count",
|
17 |
+
"FlopCountAnalysis",
|
18 |
+
]
|
19 |
+
|
20 |
+
FLOPS_MODE = "flops"
|
21 |
+
ACTIVATIONS_MODE = "activations"
|
22 |
+
|
23 |
+
|
24 |
+
# Some extra ops to ignore from counting, including elementwise and reduction ops
|
25 |
+
_IGNORED_OPS = {
|
26 |
+
"aten::add",
|
27 |
+
"aten::add_",
|
28 |
+
"aten::argmax",
|
29 |
+
"aten::argsort",
|
30 |
+
"aten::batch_norm",
|
31 |
+
"aten::constant_pad_nd",
|
32 |
+
"aten::div",
|
33 |
+
"aten::div_",
|
34 |
+
"aten::exp",
|
35 |
+
"aten::log2",
|
36 |
+
"aten::max_pool2d",
|
37 |
+
"aten::meshgrid",
|
38 |
+
"aten::mul",
|
39 |
+
"aten::mul_",
|
40 |
+
"aten::neg",
|
41 |
+
"aten::nonzero_numpy",
|
42 |
+
"aten::reciprocal",
|
43 |
+
"aten::repeat_interleave",
|
44 |
+
"aten::rsub",
|
45 |
+
"aten::sigmoid",
|
46 |
+
"aten::sigmoid_",
|
47 |
+
"aten::softmax",
|
48 |
+
"aten::sort",
|
49 |
+
"aten::sqrt",
|
50 |
+
"aten::sub",
|
51 |
+
"torchvision::nms", # TODO estimate flop for nms
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis):
|
56 |
+
"""
|
57 |
+
Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, model, inputs):
|
61 |
+
"""
|
62 |
+
Args:
|
63 |
+
model (nn.Module):
|
64 |
+
inputs (Any): inputs of the given model. Does not have to be tuple of tensors.
|
65 |
+
"""
|
66 |
+
wrapper = TracingAdapter(model, inputs, allow_non_tensor=True)
|
67 |
+
super().__init__(wrapper, wrapper.flattened_inputs)
|
68 |
+
self.set_op_handle(**{k: None for k in _IGNORED_OPS})
|
69 |
+
|
70 |
+
|
71 |
+
def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]:
|
72 |
+
"""
|
73 |
+
Implement operator-level flops counting using jit.
|
74 |
+
This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard
|
75 |
+
detection models in detectron2.
|
76 |
+
Please use :class:`FlopCountAnalysis` for more advanced functionalities.
|
77 |
+
|
78 |
+
Note:
|
79 |
+
The function runs the input through the model to compute flops.
|
80 |
+
The flops of a detection model is often input-dependent, for example,
|
81 |
+
the flops of box & mask head depends on the number of proposals &
|
82 |
+
the number of detected objects.
|
83 |
+
Therefore, the flops counting using a single input may not accurately
|
84 |
+
reflect the computation cost of a model. It's recommended to average
|
85 |
+
across a number of inputs.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
model: a detectron2 model that takes `list[dict]` as input.
|
89 |
+
inputs (list[dict]): inputs to model, in detectron2's standard format.
|
90 |
+
Only "image" key will be used.
|
91 |
+
supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count`
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
Counter: Gflop count per operator
|
95 |
+
"""
|
96 |
+
old_train = model.training
|
97 |
+
model.eval()
|
98 |
+
ret = FlopCountAnalysis(model, inputs).by_operator()
|
99 |
+
model.train(old_train)
|
100 |
+
return {k: v / 1e9 for k, v in ret.items()}
|
101 |
+
|
102 |
+
|
103 |
+
def activation_count_operators(
|
104 |
+
model: nn.Module, inputs: list, **kwargs
|
105 |
+
) -> typing.DefaultDict[str, float]:
|
106 |
+
"""
|
107 |
+
Implement operator-level activations counting using jit.
|
108 |
+
This is a wrapper of fvcore.nn.activation_count, that supports standard detection models
|
109 |
+
in detectron2.
|
110 |
+
|
111 |
+
Note:
|
112 |
+
The function runs the input through the model to compute activations.
|
113 |
+
The activations of a detection model is often input-dependent, for example,
|
114 |
+
the activations of box & mask head depends on the number of proposals &
|
115 |
+
the number of detected objects.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
model: a detectron2 model that takes `list[dict]` as input.
|
119 |
+
inputs (list[dict]): inputs to model, in detectron2's standard format.
|
120 |
+
Only "image" key will be used.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
Counter: activation count per operator
|
124 |
+
"""
|
125 |
+
return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)
|
126 |
+
|
127 |
+
|
128 |
+
def _wrapper_count_operators(
|
129 |
+
model: nn.Module, inputs: list, mode: str, **kwargs
|
130 |
+
) -> typing.DefaultDict[str, float]:
|
131 |
+
# ignore some ops
|
132 |
+
supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
|
133 |
+
supported_ops.update(kwargs.pop("supported_ops", {}))
|
134 |
+
kwargs["supported_ops"] = supported_ops
|
135 |
+
|
136 |
+
assert len(inputs) == 1, "Please use batch size=1"
|
137 |
+
tensor_input = inputs[0]["image"]
|
138 |
+
inputs = [{"image": tensor_input}] # remove other keys, in case there are any
|
139 |
+
|
140 |
+
old_train = model.training
|
141 |
+
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
|
142 |
+
model = model.module
|
143 |
+
wrapper = TracingAdapter(model, inputs)
|
144 |
+
wrapper.eval()
|
145 |
+
if mode == FLOPS_MODE:
|
146 |
+
ret = flop_count(wrapper, (tensor_input,), **kwargs)
|
147 |
+
elif mode == ACTIVATIONS_MODE:
|
148 |
+
ret = activation_count(wrapper, (tensor_input,), **kwargs)
|
149 |
+
else:
|
150 |
+
raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
|
151 |
+
# compatible with change in fvcore
|
152 |
+
if isinstance(ret, tuple):
|
153 |
+
ret = ret[0]
|
154 |
+
model.train(old_train)
|
155 |
+
return ret
|
156 |
+
|
157 |
+
|
158 |
+
def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]:
|
159 |
+
"""
|
160 |
+
Given a model, find parameters that do not contribute
|
161 |
+
to the loss.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
model: a model in training mode that returns losses
|
165 |
+
inputs: argument or a tuple of arguments. Inputs of the model
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
list[str]: the name of unused parameters
|
169 |
+
"""
|
170 |
+
assert model.training
|
171 |
+
for _, prm in model.named_parameters():
|
172 |
+
prm.grad = None
|
173 |
+
|
174 |
+
if isinstance(inputs, tuple):
|
175 |
+
losses = model(*inputs)
|
176 |
+
else:
|
177 |
+
losses = model(inputs)
|
178 |
+
|
179 |
+
if isinstance(losses, dict):
|
180 |
+
losses = sum(losses.values())
|
181 |
+
losses.backward()
|
182 |
+
|
183 |
+
unused: List[str] = []
|
184 |
+
for name, prm in model.named_parameters():
|
185 |
+
if prm.grad is None:
|
186 |
+
unused.append(name)
|
187 |
+
prm.grad = None
|
188 |
+
return unused
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/collect_env.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import importlib
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import subprocess
|
7 |
+
import sys
|
8 |
+
from collections import defaultdict
|
9 |
+
import PIL
|
10 |
+
import torch
|
11 |
+
import torchvision
|
12 |
+
from tabulate import tabulate
|
13 |
+
|
14 |
+
__all__ = ["collect_env_info"]
|
15 |
+
|
16 |
+
|
17 |
+
def collect_torch_env():
|
18 |
+
try:
|
19 |
+
import torch.__config__
|
20 |
+
|
21 |
+
return torch.__config__.show()
|
22 |
+
except ImportError:
|
23 |
+
# compatible with older versions of pytorch
|
24 |
+
from torch.utils.collect_env import get_pretty_env_info
|
25 |
+
|
26 |
+
return get_pretty_env_info()
|
27 |
+
|
28 |
+
|
29 |
+
def get_env_module():
|
30 |
+
var_name = "DETECTRON2_ENV_MODULE"
|
31 |
+
return var_name, os.environ.get(var_name, "<not set>")
|
32 |
+
|
33 |
+
|
34 |
+
def detect_compute_compatibility(CUDA_HOME, so_file):
|
35 |
+
try:
|
36 |
+
cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
|
37 |
+
if os.path.isfile(cuobjdump):
|
38 |
+
output = subprocess.check_output(
|
39 |
+
"'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
|
40 |
+
)
|
41 |
+
output = output.decode("utf-8").strip().split("\n")
|
42 |
+
arch = []
|
43 |
+
for line in output:
|
44 |
+
line = re.findall(r"\.sm_([0-9]*)\.", line)[0]
|
45 |
+
arch.append(".".join(line))
|
46 |
+
arch = sorted(set(arch))
|
47 |
+
return ", ".join(arch)
|
48 |
+
else:
|
49 |
+
return so_file + "; cannot find cuobjdump"
|
50 |
+
except Exception:
|
51 |
+
# unhandled failure
|
52 |
+
return so_file
|
53 |
+
|
54 |
+
|
55 |
+
def collect_env_info():
|
56 |
+
has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
|
57 |
+
torch_version = torch.__version__
|
58 |
+
|
59 |
+
# NOTE that CUDA_HOME/ROCM_HOME could be None even when CUDA runtime libs are functional
|
60 |
+
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
61 |
+
|
62 |
+
has_rocm = False
|
63 |
+
if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
|
64 |
+
has_rocm = True
|
65 |
+
has_cuda = has_gpu and (not has_rocm)
|
66 |
+
|
67 |
+
data = []
|
68 |
+
data.append(("sys.platform", sys.platform)) # check-template.yml depends on it
|
69 |
+
data.append(("Python", sys.version.replace("\n", "")))
|
70 |
+
data.append(("numpy", np.__version__))
|
71 |
+
|
72 |
+
try:
|
73 |
+
import annotator.oneformer.detectron2 # noqa
|
74 |
+
|
75 |
+
data.append(
|
76 |
+
("detectron2", detectron2.__version__ + " @" + os.path.dirname(detectron2.__file__))
|
77 |
+
)
|
78 |
+
except ImportError:
|
79 |
+
data.append(("detectron2", "failed to import"))
|
80 |
+
except AttributeError:
|
81 |
+
data.append(("detectron2", "imported a wrong installation"))
|
82 |
+
|
83 |
+
try:
|
84 |
+
import annotator.oneformer.detectron2._C as _C
|
85 |
+
except ImportError as e:
|
86 |
+
data.append(("detectron2._C", f"not built correctly: {e}"))
|
87 |
+
|
88 |
+
# print system compilers when extension fails to build
|
89 |
+
if sys.platform != "win32": # don't know what to do for windows
|
90 |
+
try:
|
91 |
+
# this is how torch/utils/cpp_extensions.py choose compiler
|
92 |
+
cxx = os.environ.get("CXX", "c++")
|
93 |
+
cxx = subprocess.check_output("'{}' --version".format(cxx), shell=True)
|
94 |
+
cxx = cxx.decode("utf-8").strip().split("\n")[0]
|
95 |
+
except subprocess.SubprocessError:
|
96 |
+
cxx = "Not found"
|
97 |
+
data.append(("Compiler ($CXX)", cxx))
|
98 |
+
|
99 |
+
if has_cuda and CUDA_HOME is not None:
|
100 |
+
try:
|
101 |
+
nvcc = os.path.join(CUDA_HOME, "bin", "nvcc")
|
102 |
+
nvcc = subprocess.check_output("'{}' -V".format(nvcc), shell=True)
|
103 |
+
nvcc = nvcc.decode("utf-8").strip().split("\n")[-1]
|
104 |
+
except subprocess.SubprocessError:
|
105 |
+
nvcc = "Not found"
|
106 |
+
data.append(("CUDA compiler", nvcc))
|
107 |
+
if has_cuda and sys.platform != "win32":
|
108 |
+
try:
|
109 |
+
so_file = importlib.util.find_spec("detectron2._C").origin
|
110 |
+
except (ImportError, AttributeError):
|
111 |
+
pass
|
112 |
+
else:
|
113 |
+
data.append(
|
114 |
+
("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, so_file))
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
# print compilers that are used to build extension
|
118 |
+
data.append(("Compiler", _C.get_compiler_version()))
|
119 |
+
data.append(("CUDA compiler", _C.get_cuda_version())) # cuda or hip
|
120 |
+
if has_cuda and getattr(_C, "has_cuda", lambda: True)():
|
121 |
+
data.append(
|
122 |
+
("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, _C.__file__))
|
123 |
+
)
|
124 |
+
|
125 |
+
data.append(get_env_module())
|
126 |
+
data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
|
127 |
+
data.append(("PyTorch debug build", torch.version.debug))
|
128 |
+
try:
|
129 |
+
data.append(("torch._C._GLIBCXX_USE_CXX11_ABI", torch._C._GLIBCXX_USE_CXX11_ABI))
|
130 |
+
except Exception:
|
131 |
+
pass
|
132 |
+
|
133 |
+
if not has_gpu:
|
134 |
+
has_gpu_text = "No: torch.cuda.is_available() == False"
|
135 |
+
else:
|
136 |
+
has_gpu_text = "Yes"
|
137 |
+
data.append(("GPU available", has_gpu_text))
|
138 |
+
if has_gpu:
|
139 |
+
devices = defaultdict(list)
|
140 |
+
for k in range(torch.cuda.device_count()):
|
141 |
+
cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k)))
|
142 |
+
name = torch.cuda.get_device_name(k) + f" (arch={cap})"
|
143 |
+
devices[name].append(str(k))
|
144 |
+
for name, devids in devices.items():
|
145 |
+
data.append(("GPU " + ",".join(devids), name))
|
146 |
+
|
147 |
+
if has_rocm:
|
148 |
+
msg = " - invalid!" if not (ROCM_HOME and os.path.isdir(ROCM_HOME)) else ""
|
149 |
+
data.append(("ROCM_HOME", str(ROCM_HOME) + msg))
|
150 |
+
else:
|
151 |
+
try:
|
152 |
+
from torch.utils.collect_env import get_nvidia_driver_version, run as _run
|
153 |
+
|
154 |
+
data.append(("Driver version", get_nvidia_driver_version(_run)))
|
155 |
+
except Exception:
|
156 |
+
pass
|
157 |
+
msg = " - invalid!" if not (CUDA_HOME and os.path.isdir(CUDA_HOME)) else ""
|
158 |
+
data.append(("CUDA_HOME", str(CUDA_HOME) + msg))
|
159 |
+
|
160 |
+
cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
161 |
+
if cuda_arch_list:
|
162 |
+
data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
|
163 |
+
data.append(("Pillow", PIL.__version__))
|
164 |
+
|
165 |
+
try:
|
166 |
+
data.append(
|
167 |
+
(
|
168 |
+
"torchvision",
|
169 |
+
str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
|
170 |
+
)
|
171 |
+
)
|
172 |
+
if has_cuda:
|
173 |
+
try:
|
174 |
+
torchvision_C = importlib.util.find_spec("torchvision._C").origin
|
175 |
+
msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
|
176 |
+
data.append(("torchvision arch flags", msg))
|
177 |
+
except (ImportError, AttributeError):
|
178 |
+
data.append(("torchvision._C", "Not found"))
|
179 |
+
except AttributeError:
|
180 |
+
data.append(("torchvision", "unknown"))
|
181 |
+
|
182 |
+
try:
|
183 |
+
import fvcore
|
184 |
+
|
185 |
+
data.append(("fvcore", fvcore.__version__))
|
186 |
+
except (ImportError, AttributeError):
|
187 |
+
pass
|
188 |
+
|
189 |
+
try:
|
190 |
+
import iopath
|
191 |
+
|
192 |
+
data.append(("iopath", iopath.__version__))
|
193 |
+
except (ImportError, AttributeError):
|
194 |
+
pass
|
195 |
+
|
196 |
+
try:
|
197 |
+
import cv2
|
198 |
+
|
199 |
+
data.append(("cv2", cv2.__version__))
|
200 |
+
except (ImportError, AttributeError):
|
201 |
+
data.append(("cv2", "Not found"))
|
202 |
+
env_str = tabulate(data) + "\n"
|
203 |
+
env_str += collect_torch_env()
|
204 |
+
return env_str
|
205 |
+
|
206 |
+
|
207 |
+
def test_nccl_ops():
|
208 |
+
num_gpu = torch.cuda.device_count()
|
209 |
+
if os.access("/tmp", os.W_OK):
|
210 |
+
import torch.multiprocessing as mp
|
211 |
+
|
212 |
+
dist_url = "file:///tmp/nccl_tmp_file"
|
213 |
+
print("Testing NCCL connectivity ... this should not hang.")
|
214 |
+
mp.spawn(_test_nccl_worker, nprocs=num_gpu, args=(num_gpu, dist_url), daemon=False)
|
215 |
+
print("NCCL succeeded.")
|
216 |
+
|
217 |
+
|
218 |
+
def _test_nccl_worker(rank, num_gpu, dist_url):
|
219 |
+
import torch.distributed as dist
|
220 |
+
|
221 |
+
dist.init_process_group(backend="NCCL", init_method=dist_url, rank=rank, world_size=num_gpu)
|
222 |
+
dist.barrier(device_ids=[rank])
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == "__main__":
|
226 |
+
try:
|
227 |
+
from annotator.oneformer.detectron2.utils.collect_env import collect_env_info as f
|
228 |
+
|
229 |
+
print(f())
|
230 |
+
except ImportError:
|
231 |
+
print(collect_env_info())
|
232 |
+
|
233 |
+
if torch.cuda.is_available():
|
234 |
+
num_gpu = torch.cuda.device_count()
|
235 |
+
for k in range(num_gpu):
|
236 |
+
device = f"cuda:{k}"
|
237 |
+
try:
|
238 |
+
x = torch.tensor([1, 2.0], dtype=torch.float32)
|
239 |
+
x = x.to(device)
|
240 |
+
except Exception as e:
|
241 |
+
print(
|
242 |
+
f"Unable to copy tensor to device={device}: {e}. "
|
243 |
+
"Your CUDA environment is broken."
|
244 |
+
)
|
245 |
+
if num_gpu > 1:
|
246 |
+
test_nccl_ops()
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/colormap.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
"""
|
4 |
+
An awesome colormap for really neat visualizations.
|
5 |
+
Copied from Detectron, and removed gray colors.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
|
11 |
+
__all__ = ["colormap", "random_color", "random_colors"]
|
12 |
+
|
13 |
+
# fmt: off
|
14 |
+
# RGB:
|
15 |
+
_COLORS = np.array(
|
16 |
+
[
|
17 |
+
0.000, 0.447, 0.741,
|
18 |
+
0.850, 0.325, 0.098,
|
19 |
+
0.929, 0.694, 0.125,
|
20 |
+
0.494, 0.184, 0.556,
|
21 |
+
0.466, 0.674, 0.188,
|
22 |
+
0.301, 0.745, 0.933,
|
23 |
+
0.635, 0.078, 0.184,
|
24 |
+
0.300, 0.300, 0.300,
|
25 |
+
0.600, 0.600, 0.600,
|
26 |
+
1.000, 0.000, 0.000,
|
27 |
+
1.000, 0.500, 0.000,
|
28 |
+
0.749, 0.749, 0.000,
|
29 |
+
0.000, 1.000, 0.000,
|
30 |
+
0.000, 0.000, 1.000,
|
31 |
+
0.667, 0.000, 1.000,
|
32 |
+
0.333, 0.333, 0.000,
|
33 |
+
0.333, 0.667, 0.000,
|
34 |
+
0.333, 1.000, 0.000,
|
35 |
+
0.667, 0.333, 0.000,
|
36 |
+
0.667, 0.667, 0.000,
|
37 |
+
0.667, 1.000, 0.000,
|
38 |
+
1.000, 0.333, 0.000,
|
39 |
+
1.000, 0.667, 0.000,
|
40 |
+
1.000, 1.000, 0.000,
|
41 |
+
0.000, 0.333, 0.500,
|
42 |
+
0.000, 0.667, 0.500,
|
43 |
+
0.000, 1.000, 0.500,
|
44 |
+
0.333, 0.000, 0.500,
|
45 |
+
0.333, 0.333, 0.500,
|
46 |
+
0.333, 0.667, 0.500,
|
47 |
+
0.333, 1.000, 0.500,
|
48 |
+
0.667, 0.000, 0.500,
|
49 |
+
0.667, 0.333, 0.500,
|
50 |
+
0.667, 0.667, 0.500,
|
51 |
+
0.667, 1.000, 0.500,
|
52 |
+
1.000, 0.000, 0.500,
|
53 |
+
1.000, 0.333, 0.500,
|
54 |
+
1.000, 0.667, 0.500,
|
55 |
+
1.000, 1.000, 0.500,
|
56 |
+
0.000, 0.333, 1.000,
|
57 |
+
0.000, 0.667, 1.000,
|
58 |
+
0.000, 1.000, 1.000,
|
59 |
+
0.333, 0.000, 1.000,
|
60 |
+
0.333, 0.333, 1.000,
|
61 |
+
0.333, 0.667, 1.000,
|
62 |
+
0.333, 1.000, 1.000,
|
63 |
+
0.667, 0.000, 1.000,
|
64 |
+
0.667, 0.333, 1.000,
|
65 |
+
0.667, 0.667, 1.000,
|
66 |
+
0.667, 1.000, 1.000,
|
67 |
+
1.000, 0.000, 1.000,
|
68 |
+
1.000, 0.333, 1.000,
|
69 |
+
1.000, 0.667, 1.000,
|
70 |
+
0.333, 0.000, 0.000,
|
71 |
+
0.500, 0.000, 0.000,
|
72 |
+
0.667, 0.000, 0.000,
|
73 |
+
0.833, 0.000, 0.000,
|
74 |
+
1.000, 0.000, 0.000,
|
75 |
+
0.000, 0.167, 0.000,
|
76 |
+
0.000, 0.333, 0.000,
|
77 |
+
0.000, 0.500, 0.000,
|
78 |
+
0.000, 0.667, 0.000,
|
79 |
+
0.000, 0.833, 0.000,
|
80 |
+
0.000, 1.000, 0.000,
|
81 |
+
0.000, 0.000, 0.167,
|
82 |
+
0.000, 0.000, 0.333,
|
83 |
+
0.000, 0.000, 0.500,
|
84 |
+
0.000, 0.000, 0.667,
|
85 |
+
0.000, 0.000, 0.833,
|
86 |
+
0.000, 0.000, 1.000,
|
87 |
+
0.000, 0.000, 0.000,
|
88 |
+
0.143, 0.143, 0.143,
|
89 |
+
0.857, 0.857, 0.857,
|
90 |
+
1.000, 1.000, 1.000
|
91 |
+
]
|
92 |
+
).astype(np.float32).reshape(-1, 3)
|
93 |
+
# fmt: on
|
94 |
+
|
95 |
+
|
96 |
+
def colormap(rgb=False, maximum=255):
|
97 |
+
"""
|
98 |
+
Args:
|
99 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
100 |
+
maximum (int): either 255 or 1
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
|
104 |
+
"""
|
105 |
+
assert maximum in [255, 1], maximum
|
106 |
+
c = _COLORS * maximum
|
107 |
+
if not rgb:
|
108 |
+
c = c[:, ::-1]
|
109 |
+
return c
|
110 |
+
|
111 |
+
|
112 |
+
def random_color(rgb=False, maximum=255):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
116 |
+
maximum (int): either 255 or 1
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
ndarray: a vector of 3 numbers
|
120 |
+
"""
|
121 |
+
idx = np.random.randint(0, len(_COLORS))
|
122 |
+
ret = _COLORS[idx] * maximum
|
123 |
+
if not rgb:
|
124 |
+
ret = ret[::-1]
|
125 |
+
return ret
|
126 |
+
|
127 |
+
|
128 |
+
def random_colors(N, rgb=False, maximum=255):
|
129 |
+
"""
|
130 |
+
Args:
|
131 |
+
N (int): number of unique colors needed
|
132 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
133 |
+
maximum (int): either 255 or 1
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
ndarray: a list of random_color
|
137 |
+
"""
|
138 |
+
indices = random.sample(range(len(_COLORS)), N)
|
139 |
+
ret = [_COLORS[i] * maximum for i in indices]
|
140 |
+
if not rgb:
|
141 |
+
ret = [x[::-1] for x in ret]
|
142 |
+
return ret
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
import cv2
|
147 |
+
|
148 |
+
size = 100
|
149 |
+
H, W = 10, 10
|
150 |
+
canvas = np.random.rand(H * size, W * size, 3).astype("float32")
|
151 |
+
for h in range(H):
|
152 |
+
for w in range(W):
|
153 |
+
idx = h * W + w
|
154 |
+
if idx >= len(_COLORS):
|
155 |
+
break
|
156 |
+
canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
|
157 |
+
cv2.imshow("a", canvas)
|
158 |
+
cv2.waitKey(0)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/comm.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
"""
|
3 |
+
This file contains primitives for multi-gpu communication.
|
4 |
+
This is useful when doing distributed training.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import functools
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
_LOCAL_PROCESS_GROUP = None
|
13 |
+
_MISSING_LOCAL_PG_ERROR = (
|
14 |
+
"Local process group is not yet created! Please use detectron2's `launch()` "
|
15 |
+
"to start processes and initialize pytorch process group. If you need to start "
|
16 |
+
"processes in other ways, please call comm.create_local_process_group("
|
17 |
+
"num_workers_per_machine) after calling torch.distributed.init_process_group()."
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def get_world_size() -> int:
|
22 |
+
if not dist.is_available():
|
23 |
+
return 1
|
24 |
+
if not dist.is_initialized():
|
25 |
+
return 1
|
26 |
+
return dist.get_world_size()
|
27 |
+
|
28 |
+
|
29 |
+
def get_rank() -> int:
|
30 |
+
if not dist.is_available():
|
31 |
+
return 0
|
32 |
+
if not dist.is_initialized():
|
33 |
+
return 0
|
34 |
+
return dist.get_rank()
|
35 |
+
|
36 |
+
|
37 |
+
@functools.lru_cache()
|
38 |
+
def create_local_process_group(num_workers_per_machine: int) -> None:
|
39 |
+
"""
|
40 |
+
Create a process group that contains ranks within the same machine.
|
41 |
+
|
42 |
+
Detectron2's launch() in engine/launch.py will call this function. If you start
|
43 |
+
workers without launch(), you'll have to also call this. Otherwise utilities
|
44 |
+
like `get_local_rank()` will not work.
|
45 |
+
|
46 |
+
This function contains a barrier. All processes must call it together.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
num_workers_per_machine: the number of worker processes per machine. Typically
|
50 |
+
the number of GPUs.
|
51 |
+
"""
|
52 |
+
global _LOCAL_PROCESS_GROUP
|
53 |
+
assert _LOCAL_PROCESS_GROUP is None
|
54 |
+
assert get_world_size() % num_workers_per_machine == 0
|
55 |
+
num_machines = get_world_size() // num_workers_per_machine
|
56 |
+
machine_rank = get_rank() // num_workers_per_machine
|
57 |
+
for i in range(num_machines):
|
58 |
+
ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine))
|
59 |
+
pg = dist.new_group(ranks_on_i)
|
60 |
+
if i == machine_rank:
|
61 |
+
_LOCAL_PROCESS_GROUP = pg
|
62 |
+
|
63 |
+
|
64 |
+
def get_local_process_group():
|
65 |
+
"""
|
66 |
+
Returns:
|
67 |
+
A torch process group which only includes processes that are on the same
|
68 |
+
machine as the current process. This group can be useful for communication
|
69 |
+
within a machine, e.g. a per-machine SyncBN.
|
70 |
+
"""
|
71 |
+
assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
|
72 |
+
return _LOCAL_PROCESS_GROUP
|
73 |
+
|
74 |
+
|
75 |
+
def get_local_rank() -> int:
|
76 |
+
"""
|
77 |
+
Returns:
|
78 |
+
The rank of the current process within the local (per-machine) process group.
|
79 |
+
"""
|
80 |
+
if not dist.is_available():
|
81 |
+
return 0
|
82 |
+
if not dist.is_initialized():
|
83 |
+
return 0
|
84 |
+
assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
|
85 |
+
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
86 |
+
|
87 |
+
|
88 |
+
def get_local_size() -> int:
|
89 |
+
"""
|
90 |
+
Returns:
|
91 |
+
The size of the per-machine process group,
|
92 |
+
i.e. the number of processes per machine.
|
93 |
+
"""
|
94 |
+
if not dist.is_available():
|
95 |
+
return 1
|
96 |
+
if not dist.is_initialized():
|
97 |
+
return 1
|
98 |
+
assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
|
99 |
+
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
100 |
+
|
101 |
+
|
102 |
+
def is_main_process() -> bool:
|
103 |
+
return get_rank() == 0
|
104 |
+
|
105 |
+
|
106 |
+
def synchronize():
|
107 |
+
"""
|
108 |
+
Helper function to synchronize (barrier) among all processes when
|
109 |
+
using distributed training
|
110 |
+
"""
|
111 |
+
if not dist.is_available():
|
112 |
+
return
|
113 |
+
if not dist.is_initialized():
|
114 |
+
return
|
115 |
+
world_size = dist.get_world_size()
|
116 |
+
if world_size == 1:
|
117 |
+
return
|
118 |
+
if dist.get_backend() == dist.Backend.NCCL:
|
119 |
+
# This argument is needed to avoid warnings.
|
120 |
+
# It's valid only for NCCL backend.
|
121 |
+
dist.barrier(device_ids=[torch.cuda.current_device()])
|
122 |
+
else:
|
123 |
+
dist.barrier()
|
124 |
+
|
125 |
+
|
126 |
+
@functools.lru_cache()
|
127 |
+
def _get_global_gloo_group():
|
128 |
+
"""
|
129 |
+
Return a process group based on gloo backend, containing all the ranks
|
130 |
+
The result is cached.
|
131 |
+
"""
|
132 |
+
if dist.get_backend() == "nccl":
|
133 |
+
return dist.new_group(backend="gloo")
|
134 |
+
else:
|
135 |
+
return dist.group.WORLD
|
136 |
+
|
137 |
+
|
138 |
+
def all_gather(data, group=None):
|
139 |
+
"""
|
140 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
141 |
+
|
142 |
+
Args:
|
143 |
+
data: any picklable object
|
144 |
+
group: a torch process group. By default, will use a group which
|
145 |
+
contains all ranks on gloo backend.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
list[data]: list of data gathered from each rank
|
149 |
+
"""
|
150 |
+
if get_world_size() == 1:
|
151 |
+
return [data]
|
152 |
+
if group is None:
|
153 |
+
group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage.
|
154 |
+
world_size = dist.get_world_size(group)
|
155 |
+
if world_size == 1:
|
156 |
+
return [data]
|
157 |
+
|
158 |
+
output = [None for _ in range(world_size)]
|
159 |
+
dist.all_gather_object(output, data, group=group)
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
def gather(data, dst=0, group=None):
|
164 |
+
"""
|
165 |
+
Run gather on arbitrary picklable data (not necessarily tensors).
|
166 |
+
|
167 |
+
Args:
|
168 |
+
data: any picklable object
|
169 |
+
dst (int): destination rank
|
170 |
+
group: a torch process group. By default, will use a group which
|
171 |
+
contains all ranks on gloo backend.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
175 |
+
an empty list.
|
176 |
+
"""
|
177 |
+
if get_world_size() == 1:
|
178 |
+
return [data]
|
179 |
+
if group is None:
|
180 |
+
group = _get_global_gloo_group()
|
181 |
+
world_size = dist.get_world_size(group=group)
|
182 |
+
if world_size == 1:
|
183 |
+
return [data]
|
184 |
+
rank = dist.get_rank(group=group)
|
185 |
+
|
186 |
+
if rank == dst:
|
187 |
+
output = [None for _ in range(world_size)]
|
188 |
+
dist.gather_object(data, output, dst=dst, group=group)
|
189 |
+
return output
|
190 |
+
else:
|
191 |
+
dist.gather_object(data, None, dst=dst, group=group)
|
192 |
+
return []
|
193 |
+
|
194 |
+
|
195 |
+
def shared_random_seed():
|
196 |
+
"""
|
197 |
+
Returns:
|
198 |
+
int: a random number that is the same across all workers.
|
199 |
+
If workers need a shared RNG, they can use this shared seed to
|
200 |
+
create one.
|
201 |
+
|
202 |
+
All workers must call this function, otherwise it will deadlock.
|
203 |
+
"""
|
204 |
+
ints = np.random.randint(2**31)
|
205 |
+
all_ints = all_gather(ints)
|
206 |
+
return all_ints[0]
|
207 |
+
|
208 |
+
|
209 |
+
def reduce_dict(input_dict, average=True):
|
210 |
+
"""
|
211 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
212 |
+
0 has the reduced results.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
216 |
+
average (bool): whether to do average or sum
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
a dict with the same keys as input_dict, after reduction.
|
220 |
+
"""
|
221 |
+
world_size = get_world_size()
|
222 |
+
if world_size < 2:
|
223 |
+
return input_dict
|
224 |
+
with torch.no_grad():
|
225 |
+
names = []
|
226 |
+
values = []
|
227 |
+
# sort the keys so that they are consistent across processes
|
228 |
+
for k in sorted(input_dict.keys()):
|
229 |
+
names.append(k)
|
230 |
+
values.append(input_dict[k])
|
231 |
+
values = torch.stack(values, dim=0)
|
232 |
+
dist.reduce(values, dst=0)
|
233 |
+
if dist.get_rank() == 0 and average:
|
234 |
+
# only main process gets accumulated, so only divide by
|
235 |
+
# world_size in this case
|
236 |
+
values /= world_size
|
237 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
238 |
+
return reduced_dict
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/develop.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
""" Utilities for developers only.
|
3 |
+
These are not visible to users (not automatically imported). And should not
|
4 |
+
appeared in docs."""
|
5 |
+
# adapted from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py
|
6 |
+
|
7 |
+
|
8 |
+
def create_dummy_class(klass, dependency, message=""):
|
9 |
+
"""
|
10 |
+
When a dependency of a class is not available, create a dummy class which throws ImportError
|
11 |
+
when used.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
klass (str): name of the class.
|
15 |
+
dependency (str): name of the dependency.
|
16 |
+
message: extra message to print
|
17 |
+
Returns:
|
18 |
+
class: a class object
|
19 |
+
"""
|
20 |
+
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
|
21 |
+
if message:
|
22 |
+
err = err + " " + message
|
23 |
+
|
24 |
+
class _DummyMetaClass(type):
|
25 |
+
# throw error on class attribute access
|
26 |
+
def __getattr__(_, __): # noqa: B902
|
27 |
+
raise ImportError(err)
|
28 |
+
|
29 |
+
class _Dummy(object, metaclass=_DummyMetaClass):
|
30 |
+
# throw error on constructor
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
raise ImportError(err)
|
33 |
+
|
34 |
+
return _Dummy
|
35 |
+
|
36 |
+
|
37 |
+
def create_dummy_func(func, dependency, message=""):
|
38 |
+
"""
|
39 |
+
When a dependency of a function is not available, create a dummy function which throws
|
40 |
+
ImportError when used.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
func (str): name of the function.
|
44 |
+
dependency (str or list[str]): name(s) of the dependency.
|
45 |
+
message: extra message to print
|
46 |
+
Returns:
|
47 |
+
function: a function object
|
48 |
+
"""
|
49 |
+
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
|
50 |
+
if message:
|
51 |
+
err = err + " " + message
|
52 |
+
|
53 |
+
if isinstance(dependency, (list, tuple)):
|
54 |
+
dependency = ",".join(dependency)
|
55 |
+
|
56 |
+
def _dummy(*args, **kwargs):
|
57 |
+
raise ImportError(err)
|
58 |
+
|
59 |
+
return _dummy
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/env.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import importlib
|
3 |
+
import importlib.util
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
from datetime import datetime
|
10 |
+
import torch
|
11 |
+
|
12 |
+
__all__ = ["seed_all_rng"]
|
13 |
+
|
14 |
+
|
15 |
+
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
16 |
+
"""
|
17 |
+
PyTorch version as a tuple of 2 ints. Useful for comparison.
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
|
22 |
+
"""
|
23 |
+
Whether we're building documentation.
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
def seed_all_rng(seed=None):
|
28 |
+
"""
|
29 |
+
Set the random seed for the RNG in torch, numpy and python.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
seed (int): if None, will use a strong random seed.
|
33 |
+
"""
|
34 |
+
if seed is None:
|
35 |
+
seed = (
|
36 |
+
os.getpid()
|
37 |
+
+ int(datetime.now().strftime("%S%f"))
|
38 |
+
+ int.from_bytes(os.urandom(2), "big")
|
39 |
+
)
|
40 |
+
logger = logging.getLogger(__name__)
|
41 |
+
logger.info("Using a generated random seed {}".format(seed))
|
42 |
+
np.random.seed(seed)
|
43 |
+
torch.manual_seed(seed)
|
44 |
+
random.seed(seed)
|
45 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
46 |
+
|
47 |
+
|
48 |
+
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
|
49 |
+
def _import_file(module_name, file_path, make_importable=False):
|
50 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
51 |
+
module = importlib.util.module_from_spec(spec)
|
52 |
+
spec.loader.exec_module(module)
|
53 |
+
if make_importable:
|
54 |
+
sys.modules[module_name] = module
|
55 |
+
return module
|
56 |
+
|
57 |
+
|
58 |
+
def _configure_libraries():
|
59 |
+
"""
|
60 |
+
Configurations for some libraries.
|
61 |
+
"""
|
62 |
+
# An environment option to disable `import cv2` globally,
|
63 |
+
# in case it leads to negative performance impact
|
64 |
+
disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
|
65 |
+
if disable_cv2:
|
66 |
+
sys.modules["cv2"] = None
|
67 |
+
else:
|
68 |
+
# Disable opencl in opencv since its interaction with cuda often has negative effects
|
69 |
+
# This envvar is supported after OpenCV 3.4.0
|
70 |
+
os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
|
71 |
+
try:
|
72 |
+
import cv2
|
73 |
+
|
74 |
+
if int(cv2.__version__.split(".")[0]) >= 3:
|
75 |
+
cv2.ocl.setUseOpenCL(False)
|
76 |
+
except ModuleNotFoundError:
|
77 |
+
# Other types of ImportError, if happened, should not be ignored.
|
78 |
+
# Because a failed opencv import could mess up address space
|
79 |
+
# https://github.com/skvark/opencv-python/issues/381
|
80 |
+
pass
|
81 |
+
|
82 |
+
def get_version(module, digit=2):
|
83 |
+
return tuple(map(int, module.__version__.split(".")[:digit]))
|
84 |
+
|
85 |
+
# fmt: off
|
86 |
+
assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
|
87 |
+
import fvcore
|
88 |
+
assert get_version(fvcore, 3) >= (0, 1, 2), "Requires fvcore>=0.1.2"
|
89 |
+
import yaml
|
90 |
+
assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
|
91 |
+
# fmt: on
|
92 |
+
|
93 |
+
|
94 |
+
_ENV_SETUP_DONE = False
|
95 |
+
|
96 |
+
|
97 |
+
def setup_environment():
|
98 |
+
"""Perform environment setup work. The default setup is a no-op, but this
|
99 |
+
function allows the user to specify a Python source file or a module in
|
100 |
+
the $DETECTRON2_ENV_MODULE environment variable, that performs
|
101 |
+
custom setup work that may be necessary to their computing environment.
|
102 |
+
"""
|
103 |
+
global _ENV_SETUP_DONE
|
104 |
+
if _ENV_SETUP_DONE:
|
105 |
+
return
|
106 |
+
_ENV_SETUP_DONE = True
|
107 |
+
|
108 |
+
_configure_libraries()
|
109 |
+
|
110 |
+
custom_module_path = os.environ.get("DETECTRON2_ENV_MODULE")
|
111 |
+
|
112 |
+
if custom_module_path:
|
113 |
+
setup_custom_environment(custom_module_path)
|
114 |
+
else:
|
115 |
+
# The default setup is a no-op
|
116 |
+
pass
|
117 |
+
|
118 |
+
|
119 |
+
def setup_custom_environment(custom_module):
|
120 |
+
"""
|
121 |
+
Load custom environment setup by importing a Python source file or a
|
122 |
+
module, and run the setup function.
|
123 |
+
"""
|
124 |
+
if custom_module.endswith(".py"):
|
125 |
+
module = _import_file("detectron2.utils.env.custom_module", custom_module)
|
126 |
+
else:
|
127 |
+
module = importlib.import_module(custom_module)
|
128 |
+
assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
|
129 |
+
"Custom environment module defined in {} does not have the "
|
130 |
+
"required callable attribute 'setup_environment'."
|
131 |
+
).format(custom_module)
|
132 |
+
module.setup_environment()
|
133 |
+
|
134 |
+
|
135 |
+
def fixup_module_metadata(module_name, namespace, keys=None):
|
136 |
+
"""
|
137 |
+
Fix the __qualname__ of module members to be their exported api name, so
|
138 |
+
when they are referenced in docs, sphinx can find them. Reference:
|
139 |
+
https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
|
140 |
+
"""
|
141 |
+
if not DOC_BUILDING:
|
142 |
+
return
|
143 |
+
seen_ids = set()
|
144 |
+
|
145 |
+
def fix_one(qualname, name, obj):
|
146 |
+
# avoid infinite recursion (relevant when using
|
147 |
+
# typing.Generic, for example)
|
148 |
+
if id(obj) in seen_ids:
|
149 |
+
return
|
150 |
+
seen_ids.add(id(obj))
|
151 |
+
|
152 |
+
mod = getattr(obj, "__module__", None)
|
153 |
+
if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
|
154 |
+
obj.__module__ = module_name
|
155 |
+
# Modules, unlike everything else in Python, put fully-qualitied
|
156 |
+
# names into their __name__ attribute. We check for "." to avoid
|
157 |
+
# rewriting these.
|
158 |
+
if hasattr(obj, "__name__") and "." not in obj.__name__:
|
159 |
+
obj.__name__ = name
|
160 |
+
obj.__qualname__ = qualname
|
161 |
+
if isinstance(obj, type):
|
162 |
+
for attr_name, attr_value in obj.__dict__.items():
|
163 |
+
fix_one(objname + "." + attr_name, attr_name, attr_value)
|
164 |
+
|
165 |
+
if keys is None:
|
166 |
+
keys = namespace.keys()
|
167 |
+
for objname in keys:
|
168 |
+
if not objname.startswith("_"):
|
169 |
+
obj = namespace[objname]
|
170 |
+
fix_one(objname, objname, obj)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/events.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
from collections import defaultdict
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from typing import Optional
|
10 |
+
import torch
|
11 |
+
from fvcore.common.history_buffer import HistoryBuffer
|
12 |
+
|
13 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"get_event_storage",
|
17 |
+
"JSONWriter",
|
18 |
+
"TensorboardXWriter",
|
19 |
+
"CommonMetricPrinter",
|
20 |
+
"EventStorage",
|
21 |
+
]
|
22 |
+
|
23 |
+
_CURRENT_STORAGE_STACK = []
|
24 |
+
|
25 |
+
|
26 |
+
def get_event_storage():
|
27 |
+
"""
|
28 |
+
Returns:
|
29 |
+
The :class:`EventStorage` object that's currently being used.
|
30 |
+
Throws an error if no :class:`EventStorage` is currently enabled.
|
31 |
+
"""
|
32 |
+
assert len(
|
33 |
+
_CURRENT_STORAGE_STACK
|
34 |
+
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
|
35 |
+
return _CURRENT_STORAGE_STACK[-1]
|
36 |
+
|
37 |
+
|
38 |
+
class EventWriter:
|
39 |
+
"""
|
40 |
+
Base class for writers that obtain events from :class:`EventStorage` and process them.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def write(self):
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
def close(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
class JSONWriter(EventWriter):
|
51 |
+
"""
|
52 |
+
Write scalars to a json file.
|
53 |
+
|
54 |
+
It saves scalars as one json per line (instead of a big json) for easy parsing.
|
55 |
+
|
56 |
+
Examples parsing such a json file:
|
57 |
+
::
|
58 |
+
$ cat metrics.json | jq -s '.[0:2]'
|
59 |
+
[
|
60 |
+
{
|
61 |
+
"data_time": 0.008433341979980469,
|
62 |
+
"iteration": 19,
|
63 |
+
"loss": 1.9228371381759644,
|
64 |
+
"loss_box_reg": 0.050025828182697296,
|
65 |
+
"loss_classifier": 0.5316952466964722,
|
66 |
+
"loss_mask": 0.7236229181289673,
|
67 |
+
"loss_rpn_box": 0.0856662318110466,
|
68 |
+
"loss_rpn_cls": 0.48198649287223816,
|
69 |
+
"lr": 0.007173333333333333,
|
70 |
+
"time": 0.25401854515075684
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"data_time": 0.007216215133666992,
|
74 |
+
"iteration": 39,
|
75 |
+
"loss": 1.282649278640747,
|
76 |
+
"loss_box_reg": 0.06222952902317047,
|
77 |
+
"loss_classifier": 0.30682939291000366,
|
78 |
+
"loss_mask": 0.6970193982124329,
|
79 |
+
"loss_rpn_box": 0.038663312792778015,
|
80 |
+
"loss_rpn_cls": 0.1471673548221588,
|
81 |
+
"lr": 0.007706666666666667,
|
82 |
+
"time": 0.2490077018737793
|
83 |
+
}
|
84 |
+
]
|
85 |
+
|
86 |
+
$ cat metrics.json | jq '.loss_mask'
|
87 |
+
0.7126231789588928
|
88 |
+
0.689423680305481
|
89 |
+
0.6776131987571716
|
90 |
+
...
|
91 |
+
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, json_file, window_size=20):
|
95 |
+
"""
|
96 |
+
Args:
|
97 |
+
json_file (str): path to the json file. New data will be appended if the file exists.
|
98 |
+
window_size (int): the window size of median smoothing for the scalars whose
|
99 |
+
`smoothing_hint` are True.
|
100 |
+
"""
|
101 |
+
self._file_handle = PathManager.open(json_file, "a")
|
102 |
+
self._window_size = window_size
|
103 |
+
self._last_write = -1
|
104 |
+
|
105 |
+
def write(self):
|
106 |
+
storage = get_event_storage()
|
107 |
+
to_save = defaultdict(dict)
|
108 |
+
|
109 |
+
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
|
110 |
+
# keep scalars that have not been written
|
111 |
+
if iter <= self._last_write:
|
112 |
+
continue
|
113 |
+
to_save[iter][k] = v
|
114 |
+
if len(to_save):
|
115 |
+
all_iters = sorted(to_save.keys())
|
116 |
+
self._last_write = max(all_iters)
|
117 |
+
|
118 |
+
for itr, scalars_per_iter in to_save.items():
|
119 |
+
scalars_per_iter["iteration"] = itr
|
120 |
+
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
|
121 |
+
self._file_handle.flush()
|
122 |
+
try:
|
123 |
+
os.fsync(self._file_handle.fileno())
|
124 |
+
except AttributeError:
|
125 |
+
pass
|
126 |
+
|
127 |
+
def close(self):
|
128 |
+
self._file_handle.close()
|
129 |
+
|
130 |
+
|
131 |
+
class TensorboardXWriter(EventWriter):
|
132 |
+
"""
|
133 |
+
Write all scalars to a tensorboard file.
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
|
137 |
+
"""
|
138 |
+
Args:
|
139 |
+
log_dir (str): the directory to save the output events
|
140 |
+
window_size (int): the scalars will be median-smoothed by this window size
|
141 |
+
|
142 |
+
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
|
143 |
+
"""
|
144 |
+
self._window_size = window_size
|
145 |
+
from torch.utils.tensorboard import SummaryWriter
|
146 |
+
|
147 |
+
self._writer = SummaryWriter(log_dir, **kwargs)
|
148 |
+
self._last_write = -1
|
149 |
+
|
150 |
+
def write(self):
|
151 |
+
storage = get_event_storage()
|
152 |
+
new_last_write = self._last_write
|
153 |
+
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
|
154 |
+
if iter > self._last_write:
|
155 |
+
self._writer.add_scalar(k, v, iter)
|
156 |
+
new_last_write = max(new_last_write, iter)
|
157 |
+
self._last_write = new_last_write
|
158 |
+
|
159 |
+
# storage.put_{image,histogram} is only meant to be used by
|
160 |
+
# tensorboard writer. So we access its internal fields directly from here.
|
161 |
+
if len(storage._vis_data) >= 1:
|
162 |
+
for img_name, img, step_num in storage._vis_data:
|
163 |
+
self._writer.add_image(img_name, img, step_num)
|
164 |
+
# Storage stores all image data and rely on this writer to clear them.
|
165 |
+
# As a result it assumes only one writer will use its image data.
|
166 |
+
# An alternative design is to let storage store limited recent
|
167 |
+
# data (e.g. only the most recent image) that all writers can access.
|
168 |
+
# In that case a writer may not see all image data if its period is long.
|
169 |
+
storage.clear_images()
|
170 |
+
|
171 |
+
if len(storage._histograms) >= 1:
|
172 |
+
for params in storage._histograms:
|
173 |
+
self._writer.add_histogram_raw(**params)
|
174 |
+
storage.clear_histograms()
|
175 |
+
|
176 |
+
def close(self):
|
177 |
+
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
|
178 |
+
self._writer.close()
|
179 |
+
|
180 |
+
|
181 |
+
class CommonMetricPrinter(EventWriter):
|
182 |
+
"""
|
183 |
+
Print **common** metrics to the terminal, including
|
184 |
+
iteration time, ETA, memory, all losses, and the learning rate.
|
185 |
+
It also applies smoothing using a window of 20 elements.
|
186 |
+
|
187 |
+
It's meant to print common metrics in common ways.
|
188 |
+
To print something in more customized ways, please implement a similar printer by yourself.
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
|
192 |
+
"""
|
193 |
+
Args:
|
194 |
+
max_iter: the maximum number of iterations to train.
|
195 |
+
Used to compute ETA. If not given, ETA will not be printed.
|
196 |
+
window_size (int): the losses will be median-smoothed by this window size
|
197 |
+
"""
|
198 |
+
self.logger = logging.getLogger(__name__)
|
199 |
+
self._max_iter = max_iter
|
200 |
+
self._window_size = window_size
|
201 |
+
self._last_write = None # (step, time) of last call to write(). Used to compute ETA
|
202 |
+
|
203 |
+
def _get_eta(self, storage) -> Optional[str]:
|
204 |
+
if self._max_iter is None:
|
205 |
+
return ""
|
206 |
+
iteration = storage.iter
|
207 |
+
try:
|
208 |
+
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
|
209 |
+
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
|
210 |
+
return str(datetime.timedelta(seconds=int(eta_seconds)))
|
211 |
+
except KeyError:
|
212 |
+
# estimate eta on our own - more noisy
|
213 |
+
eta_string = None
|
214 |
+
if self._last_write is not None:
|
215 |
+
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
|
216 |
+
iteration - self._last_write[0]
|
217 |
+
)
|
218 |
+
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
|
219 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
220 |
+
self._last_write = (iteration, time.perf_counter())
|
221 |
+
return eta_string
|
222 |
+
|
223 |
+
def write(self):
|
224 |
+
storage = get_event_storage()
|
225 |
+
iteration = storage.iter
|
226 |
+
if iteration == self._max_iter:
|
227 |
+
# This hook only reports training progress (loss, ETA, etc) but not other data,
|
228 |
+
# therefore do not write anything after training succeeds, even if this method
|
229 |
+
# is called.
|
230 |
+
return
|
231 |
+
|
232 |
+
try:
|
233 |
+
avg_data_time = storage.history("data_time").avg(
|
234 |
+
storage.count_samples("data_time", self._window_size)
|
235 |
+
)
|
236 |
+
last_data_time = storage.history("data_time").latest()
|
237 |
+
except KeyError:
|
238 |
+
# they may not exist in the first few iterations (due to warmup)
|
239 |
+
# or when SimpleTrainer is not used
|
240 |
+
avg_data_time = None
|
241 |
+
last_data_time = None
|
242 |
+
try:
|
243 |
+
avg_iter_time = storage.history("time").global_avg()
|
244 |
+
last_iter_time = storage.history("time").latest()
|
245 |
+
except KeyError:
|
246 |
+
avg_iter_time = None
|
247 |
+
last_iter_time = None
|
248 |
+
try:
|
249 |
+
lr = "{:.5g}".format(storage.history("lr").latest())
|
250 |
+
except KeyError:
|
251 |
+
lr = "N/A"
|
252 |
+
|
253 |
+
eta_string = self._get_eta(storage)
|
254 |
+
|
255 |
+
if torch.cuda.is_available():
|
256 |
+
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
257 |
+
else:
|
258 |
+
max_mem_mb = None
|
259 |
+
|
260 |
+
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
261 |
+
self.logger.info(
|
262 |
+
str.format(
|
263 |
+
" {eta}iter: {iter} {losses} {non_losses} {avg_time}{last_time}"
|
264 |
+
+ "{avg_data_time}{last_data_time} lr: {lr} {memory}",
|
265 |
+
eta=f"eta: {eta_string} " if eta_string else "",
|
266 |
+
iter=iteration,
|
267 |
+
losses=" ".join(
|
268 |
+
[
|
269 |
+
"{}: {:.4g}".format(
|
270 |
+
k, v.median(storage.count_samples(k, self._window_size))
|
271 |
+
)
|
272 |
+
for k, v in storage.histories().items()
|
273 |
+
if "loss" in k
|
274 |
+
]
|
275 |
+
),
|
276 |
+
non_losses=" ".join(
|
277 |
+
[
|
278 |
+
"{}: {:.4g}".format(
|
279 |
+
k, v.median(storage.count_samples(k, self._window_size))
|
280 |
+
)
|
281 |
+
for k, v in storage.histories().items()
|
282 |
+
if "[metric]" in k
|
283 |
+
]
|
284 |
+
),
|
285 |
+
avg_time="time: {:.4f} ".format(avg_iter_time)
|
286 |
+
if avg_iter_time is not None
|
287 |
+
else "",
|
288 |
+
last_time="last_time: {:.4f} ".format(last_iter_time)
|
289 |
+
if last_iter_time is not None
|
290 |
+
else "",
|
291 |
+
avg_data_time="data_time: {:.4f} ".format(avg_data_time)
|
292 |
+
if avg_data_time is not None
|
293 |
+
else "",
|
294 |
+
last_data_time="last_data_time: {:.4f} ".format(last_data_time)
|
295 |
+
if last_data_time is not None
|
296 |
+
else "",
|
297 |
+
lr=lr,
|
298 |
+
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
|
299 |
+
)
|
300 |
+
)
|
301 |
+
|
302 |
+
|
303 |
+
class EventStorage:
|
304 |
+
"""
|
305 |
+
The user-facing class that provides metric storage functionalities.
|
306 |
+
|
307 |
+
In the future we may add support for storing / logging other types of data if needed.
|
308 |
+
"""
|
309 |
+
|
310 |
+
def __init__(self, start_iter=0):
|
311 |
+
"""
|
312 |
+
Args:
|
313 |
+
start_iter (int): the iteration number to start with
|
314 |
+
"""
|
315 |
+
self._history = defaultdict(HistoryBuffer)
|
316 |
+
self._smoothing_hints = {}
|
317 |
+
self._latest_scalars = {}
|
318 |
+
self._iter = start_iter
|
319 |
+
self._current_prefix = ""
|
320 |
+
self._vis_data = []
|
321 |
+
self._histograms = []
|
322 |
+
|
323 |
+
def put_image(self, img_name, img_tensor):
|
324 |
+
"""
|
325 |
+
Add an `img_tensor` associated with `img_name`, to be shown on
|
326 |
+
tensorboard.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
img_name (str): The name of the image to put into tensorboard.
|
330 |
+
img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
|
331 |
+
Tensor of shape `[channel, height, width]` where `channel` is
|
332 |
+
3. The image format should be RGB. The elements in img_tensor
|
333 |
+
can either have values in [0, 1] (float32) or [0, 255] (uint8).
|
334 |
+
The `img_tensor` will be visualized in tensorboard.
|
335 |
+
"""
|
336 |
+
self._vis_data.append((img_name, img_tensor, self._iter))
|
337 |
+
|
338 |
+
def put_scalar(self, name, value, smoothing_hint=True):
|
339 |
+
"""
|
340 |
+
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
|
344 |
+
smoothed when logged. The hint will be accessible through
|
345 |
+
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
|
346 |
+
and apply custom smoothing rule.
|
347 |
+
|
348 |
+
It defaults to True because most scalars we save need to be smoothed to
|
349 |
+
provide any useful signal.
|
350 |
+
"""
|
351 |
+
name = self._current_prefix + name
|
352 |
+
history = self._history[name]
|
353 |
+
value = float(value)
|
354 |
+
history.update(value, self._iter)
|
355 |
+
self._latest_scalars[name] = (value, self._iter)
|
356 |
+
|
357 |
+
existing_hint = self._smoothing_hints.get(name)
|
358 |
+
if existing_hint is not None:
|
359 |
+
assert (
|
360 |
+
existing_hint == smoothing_hint
|
361 |
+
), "Scalar {} was put with a different smoothing_hint!".format(name)
|
362 |
+
else:
|
363 |
+
self._smoothing_hints[name] = smoothing_hint
|
364 |
+
|
365 |
+
def put_scalars(self, *, smoothing_hint=True, **kwargs):
|
366 |
+
"""
|
367 |
+
Put multiple scalars from keyword arguments.
|
368 |
+
|
369 |
+
Examples:
|
370 |
+
|
371 |
+
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
|
372 |
+
"""
|
373 |
+
for k, v in kwargs.items():
|
374 |
+
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
|
375 |
+
|
376 |
+
def put_histogram(self, hist_name, hist_tensor, bins=1000):
|
377 |
+
"""
|
378 |
+
Create a histogram from a tensor.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
hist_name (str): The name of the histogram to put into tensorboard.
|
382 |
+
hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
|
383 |
+
into a histogram.
|
384 |
+
bins (int): Number of histogram bins.
|
385 |
+
"""
|
386 |
+
ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
|
387 |
+
|
388 |
+
# Create a histogram with PyTorch
|
389 |
+
hist_counts = torch.histc(hist_tensor, bins=bins)
|
390 |
+
hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
|
391 |
+
|
392 |
+
# Parameter for the add_histogram_raw function of SummaryWriter
|
393 |
+
hist_params = dict(
|
394 |
+
tag=hist_name,
|
395 |
+
min=ht_min,
|
396 |
+
max=ht_max,
|
397 |
+
num=len(hist_tensor),
|
398 |
+
sum=float(hist_tensor.sum()),
|
399 |
+
sum_squares=float(torch.sum(hist_tensor**2)),
|
400 |
+
bucket_limits=hist_edges[1:].tolist(),
|
401 |
+
bucket_counts=hist_counts.tolist(),
|
402 |
+
global_step=self._iter,
|
403 |
+
)
|
404 |
+
self._histograms.append(hist_params)
|
405 |
+
|
406 |
+
def history(self, name):
|
407 |
+
"""
|
408 |
+
Returns:
|
409 |
+
HistoryBuffer: the scalar history for name
|
410 |
+
"""
|
411 |
+
ret = self._history.get(name, None)
|
412 |
+
if ret is None:
|
413 |
+
raise KeyError("No history metric available for {}!".format(name))
|
414 |
+
return ret
|
415 |
+
|
416 |
+
def histories(self):
|
417 |
+
"""
|
418 |
+
Returns:
|
419 |
+
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
|
420 |
+
"""
|
421 |
+
return self._history
|
422 |
+
|
423 |
+
def latest(self):
|
424 |
+
"""
|
425 |
+
Returns:
|
426 |
+
dict[str -> (float, int)]: mapping from the name of each scalar to the most
|
427 |
+
recent value and the iteration number its added.
|
428 |
+
"""
|
429 |
+
return self._latest_scalars
|
430 |
+
|
431 |
+
def latest_with_smoothing_hint(self, window_size=20):
|
432 |
+
"""
|
433 |
+
Similar to :meth:`latest`, but the returned values
|
434 |
+
are either the un-smoothed original latest value,
|
435 |
+
or a median of the given window_size,
|
436 |
+
depend on whether the smoothing_hint is True.
|
437 |
+
|
438 |
+
This provides a default behavior that other writers can use.
|
439 |
+
|
440 |
+
Note: All scalars saved in the past `window_size` iterations are used for smoothing.
|
441 |
+
This is different from the `window_size` definition in HistoryBuffer.
|
442 |
+
Use :meth:`get_history_window_size` to get the `window_size` used in HistoryBuffer.
|
443 |
+
"""
|
444 |
+
result = {}
|
445 |
+
for k, (v, itr) in self._latest_scalars.items():
|
446 |
+
result[k] = (
|
447 |
+
self._history[k].median(self.count_samples(k, window_size))
|
448 |
+
if self._smoothing_hints[k]
|
449 |
+
else v,
|
450 |
+
itr,
|
451 |
+
)
|
452 |
+
return result
|
453 |
+
|
454 |
+
def count_samples(self, name, window_size=20):
|
455 |
+
"""
|
456 |
+
Return the number of samples logged in the past `window_size` iterations.
|
457 |
+
"""
|
458 |
+
samples = 0
|
459 |
+
data = self._history[name].values()
|
460 |
+
for _, iter_ in reversed(data):
|
461 |
+
if iter_ > data[-1][1] - window_size:
|
462 |
+
samples += 1
|
463 |
+
else:
|
464 |
+
break
|
465 |
+
return samples
|
466 |
+
|
467 |
+
def smoothing_hints(self):
|
468 |
+
"""
|
469 |
+
Returns:
|
470 |
+
dict[name -> bool]: the user-provided hint on whether the scalar
|
471 |
+
is noisy and needs smoothing.
|
472 |
+
"""
|
473 |
+
return self._smoothing_hints
|
474 |
+
|
475 |
+
def step(self):
|
476 |
+
"""
|
477 |
+
User should either: (1) Call this function to increment storage.iter when needed. Or
|
478 |
+
(2) Set `storage.iter` to the correct iteration number before each iteration.
|
479 |
+
|
480 |
+
The storage will then be able to associate the new data with an iteration number.
|
481 |
+
"""
|
482 |
+
self._iter += 1
|
483 |
+
|
484 |
+
@property
|
485 |
+
def iter(self):
|
486 |
+
"""
|
487 |
+
Returns:
|
488 |
+
int: The current iteration number. When used together with a trainer,
|
489 |
+
this is ensured to be the same as trainer.iter.
|
490 |
+
"""
|
491 |
+
return self._iter
|
492 |
+
|
493 |
+
@iter.setter
|
494 |
+
def iter(self, val):
|
495 |
+
self._iter = int(val)
|
496 |
+
|
497 |
+
@property
|
498 |
+
def iteration(self):
|
499 |
+
# for backward compatibility
|
500 |
+
return self._iter
|
501 |
+
|
502 |
+
def __enter__(self):
|
503 |
+
_CURRENT_STORAGE_STACK.append(self)
|
504 |
+
return self
|
505 |
+
|
506 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
507 |
+
assert _CURRENT_STORAGE_STACK[-1] == self
|
508 |
+
_CURRENT_STORAGE_STACK.pop()
|
509 |
+
|
510 |
+
@contextmanager
|
511 |
+
def name_scope(self, name):
|
512 |
+
"""
|
513 |
+
Yields:
|
514 |
+
A context within which all the events added to this storage
|
515 |
+
will be prefixed by the name scope.
|
516 |
+
"""
|
517 |
+
old_prefix = self._current_prefix
|
518 |
+
self._current_prefix = name.rstrip("/") + "/"
|
519 |
+
yield
|
520 |
+
self._current_prefix = old_prefix
|
521 |
+
|
522 |
+
def clear_images(self):
|
523 |
+
"""
|
524 |
+
Delete all the stored images for visualization. This should be called
|
525 |
+
after images are written to tensorboard.
|
526 |
+
"""
|
527 |
+
self._vis_data = []
|
528 |
+
|
529 |
+
def clear_histograms(self):
|
530 |
+
"""
|
531 |
+
Delete all the stored histograms for visualization.
|
532 |
+
This should be called after histograms are written to tensorboard.
|
533 |
+
"""
|
534 |
+
self._histograms = []
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/file_io.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
|
3 |
+
from iopath.common.file_io import PathManager as PathManagerBase
|
4 |
+
|
5 |
+
__all__ = ["PathManager", "PathHandler"]
|
6 |
+
|
7 |
+
|
8 |
+
PathManager = PathManagerBase()
|
9 |
+
"""
|
10 |
+
This is a detectron2 project-specific PathManager.
|
11 |
+
We try to stay away from global PathManager in fvcore as it
|
12 |
+
introduces potential conflicts among other libraries.
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
class Detectron2Handler(PathHandler):
|
17 |
+
"""
|
18 |
+
Resolve anything that's hosted under detectron2's namespace.
|
19 |
+
"""
|
20 |
+
|
21 |
+
PREFIX = "detectron2://"
|
22 |
+
S3_DETECTRON2_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/"
|
23 |
+
|
24 |
+
def _get_supported_prefixes(self):
|
25 |
+
return [self.PREFIX]
|
26 |
+
|
27 |
+
def _get_local_path(self, path, **kwargs):
|
28 |
+
name = path[len(self.PREFIX) :]
|
29 |
+
return PathManager.get_local_path(self.S3_DETECTRON2_PREFIX + name, **kwargs)
|
30 |
+
|
31 |
+
def _open(self, path, mode="r", **kwargs):
|
32 |
+
return PathManager.open(
|
33 |
+
self.S3_DETECTRON2_PREFIX + path[len(self.PREFIX) :], mode, **kwargs
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
PathManager.register_handler(HTTPURLHandler())
|
38 |
+
PathManager.register_handler(OneDrivePathHandler())
|
39 |
+
PathManager.register_handler(Detectron2Handler())
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/logger.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import atexit
|
3 |
+
import functools
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from collections import Counter
|
9 |
+
import torch
|
10 |
+
from tabulate import tabulate
|
11 |
+
from termcolor import colored
|
12 |
+
|
13 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
14 |
+
|
15 |
+
__all__ = ["setup_logger", "log_first_n", "log_every_n", "log_every_n_seconds"]
|
16 |
+
|
17 |
+
|
18 |
+
class _ColorfulFormatter(logging.Formatter):
|
19 |
+
def __init__(self, *args, **kwargs):
|
20 |
+
self._root_name = kwargs.pop("root_name") + "."
|
21 |
+
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
22 |
+
if len(self._abbrev_name):
|
23 |
+
self._abbrev_name = self._abbrev_name + "."
|
24 |
+
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
25 |
+
|
26 |
+
def formatMessage(self, record):
|
27 |
+
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
28 |
+
log = super(_ColorfulFormatter, self).formatMessage(record)
|
29 |
+
if record.levelno == logging.WARNING:
|
30 |
+
prefix = colored("WARNING", "red", attrs=["blink"])
|
31 |
+
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
32 |
+
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
33 |
+
else:
|
34 |
+
return log
|
35 |
+
return prefix + " " + log
|
36 |
+
|
37 |
+
|
38 |
+
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
|
39 |
+
def setup_logger(
|
40 |
+
output=None, distributed_rank=0, *, color=True, name="detectron2", abbrev_name=None
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
Initialize the detectron2 logger and set its verbosity level to "DEBUG".
|
44 |
+
|
45 |
+
Args:
|
46 |
+
output (str): a file name or a directory to save log. If None, will not save log file.
|
47 |
+
If ends with ".txt" or ".log", assumed to be a file name.
|
48 |
+
Otherwise, logs will be saved to `output/log.txt`.
|
49 |
+
name (str): the root module name of this logger
|
50 |
+
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
|
51 |
+
Set to "" to not log the root module in logs.
|
52 |
+
By default, will abbreviate "detectron2" to "d2" and leave other
|
53 |
+
modules unchanged.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
logging.Logger: a logger
|
57 |
+
"""
|
58 |
+
logger = logging.getLogger(name)
|
59 |
+
logger.setLevel(logging.DEBUG)
|
60 |
+
logger.propagate = False
|
61 |
+
|
62 |
+
if abbrev_name is None:
|
63 |
+
abbrev_name = "d2" if name == "detectron2" else name
|
64 |
+
|
65 |
+
plain_formatter = logging.Formatter(
|
66 |
+
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
|
67 |
+
)
|
68 |
+
# stdout logging: master only
|
69 |
+
if distributed_rank == 0:
|
70 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
71 |
+
ch.setLevel(logging.DEBUG)
|
72 |
+
if color:
|
73 |
+
formatter = _ColorfulFormatter(
|
74 |
+
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
75 |
+
datefmt="%m/%d %H:%M:%S",
|
76 |
+
root_name=name,
|
77 |
+
abbrev_name=str(abbrev_name),
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
formatter = plain_formatter
|
81 |
+
ch.setFormatter(formatter)
|
82 |
+
logger.addHandler(ch)
|
83 |
+
|
84 |
+
# file logging: all workers
|
85 |
+
if output is not None:
|
86 |
+
if output.endswith(".txt") or output.endswith(".log"):
|
87 |
+
filename = output
|
88 |
+
else:
|
89 |
+
filename = os.path.join(output, "log.txt")
|
90 |
+
if distributed_rank > 0:
|
91 |
+
filename = filename + ".rank{}".format(distributed_rank)
|
92 |
+
PathManager.mkdirs(os.path.dirname(filename))
|
93 |
+
|
94 |
+
fh = logging.StreamHandler(_cached_log_stream(filename))
|
95 |
+
fh.setLevel(logging.DEBUG)
|
96 |
+
fh.setFormatter(plain_formatter)
|
97 |
+
logger.addHandler(fh)
|
98 |
+
|
99 |
+
return logger
|
100 |
+
|
101 |
+
|
102 |
+
# cache the opened file object, so that different calls to `setup_logger`
|
103 |
+
# with the same file name can safely write to the same file.
|
104 |
+
@functools.lru_cache(maxsize=None)
|
105 |
+
def _cached_log_stream(filename):
|
106 |
+
# use 1K buffer if writing to cloud storage
|
107 |
+
io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1)
|
108 |
+
atexit.register(io.close)
|
109 |
+
return io
|
110 |
+
|
111 |
+
|
112 |
+
"""
|
113 |
+
Below are some other convenient logging methods.
|
114 |
+
They are mainly adopted from
|
115 |
+
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
|
116 |
+
"""
|
117 |
+
|
118 |
+
|
119 |
+
def _find_caller():
|
120 |
+
"""
|
121 |
+
Returns:
|
122 |
+
str: module name of the caller
|
123 |
+
tuple: a hashable key to be used to identify different callers
|
124 |
+
"""
|
125 |
+
frame = sys._getframe(2)
|
126 |
+
while frame:
|
127 |
+
code = frame.f_code
|
128 |
+
if os.path.join("utils", "logger.") not in code.co_filename:
|
129 |
+
mod_name = frame.f_globals["__name__"]
|
130 |
+
if mod_name == "__main__":
|
131 |
+
mod_name = "detectron2"
|
132 |
+
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
|
133 |
+
frame = frame.f_back
|
134 |
+
|
135 |
+
|
136 |
+
_LOG_COUNTER = Counter()
|
137 |
+
_LOG_TIMER = {}
|
138 |
+
|
139 |
+
|
140 |
+
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
|
141 |
+
"""
|
142 |
+
Log only for the first n times.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
lvl (int): the logging level
|
146 |
+
msg (str):
|
147 |
+
n (int):
|
148 |
+
name (str): name of the logger to use. Will use the caller's module by default.
|
149 |
+
key (str or tuple[str]): the string(s) can be one of "caller" or
|
150 |
+
"message", which defines how to identify duplicated logs.
|
151 |
+
For example, if called with `n=1, key="caller"`, this function
|
152 |
+
will only log the first call from the same caller, regardless of
|
153 |
+
the message content.
|
154 |
+
If called with `n=1, key="message"`, this function will log the
|
155 |
+
same content only once, even if they are called from different places.
|
156 |
+
If called with `n=1, key=("caller", "message")`, this function
|
157 |
+
will not log only if the same caller has logged the same message before.
|
158 |
+
"""
|
159 |
+
if isinstance(key, str):
|
160 |
+
key = (key,)
|
161 |
+
assert len(key) > 0
|
162 |
+
|
163 |
+
caller_module, caller_key = _find_caller()
|
164 |
+
hash_key = ()
|
165 |
+
if "caller" in key:
|
166 |
+
hash_key = hash_key + caller_key
|
167 |
+
if "message" in key:
|
168 |
+
hash_key = hash_key + (msg,)
|
169 |
+
|
170 |
+
_LOG_COUNTER[hash_key] += 1
|
171 |
+
if _LOG_COUNTER[hash_key] <= n:
|
172 |
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
173 |
+
|
174 |
+
|
175 |
+
def log_every_n(lvl, msg, n=1, *, name=None):
|
176 |
+
"""
|
177 |
+
Log once per n times.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
lvl (int): the logging level
|
181 |
+
msg (str):
|
182 |
+
n (int):
|
183 |
+
name (str): name of the logger to use. Will use the caller's module by default.
|
184 |
+
"""
|
185 |
+
caller_module, key = _find_caller()
|
186 |
+
_LOG_COUNTER[key] += 1
|
187 |
+
if n == 1 or _LOG_COUNTER[key] % n == 1:
|
188 |
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
189 |
+
|
190 |
+
|
191 |
+
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
|
192 |
+
"""
|
193 |
+
Log no more than once per n seconds.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
lvl (int): the logging level
|
197 |
+
msg (str):
|
198 |
+
n (int):
|
199 |
+
name (str): name of the logger to use. Will use the caller's module by default.
|
200 |
+
"""
|
201 |
+
caller_module, key = _find_caller()
|
202 |
+
last_logged = _LOG_TIMER.get(key, None)
|
203 |
+
current_time = time.time()
|
204 |
+
if last_logged is None or current_time - last_logged >= n:
|
205 |
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
206 |
+
_LOG_TIMER[key] = current_time
|
207 |
+
|
208 |
+
|
209 |
+
def create_small_table(small_dict):
|
210 |
+
"""
|
211 |
+
Create a small table using the keys of small_dict as headers. This is only
|
212 |
+
suitable for small dictionaries.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
small_dict (dict): a result dictionary of only a few items.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
str: the table as a string.
|
219 |
+
"""
|
220 |
+
keys, values = tuple(zip(*small_dict.items()))
|
221 |
+
table = tabulate(
|
222 |
+
[values],
|
223 |
+
headers=keys,
|
224 |
+
tablefmt="pipe",
|
225 |
+
floatfmt=".3f",
|
226 |
+
stralign="center",
|
227 |
+
numalign="center",
|
228 |
+
)
|
229 |
+
return table
|
230 |
+
|
231 |
+
|
232 |
+
def _log_api_usage(identifier: str):
|
233 |
+
"""
|
234 |
+
Internal function used to log the usage of different detectron2 components
|
235 |
+
inside facebook's infra.
|
236 |
+
"""
|
237 |
+
torch._C._log_api_usage_once("detectron2." + identifier)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/memory.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from contextlib import contextmanager
|
5 |
+
from functools import wraps
|
6 |
+
import torch
|
7 |
+
|
8 |
+
__all__ = ["retry_if_cuda_oom"]
|
9 |
+
|
10 |
+
|
11 |
+
@contextmanager
|
12 |
+
def _ignore_torch_cuda_oom():
|
13 |
+
"""
|
14 |
+
A context which ignores CUDA OOM exception from pytorch.
|
15 |
+
"""
|
16 |
+
try:
|
17 |
+
yield
|
18 |
+
except RuntimeError as e:
|
19 |
+
# NOTE: the string may change?
|
20 |
+
if "CUDA out of memory. " in str(e):
|
21 |
+
pass
|
22 |
+
else:
|
23 |
+
raise
|
24 |
+
|
25 |
+
|
26 |
+
def retry_if_cuda_oom(func):
|
27 |
+
"""
|
28 |
+
Makes a function retry itself after encountering
|
29 |
+
pytorch's CUDA OOM error.
|
30 |
+
It will first retry after calling `torch.cuda.empty_cache()`.
|
31 |
+
|
32 |
+
If that still fails, it will then retry by trying to convert inputs to CPUs.
|
33 |
+
In this case, it expects the function to dispatch to CPU implementation.
|
34 |
+
The return values may become CPU tensors as well and it's user's
|
35 |
+
responsibility to convert it back to CUDA tensor if needed.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
func: a stateless callable that takes tensor-like objects as arguments
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
a callable which retries `func` if OOM is encountered.
|
42 |
+
|
43 |
+
Examples:
|
44 |
+
::
|
45 |
+
output = retry_if_cuda_oom(some_torch_function)(input1, input2)
|
46 |
+
# output may be on CPU even if inputs are on GPU
|
47 |
+
|
48 |
+
Note:
|
49 |
+
1. When converting inputs to CPU, it will only look at each argument and check
|
50 |
+
if it has `.device` and `.to` for conversion. Nested structures of tensors
|
51 |
+
are not supported.
|
52 |
+
|
53 |
+
2. Since the function might be called more than once, it has to be
|
54 |
+
stateless.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def maybe_to_cpu(x):
|
58 |
+
try:
|
59 |
+
like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
|
60 |
+
except AttributeError:
|
61 |
+
like_gpu_tensor = False
|
62 |
+
if like_gpu_tensor:
|
63 |
+
return x.to(device="cpu")
|
64 |
+
else:
|
65 |
+
return x
|
66 |
+
|
67 |
+
@wraps(func)
|
68 |
+
def wrapped(*args, **kwargs):
|
69 |
+
with _ignore_torch_cuda_oom():
|
70 |
+
return func(*args, **kwargs)
|
71 |
+
|
72 |
+
# Clear cache and retry
|
73 |
+
torch.cuda.empty_cache()
|
74 |
+
with _ignore_torch_cuda_oom():
|
75 |
+
return func(*args, **kwargs)
|
76 |
+
|
77 |
+
# Try on CPU. This slows down the code significantly, therefore print a notice.
|
78 |
+
logger = logging.getLogger(__name__)
|
79 |
+
logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func)))
|
80 |
+
new_args = (maybe_to_cpu(x) for x in args)
|
81 |
+
new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
|
82 |
+
return func(*new_args, **new_kwargs)
|
83 |
+
|
84 |
+
return wrapped
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/registry.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
import pydoc
|
5 |
+
from fvcore.common.registry import Registry # for backward compatibility.
|
6 |
+
|
7 |
+
"""
|
8 |
+
``Registry`` and `locate` provide ways to map a string (typically found
|
9 |
+
in config files) to callable objects.
|
10 |
+
"""
|
11 |
+
|
12 |
+
__all__ = ["Registry", "locate"]
|
13 |
+
|
14 |
+
|
15 |
+
def _convert_target_to_string(t: Any) -> str:
|
16 |
+
"""
|
17 |
+
Inverse of ``locate()``.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
t: any object with ``__module__`` and ``__qualname__``
|
21 |
+
"""
|
22 |
+
module, qualname = t.__module__, t.__qualname__
|
23 |
+
|
24 |
+
# Compress the path to this object, e.g. ``module.submodule._impl.class``
|
25 |
+
# may become ``module.submodule.class``, if the later also resolves to the same
|
26 |
+
# object. This simplifies the string, and also is less affected by moving the
|
27 |
+
# class implementation.
|
28 |
+
module_parts = module.split(".")
|
29 |
+
for k in range(1, len(module_parts)):
|
30 |
+
prefix = ".".join(module_parts[:k])
|
31 |
+
candidate = f"{prefix}.{qualname}"
|
32 |
+
try:
|
33 |
+
if locate(candidate) is t:
|
34 |
+
return candidate
|
35 |
+
except ImportError:
|
36 |
+
pass
|
37 |
+
return f"{module}.{qualname}"
|
38 |
+
|
39 |
+
|
40 |
+
def locate(name: str) -> Any:
|
41 |
+
"""
|
42 |
+
Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
|
43 |
+
such as "module.submodule.class_name".
|
44 |
+
|
45 |
+
Raise Exception if it cannot be found.
|
46 |
+
"""
|
47 |
+
obj = pydoc.locate(name)
|
48 |
+
|
49 |
+
# Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
|
50 |
+
# by pydoc.locate. Try a private function from hydra.
|
51 |
+
if obj is None:
|
52 |
+
try:
|
53 |
+
# from hydra.utils import get_method - will print many errors
|
54 |
+
from hydra.utils import _locate
|
55 |
+
except ImportError as e:
|
56 |
+
raise ImportError(f"Cannot dynamically locate object {name}!") from e
|
57 |
+
else:
|
58 |
+
obj = _locate(name) # it raises if fails
|
59 |
+
|
60 |
+
return obj
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/serialize.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# import cloudpickle
|
3 |
+
|
4 |
+
|
5 |
+
class PicklableWrapper(object):
|
6 |
+
"""
|
7 |
+
Wrap an object to make it more picklable, note that it uses
|
8 |
+
heavy weight serialization libraries that are slower than pickle.
|
9 |
+
It's best to use it only on closures (which are usually not picklable).
|
10 |
+
|
11 |
+
This is a simplified version of
|
12 |
+
https://github.com/joblib/joblib/blob/master/joblib/externals/loky/cloudpickle_wrapper.py
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, obj):
|
16 |
+
while isinstance(obj, PicklableWrapper):
|
17 |
+
# Wrapping an object twice is no-op
|
18 |
+
obj = obj._obj
|
19 |
+
self._obj = obj
|
20 |
+
|
21 |
+
# def __reduce__(self):
|
22 |
+
# s = cloudpickle.dumps(self._obj)
|
23 |
+
# return cloudpickle.loads, (s,)
|
24 |
+
|
25 |
+
def __call__(self, *args, **kwargs):
|
26 |
+
return self._obj(*args, **kwargs)
|
27 |
+
|
28 |
+
def __getattr__(self, attr):
|
29 |
+
# Ensure that the wrapped object can be used seamlessly as the previous object.
|
30 |
+
if attr not in ["_obj"]:
|
31 |
+
return getattr(self._obj, attr)
|
32 |
+
return getattr(self, attr)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/testing.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import io
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import tempfile
|
7 |
+
import unittest
|
8 |
+
from typing import Callable
|
9 |
+
import torch
|
10 |
+
import torch.onnx.symbolic_helper as sym_help
|
11 |
+
from packaging import version
|
12 |
+
from torch._C import ListType
|
13 |
+
from torch.onnx import register_custom_op_symbolic
|
14 |
+
|
15 |
+
from annotator.oneformer.detectron2 import model_zoo
|
16 |
+
from annotator.oneformer.detectron2.config import CfgNode, LazyConfig, instantiate
|
17 |
+
from annotator.oneformer.detectron2.data import DatasetCatalog
|
18 |
+
from annotator.oneformer.detectron2.data.detection_utils import read_image
|
19 |
+
from annotator.oneformer.detectron2.modeling import build_model
|
20 |
+
from annotator.oneformer.detectron2.structures import Boxes, Instances, ROIMasks
|
21 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
22 |
+
|
23 |
+
|
24 |
+
"""
|
25 |
+
Internal utilities for tests. Don't use except for writing tests.
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
def get_model_no_weights(config_path):
|
30 |
+
"""
|
31 |
+
Like model_zoo.get, but do not load any weights (even pretrained)
|
32 |
+
"""
|
33 |
+
cfg = model_zoo.get_config(config_path)
|
34 |
+
if isinstance(cfg, CfgNode):
|
35 |
+
if not torch.cuda.is_available():
|
36 |
+
cfg.MODEL.DEVICE = "cpu"
|
37 |
+
return build_model(cfg)
|
38 |
+
else:
|
39 |
+
return instantiate(cfg.model)
|
40 |
+
|
41 |
+
|
42 |
+
def random_boxes(num_boxes, max_coord=100, device="cpu"):
|
43 |
+
"""
|
44 |
+
Create a random Nx4 boxes tensor, with coordinates < max_coord.
|
45 |
+
"""
|
46 |
+
boxes = torch.rand(num_boxes, 4, device=device) * (max_coord * 0.5)
|
47 |
+
boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression
|
48 |
+
# Note: the implementation of this function in torchvision is:
|
49 |
+
# boxes[:, 2:] += torch.rand(N, 2) * 100
|
50 |
+
# but it does not guarantee non-negative widths/heights constraints:
|
51 |
+
# boxes[:, 2] >= boxes[:, 0] and boxes[:, 3] >= boxes[:, 1]:
|
52 |
+
boxes[:, 2:] += boxes[:, :2]
|
53 |
+
return boxes
|
54 |
+
|
55 |
+
|
56 |
+
def get_sample_coco_image(tensor=True):
|
57 |
+
"""
|
58 |
+
Args:
|
59 |
+
tensor (bool): if True, returns 3xHxW tensor.
|
60 |
+
else, returns a HxWx3 numpy array.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
an image, in BGR color.
|
64 |
+
"""
|
65 |
+
try:
|
66 |
+
file_name = DatasetCatalog.get("coco_2017_val_100")[0]["file_name"]
|
67 |
+
if not PathManager.exists(file_name):
|
68 |
+
raise FileNotFoundError()
|
69 |
+
except IOError:
|
70 |
+
# for public CI to run
|
71 |
+
file_name = PathManager.get_local_path(
|
72 |
+
"http://images.cocodataset.org/train2017/000000000009.jpg"
|
73 |
+
)
|
74 |
+
ret = read_image(file_name, format="BGR")
|
75 |
+
if tensor:
|
76 |
+
ret = torch.from_numpy(np.ascontiguousarray(ret.transpose(2, 0, 1)))
|
77 |
+
return ret
|
78 |
+
|
79 |
+
|
80 |
+
def convert_scripted_instances(instances):
|
81 |
+
"""
|
82 |
+
Convert a scripted Instances object to a regular :class:`Instances` object
|
83 |
+
"""
|
84 |
+
assert hasattr(
|
85 |
+
instances, "image_size"
|
86 |
+
), f"Expect an Instances object, but got {type(instances)}!"
|
87 |
+
ret = Instances(instances.image_size)
|
88 |
+
for name in instances._field_names:
|
89 |
+
val = getattr(instances, "_" + name, None)
|
90 |
+
if val is not None:
|
91 |
+
ret.set(name, val)
|
92 |
+
return ret
|
93 |
+
|
94 |
+
|
95 |
+
def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False):
|
96 |
+
"""
|
97 |
+
Args:
|
98 |
+
input, other (Instances):
|
99 |
+
size_as_tensor: compare image_size of the Instances as tensors (instead of tuples).
|
100 |
+
Useful for comparing outputs of tracing.
|
101 |
+
"""
|
102 |
+
if not isinstance(input, Instances):
|
103 |
+
input = convert_scripted_instances(input)
|
104 |
+
if not isinstance(other, Instances):
|
105 |
+
other = convert_scripted_instances(other)
|
106 |
+
|
107 |
+
if not msg:
|
108 |
+
msg = "Two Instances are different! "
|
109 |
+
else:
|
110 |
+
msg = msg.rstrip() + " "
|
111 |
+
|
112 |
+
size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!"
|
113 |
+
if size_as_tensor:
|
114 |
+
assert torch.equal(
|
115 |
+
torch.tensor(input.image_size), torch.tensor(other.image_size)
|
116 |
+
), size_error_msg
|
117 |
+
else:
|
118 |
+
assert input.image_size == other.image_size, size_error_msg
|
119 |
+
fields = sorted(input.get_fields().keys())
|
120 |
+
fields_other = sorted(other.get_fields().keys())
|
121 |
+
assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!"
|
122 |
+
|
123 |
+
for f in fields:
|
124 |
+
val1, val2 = input.get(f), other.get(f)
|
125 |
+
if isinstance(val1, (Boxes, ROIMasks)):
|
126 |
+
# boxes in the range of O(100) and can have a larger tolerance
|
127 |
+
assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), (
|
128 |
+
msg + f"Field {f} differs too much!"
|
129 |
+
)
|
130 |
+
elif isinstance(val1, torch.Tensor):
|
131 |
+
if val1.dtype.is_floating_point:
|
132 |
+
mag = torch.abs(val1).max().cpu().item()
|
133 |
+
assert torch.allclose(val1, val2, atol=mag * rtol), (
|
134 |
+
msg + f"Field {f} differs too much!"
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
assert torch.equal(val1, val2), msg + f"Field {f} is different!"
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Don't know how to compare type {type(val1)}")
|
140 |
+
|
141 |
+
|
142 |
+
def reload_script_model(module):
|
143 |
+
"""
|
144 |
+
Save a jit module and load it back.
|
145 |
+
Similar to the `getExportImportCopy` function in torch/testing/
|
146 |
+
"""
|
147 |
+
buffer = io.BytesIO()
|
148 |
+
torch.jit.save(module, buffer)
|
149 |
+
buffer.seek(0)
|
150 |
+
return torch.jit.load(buffer)
|
151 |
+
|
152 |
+
|
153 |
+
def reload_lazy_config(cfg):
|
154 |
+
"""
|
155 |
+
Save an object by LazyConfig.save and load it back.
|
156 |
+
This is used to test that a config still works the same after
|
157 |
+
serialization/deserialization.
|
158 |
+
"""
|
159 |
+
with tempfile.TemporaryDirectory(prefix="detectron2") as d:
|
160 |
+
fname = os.path.join(d, "d2_cfg_test.yaml")
|
161 |
+
LazyConfig.save(cfg, fname)
|
162 |
+
return LazyConfig.load(fname)
|
163 |
+
|
164 |
+
|
165 |
+
def min_torch_version(min_version: str) -> bool:
|
166 |
+
"""
|
167 |
+
Returns True when torch's version is at least `min_version`.
|
168 |
+
"""
|
169 |
+
try:
|
170 |
+
import torch
|
171 |
+
except ImportError:
|
172 |
+
return False
|
173 |
+
|
174 |
+
installed_version = version.parse(torch.__version__.split("+")[0])
|
175 |
+
min_version = version.parse(min_version)
|
176 |
+
return installed_version >= min_version
|
177 |
+
|
178 |
+
|
179 |
+
def has_dynamic_axes(onnx_model):
|
180 |
+
"""
|
181 |
+
Return True when all ONNX input/output have only dynamic axes for all ranks
|
182 |
+
"""
|
183 |
+
return all(
|
184 |
+
not dim.dim_param.isnumeric()
|
185 |
+
for inp in onnx_model.graph.input
|
186 |
+
for dim in inp.type.tensor_type.shape.dim
|
187 |
+
) and all(
|
188 |
+
not dim.dim_param.isnumeric()
|
189 |
+
for out in onnx_model.graph.output
|
190 |
+
for dim in out.type.tensor_type.shape.dim
|
191 |
+
)
|
192 |
+
|
193 |
+
|
194 |
+
def register_custom_op_onnx_export(
|
195 |
+
opname: str, symbolic_fn: Callable, opset_version: int, min_version: str
|
196 |
+
) -> None:
|
197 |
+
"""
|
198 |
+
Register `symbolic_fn` as PyTorch's symbolic `opname`-`opset_version` for ONNX export.
|
199 |
+
The registration is performed only when current PyTorch's version is < `min_version.`
|
200 |
+
IMPORTANT: symbolic must be manually unregistered after the caller function returns
|
201 |
+
"""
|
202 |
+
if min_torch_version(min_version):
|
203 |
+
return
|
204 |
+
register_custom_op_symbolic(opname, symbolic_fn, opset_version)
|
205 |
+
print(f"_register_custom_op_onnx_export({opname}, {opset_version}) succeeded.")
|
206 |
+
|
207 |
+
|
208 |
+
def unregister_custom_op_onnx_export(opname: str, opset_version: int, min_version: str) -> None:
|
209 |
+
"""
|
210 |
+
Unregister PyTorch's symbolic `opname`-`opset_version` for ONNX export.
|
211 |
+
The un-registration is performed only when PyTorch's version is < `min_version`
|
212 |
+
IMPORTANT: The symbolic must have been manually registered by the caller, otherwise
|
213 |
+
the incorrect symbolic may be unregistered instead.
|
214 |
+
"""
|
215 |
+
|
216 |
+
# TODO: _unregister_custom_op_symbolic is introduced PyTorch>=1.10
|
217 |
+
# Remove after PyTorch 1.10+ is used by ALL detectron2's CI
|
218 |
+
try:
|
219 |
+
from torch.onnx import unregister_custom_op_symbolic as _unregister_custom_op_symbolic
|
220 |
+
except ImportError:
|
221 |
+
|
222 |
+
def _unregister_custom_op_symbolic(symbolic_name, opset_version):
|
223 |
+
import torch.onnx.symbolic_registry as sym_registry
|
224 |
+
from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
|
225 |
+
|
226 |
+
def _get_ns_op_name_from_custom_op(symbolic_name):
|
227 |
+
try:
|
228 |
+
from torch.onnx.utils import get_ns_op_name_from_custom_op
|
229 |
+
|
230 |
+
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
|
231 |
+
except ImportError as import_error:
|
232 |
+
if not bool(
|
233 |
+
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
|
234 |
+
):
|
235 |
+
raise ValueError(
|
236 |
+
f"Invalid symbolic name {symbolic_name}. Must be `domain::name`"
|
237 |
+
) from import_error
|
238 |
+
|
239 |
+
ns, op_name = symbolic_name.split("::")
|
240 |
+
if ns == "onnx":
|
241 |
+
raise ValueError(f"{ns} domain cannot be modified.") from import_error
|
242 |
+
|
243 |
+
if ns == "aten":
|
244 |
+
ns = ""
|
245 |
+
|
246 |
+
return ns, op_name
|
247 |
+
|
248 |
+
def _unregister_op(opname: str, domain: str, version: int):
|
249 |
+
try:
|
250 |
+
sym_registry.unregister_op(op_name, ns, ver)
|
251 |
+
except AttributeError as attribute_error:
|
252 |
+
if sym_registry.is_registered_op(opname, domain, version):
|
253 |
+
del sym_registry._registry[(domain, version)][opname]
|
254 |
+
if not sym_registry._registry[(domain, version)]:
|
255 |
+
del sym_registry._registry[(domain, version)]
|
256 |
+
else:
|
257 |
+
raise RuntimeError(
|
258 |
+
f"The opname {opname} is not registered."
|
259 |
+
) from attribute_error
|
260 |
+
|
261 |
+
ns, op_name = _get_ns_op_name_from_custom_op(symbolic_name)
|
262 |
+
for ver in _onnx_stable_opsets + [_onnx_main_opset]:
|
263 |
+
if ver >= opset_version:
|
264 |
+
_unregister_op(op_name, ns, ver)
|
265 |
+
|
266 |
+
if min_torch_version(min_version):
|
267 |
+
return
|
268 |
+
_unregister_custom_op_symbolic(opname, opset_version)
|
269 |
+
print(f"_unregister_custom_op_onnx_export({opname}, {opset_version}) succeeded.")
|
270 |
+
|
271 |
+
|
272 |
+
skipIfOnCPUCI = unittest.skipIf(
|
273 |
+
os.environ.get("CI") and not torch.cuda.is_available(),
|
274 |
+
"The test is too slow on CPUs and will be executed on CircleCI's GPU jobs.",
|
275 |
+
)
|
276 |
+
|
277 |
+
|
278 |
+
def skipIfUnsupportedMinOpsetVersion(min_opset_version, current_opset_version=None):
|
279 |
+
"""
|
280 |
+
Skips tests for ONNX Opset versions older than min_opset_version.
|
281 |
+
"""
|
282 |
+
|
283 |
+
def skip_dec(func):
|
284 |
+
def wrapper(self):
|
285 |
+
try:
|
286 |
+
opset_version = self.opset_version
|
287 |
+
except AttributeError:
|
288 |
+
opset_version = current_opset_version
|
289 |
+
if opset_version < min_opset_version:
|
290 |
+
raise unittest.SkipTest(
|
291 |
+
f"Unsupported opset_version {opset_version}"
|
292 |
+
f", required is {min_opset_version}"
|
293 |
+
)
|
294 |
+
return func(self)
|
295 |
+
|
296 |
+
return wrapper
|
297 |
+
|
298 |
+
return skip_dec
|
299 |
+
|
300 |
+
|
301 |
+
def skipIfUnsupportedMinTorchVersion(min_version):
|
302 |
+
"""
|
303 |
+
Skips tests for PyTorch versions older than min_version.
|
304 |
+
"""
|
305 |
+
reason = f"module 'torch' has __version__ {torch.__version__}" f", required is: {min_version}"
|
306 |
+
return unittest.skipIf(not min_torch_version(min_version), reason)
|
307 |
+
|
308 |
+
|
309 |
+
# TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI
|
310 |
+
def _pytorch1111_symbolic_opset9_to(g, self, *args):
|
311 |
+
"""aten::to() symbolic that must be used for testing with PyTorch < 1.11.1."""
|
312 |
+
|
313 |
+
def is_aten_to_device_only(args):
|
314 |
+
if len(args) == 4:
|
315 |
+
# aten::to(Tensor, Device, bool, bool, memory_format)
|
316 |
+
return (
|
317 |
+
args[0].node().kind() == "prim::device"
|
318 |
+
or args[0].type().isSubtypeOf(ListType.ofInts())
|
319 |
+
or (
|
320 |
+
sym_help._is_value(args[0])
|
321 |
+
and args[0].node().kind() == "onnx::Constant"
|
322 |
+
and isinstance(args[0].node()["value"], str)
|
323 |
+
)
|
324 |
+
)
|
325 |
+
elif len(args) == 5:
|
326 |
+
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
|
327 |
+
# When dtype is None, this is a aten::to(device) call
|
328 |
+
dtype = sym_help._get_const(args[1], "i", "dtype")
|
329 |
+
return dtype is None
|
330 |
+
elif len(args) in (6, 7):
|
331 |
+
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format)
|
332 |
+
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format)
|
333 |
+
# When dtype is None, this is a aten::to(device) call
|
334 |
+
dtype = sym_help._get_const(args[0], "i", "dtype")
|
335 |
+
return dtype is None
|
336 |
+
return False
|
337 |
+
|
338 |
+
# ONNX doesn't have a concept of a device, so we ignore device-only casts
|
339 |
+
if is_aten_to_device_only(args):
|
340 |
+
return self
|
341 |
+
|
342 |
+
if len(args) == 4:
|
343 |
+
# TestONNXRuntime::test_ones_bool shows args[0] of aten::to can be onnx::Constant[Tensor]
|
344 |
+
# In this case, the constant value is a tensor not int,
|
345 |
+
# so sym_help._maybe_get_const(args[0], 'i') would not work.
|
346 |
+
dtype = args[0]
|
347 |
+
if sym_help._is_value(args[0]) and args[0].node().kind() == "onnx::Constant":
|
348 |
+
tval = args[0].node()["value"]
|
349 |
+
if isinstance(tval, torch.Tensor):
|
350 |
+
if len(tval.shape) == 0:
|
351 |
+
tval = tval.item()
|
352 |
+
dtype = int(tval)
|
353 |
+
else:
|
354 |
+
dtype = tval
|
355 |
+
|
356 |
+
if sym_help._is_value(dtype) or isinstance(dtype, torch.Tensor):
|
357 |
+
# aten::to(Tensor, Tensor, bool, bool, memory_format)
|
358 |
+
dtype = args[0].type().scalarType()
|
359 |
+
return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[dtype])
|
360 |
+
else:
|
361 |
+
# aten::to(Tensor, ScalarType, bool, bool, memory_format)
|
362 |
+
# memory_format is ignored
|
363 |
+
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
|
364 |
+
elif len(args) == 5:
|
365 |
+
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
|
366 |
+
dtype = sym_help._get_const(args[1], "i", "dtype")
|
367 |
+
# memory_format is ignored
|
368 |
+
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
|
369 |
+
elif len(args) == 6:
|
370 |
+
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format)
|
371 |
+
dtype = sym_help._get_const(args[0], "i", "dtype")
|
372 |
+
# Layout, device and memory_format are ignored
|
373 |
+
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
|
374 |
+
elif len(args) == 7:
|
375 |
+
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format)
|
376 |
+
dtype = sym_help._get_const(args[0], "i", "dtype")
|
377 |
+
# Layout, device and memory_format are ignored
|
378 |
+
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
|
379 |
+
else:
|
380 |
+
return sym_help._onnx_unsupported("Unknown aten::to signature")
|
381 |
+
|
382 |
+
|
383 |
+
# TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI
|
384 |
+
def _pytorch1111_symbolic_opset9_repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
385 |
+
|
386 |
+
# from torch.onnx.symbolic_helper import ScalarType
|
387 |
+
from torch.onnx.symbolic_opset9 import expand, unsqueeze
|
388 |
+
|
389 |
+
input = self
|
390 |
+
# if dim is None flatten
|
391 |
+
# By default, use the flattened input array, and return a flat output array
|
392 |
+
if sym_help._is_none(dim):
|
393 |
+
input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
|
394 |
+
dim = 0
|
395 |
+
else:
|
396 |
+
dim = sym_help._maybe_get_scalar(dim)
|
397 |
+
|
398 |
+
repeats_dim = sym_help._get_tensor_rank(repeats)
|
399 |
+
repeats_sizes = sym_help._get_tensor_sizes(repeats)
|
400 |
+
input_sizes = sym_help._get_tensor_sizes(input)
|
401 |
+
if repeats_dim is None:
|
402 |
+
raise RuntimeError(
|
403 |
+
"Unsupported: ONNX export of repeat_interleave for unknown " "repeats rank."
|
404 |
+
)
|
405 |
+
if repeats_sizes is None:
|
406 |
+
raise RuntimeError(
|
407 |
+
"Unsupported: ONNX export of repeat_interleave for unknown " "repeats size."
|
408 |
+
)
|
409 |
+
if input_sizes is None:
|
410 |
+
raise RuntimeError(
|
411 |
+
"Unsupported: ONNX export of repeat_interleave for unknown " "input size."
|
412 |
+
)
|
413 |
+
|
414 |
+
input_sizes_temp = input_sizes.copy()
|
415 |
+
for idx, input_size in enumerate(input_sizes):
|
416 |
+
if input_size is None:
|
417 |
+
input_sizes[idx], input_sizes_temp[idx] = 0, -1
|
418 |
+
|
419 |
+
# Cases where repeats is an int or single value tensor
|
420 |
+
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
|
421 |
+
if not sym_help._is_tensor(repeats):
|
422 |
+
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
423 |
+
if input_sizes[dim] == 0:
|
424 |
+
return sym_help._onnx_opset_unsupported_detailed(
|
425 |
+
"repeat_interleave",
|
426 |
+
9,
|
427 |
+
13,
|
428 |
+
"Unsupported along dimension with unknown input size",
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
reps = input_sizes[dim]
|
432 |
+
repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None)
|
433 |
+
|
434 |
+
# Cases where repeats is a 1 dim Tensor
|
435 |
+
elif repeats_dim == 1:
|
436 |
+
if input_sizes[dim] == 0:
|
437 |
+
return sym_help._onnx_opset_unsupported_detailed(
|
438 |
+
"repeat_interleave",
|
439 |
+
9,
|
440 |
+
13,
|
441 |
+
"Unsupported along dimension with unknown input size",
|
442 |
+
)
|
443 |
+
if repeats_sizes[0] is None:
|
444 |
+
return sym_help._onnx_opset_unsupported_detailed(
|
445 |
+
"repeat_interleave", 9, 13, "Unsupported for cases with dynamic repeats"
|
446 |
+
)
|
447 |
+
assert (
|
448 |
+
repeats_sizes[0] == input_sizes[dim]
|
449 |
+
), "repeats must have the same size as input along dim"
|
450 |
+
reps = repeats_sizes[0]
|
451 |
+
else:
|
452 |
+
raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
|
453 |
+
|
454 |
+
final_splits = list()
|
455 |
+
r_splits = sym_help._repeat_interleave_split_helper(g, repeats, reps, 0)
|
456 |
+
if isinstance(r_splits, torch._C.Value):
|
457 |
+
r_splits = [r_splits]
|
458 |
+
i_splits = sym_help._repeat_interleave_split_helper(g, input, reps, dim)
|
459 |
+
if isinstance(i_splits, torch._C.Value):
|
460 |
+
i_splits = [i_splits]
|
461 |
+
input_sizes[dim], input_sizes_temp[dim] = -1, 1
|
462 |
+
for idx, r_split in enumerate(r_splits):
|
463 |
+
i_split = unsqueeze(g, i_splits[idx], dim + 1)
|
464 |
+
r_concat = [
|
465 |
+
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])),
|
466 |
+
r_split,
|
467 |
+
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])),
|
468 |
+
]
|
469 |
+
r_concat = g.op("Concat", *r_concat, axis_i=0)
|
470 |
+
i_split = expand(g, i_split, r_concat, None)
|
471 |
+
i_split = sym_help._reshape_helper(
|
472 |
+
g,
|
473 |
+
i_split,
|
474 |
+
g.op("Constant", value_t=torch.LongTensor(input_sizes)),
|
475 |
+
allowzero=0,
|
476 |
+
)
|
477 |
+
final_splits.append(i_split)
|
478 |
+
return g.op("Concat", *final_splits, axis_i=dim)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/tracing.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from annotator.oneformer.detectron2.utils.env import TORCH_VERSION
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current
|
8 |
+
|
9 |
+
tracing_current_exists = True
|
10 |
+
except ImportError:
|
11 |
+
tracing_current_exists = False
|
12 |
+
|
13 |
+
try:
|
14 |
+
from torch.fx._symbolic_trace import _orig_module_call
|
15 |
+
|
16 |
+
tracing_legacy_exists = True
|
17 |
+
except ImportError:
|
18 |
+
tracing_legacy_exists = False
|
19 |
+
|
20 |
+
|
21 |
+
@torch.jit.ignore
|
22 |
+
def is_fx_tracing_legacy() -> bool:
|
23 |
+
"""
|
24 |
+
Returns a bool indicating whether torch.fx is currently symbolically tracing a module.
|
25 |
+
Can be useful for gating module logic that is incompatible with symbolic tracing.
|
26 |
+
"""
|
27 |
+
return torch.nn.Module.__call__ is not _orig_module_call
|
28 |
+
|
29 |
+
|
30 |
+
@torch.jit.ignore
|
31 |
+
def is_fx_tracing() -> bool:
|
32 |
+
"""Returns whether execution is currently in
|
33 |
+
Torch FX tracing mode"""
|
34 |
+
if TORCH_VERSION >= (1, 10) and tracing_current_exists:
|
35 |
+
return is_fx_tracing_current()
|
36 |
+
elif tracing_legacy_exists:
|
37 |
+
return is_fx_tracing_legacy()
|
38 |
+
else:
|
39 |
+
# Can't find either current or legacy tracing indication code.
|
40 |
+
# Enabling this assert_fx_safe() call regardless of tracing status.
|
41 |
+
return False
|
42 |
+
|
43 |
+
|
44 |
+
@torch.jit.ignore
|
45 |
+
def assert_fx_safe(condition: bool, message: str) -> torch.Tensor:
|
46 |
+
"""An FX-tracing safe version of assert.
|
47 |
+
Avoids erroneous type assertion triggering when types are masked inside
|
48 |
+
an fx.proxy.Proxy object during tracing.
|
49 |
+
Args: condition - either a boolean expression or a string representing
|
50 |
+
the condition to test. If this assert triggers an exception when tracing
|
51 |
+
due to dynamic control flow, try encasing the expression in quotation
|
52 |
+
marks and supplying it as a string."""
|
53 |
+
# Must return a concrete tensor for compatibility with PyTorch <=1.8.
|
54 |
+
# If <=1.8 compatibility is not needed, return type can be converted to None
|
55 |
+
if not is_fx_tracing():
|
56 |
+
try:
|
57 |
+
if isinstance(condition, str):
|
58 |
+
caller_frame = inspect.currentframe().f_back
|
59 |
+
torch._assert(
|
60 |
+
eval(condition, caller_frame.f_globals, caller_frame.f_locals), message
|
61 |
+
)
|
62 |
+
return torch.ones(1)
|
63 |
+
else:
|
64 |
+
torch._assert(condition, message)
|
65 |
+
return torch.ones(1)
|
66 |
+
except torch.fx.proxy.TraceError as e:
|
67 |
+
print(
|
68 |
+
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
|
69 |
+
+ str(e)
|
70 |
+
)
|
71 |
+
return torch.zeros(1)
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/video_visualizer.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import numpy as np
|
3 |
+
from typing import List
|
4 |
+
import annotator.oneformer.pycocotools.mask as mask_util
|
5 |
+
|
6 |
+
from annotator.oneformer.detectron2.structures import Instances
|
7 |
+
from annotator.oneformer.detectron2.utils.visualizer import (
|
8 |
+
ColorMode,
|
9 |
+
Visualizer,
|
10 |
+
_create_text_labels,
|
11 |
+
_PanopticPrediction,
|
12 |
+
)
|
13 |
+
|
14 |
+
from .colormap import random_color, random_colors
|
15 |
+
|
16 |
+
|
17 |
+
class _DetectedInstance:
|
18 |
+
"""
|
19 |
+
Used to store data about detected objects in video frame,
|
20 |
+
in order to transfer color to objects in the future frames.
|
21 |
+
|
22 |
+
Attributes:
|
23 |
+
label (int):
|
24 |
+
bbox (tuple[float]):
|
25 |
+
mask_rle (dict):
|
26 |
+
color (tuple[float]): RGB colors in range (0, 1)
|
27 |
+
ttl (int): time-to-live for the instance. For example, if ttl=2,
|
28 |
+
the instance color can be transferred to objects in the next two frames.
|
29 |
+
"""
|
30 |
+
|
31 |
+
__slots__ = ["label", "bbox", "mask_rle", "color", "ttl"]
|
32 |
+
|
33 |
+
def __init__(self, label, bbox, mask_rle, color, ttl):
|
34 |
+
self.label = label
|
35 |
+
self.bbox = bbox
|
36 |
+
self.mask_rle = mask_rle
|
37 |
+
self.color = color
|
38 |
+
self.ttl = ttl
|
39 |
+
|
40 |
+
|
41 |
+
class VideoVisualizer:
|
42 |
+
def __init__(self, metadata, instance_mode=ColorMode.IMAGE):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
metadata (MetadataCatalog): image metadata.
|
46 |
+
"""
|
47 |
+
self.metadata = metadata
|
48 |
+
self._old_instances = []
|
49 |
+
assert instance_mode in [
|
50 |
+
ColorMode.IMAGE,
|
51 |
+
ColorMode.IMAGE_BW,
|
52 |
+
], "Other mode not supported yet."
|
53 |
+
self._instance_mode = instance_mode
|
54 |
+
self._max_num_instances = self.metadata.get("max_num_instances", 74)
|
55 |
+
self._assigned_colors = {}
|
56 |
+
self._color_pool = random_colors(self._max_num_instances, rgb=True, maximum=1)
|
57 |
+
self._color_idx_set = set(range(len(self._color_pool)))
|
58 |
+
|
59 |
+
def draw_instance_predictions(self, frame, predictions):
|
60 |
+
"""
|
61 |
+
Draw instance-level prediction results on an image.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255].
|
65 |
+
predictions (Instances): the output of an instance detection/segmentation
|
66 |
+
model. Following fields will be used to draw:
|
67 |
+
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
output (VisImage): image object with visualizations.
|
71 |
+
"""
|
72 |
+
frame_visualizer = Visualizer(frame, self.metadata)
|
73 |
+
num_instances = len(predictions)
|
74 |
+
if num_instances == 0:
|
75 |
+
return frame_visualizer.output
|
76 |
+
|
77 |
+
boxes = predictions.pred_boxes.tensor.numpy() if predictions.has("pred_boxes") else None
|
78 |
+
scores = predictions.scores if predictions.has("scores") else None
|
79 |
+
classes = predictions.pred_classes.numpy() if predictions.has("pred_classes") else None
|
80 |
+
keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
|
81 |
+
colors = predictions.COLOR if predictions.has("COLOR") else [None] * len(predictions)
|
82 |
+
periods = predictions.ID_period if predictions.has("ID_period") else None
|
83 |
+
period_threshold = self.metadata.get("period_threshold", 0)
|
84 |
+
visibilities = (
|
85 |
+
[True] * len(predictions)
|
86 |
+
if periods is None
|
87 |
+
else [x > period_threshold for x in periods]
|
88 |
+
)
|
89 |
+
|
90 |
+
if predictions.has("pred_masks"):
|
91 |
+
masks = predictions.pred_masks
|
92 |
+
# mask IOU is not yet enabled
|
93 |
+
# masks_rles = mask_util.encode(np.asarray(masks.permute(1, 2, 0), order="F"))
|
94 |
+
# assert len(masks_rles) == num_instances
|
95 |
+
else:
|
96 |
+
masks = None
|
97 |
+
|
98 |
+
if not predictions.has("COLOR"):
|
99 |
+
if predictions.has("ID"):
|
100 |
+
colors = self._assign_colors_by_id(predictions)
|
101 |
+
else:
|
102 |
+
# ToDo: clean old assign color method and use a default tracker to assign id
|
103 |
+
detected = [
|
104 |
+
_DetectedInstance(classes[i], boxes[i], mask_rle=None, color=colors[i], ttl=8)
|
105 |
+
for i in range(num_instances)
|
106 |
+
]
|
107 |
+
colors = self._assign_colors(detected)
|
108 |
+
|
109 |
+
labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
|
110 |
+
|
111 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
112 |
+
# any() returns uint8 tensor
|
113 |
+
frame_visualizer.output.reset_image(
|
114 |
+
frame_visualizer._create_grayscale_image(
|
115 |
+
(masks.any(dim=0) > 0).numpy() if masks is not None else None
|
116 |
+
)
|
117 |
+
)
|
118 |
+
alpha = 0.3
|
119 |
+
else:
|
120 |
+
alpha = 0.5
|
121 |
+
|
122 |
+
labels = (
|
123 |
+
None
|
124 |
+
if labels is None
|
125 |
+
else [y[0] for y in filter(lambda x: x[1], zip(labels, visibilities))]
|
126 |
+
) # noqa
|
127 |
+
assigned_colors = (
|
128 |
+
None
|
129 |
+
if colors is None
|
130 |
+
else [y[0] for y in filter(lambda x: x[1], zip(colors, visibilities))]
|
131 |
+
) # noqa
|
132 |
+
frame_visualizer.overlay_instances(
|
133 |
+
boxes=None if masks is not None else boxes[visibilities], # boxes are a bit distracting
|
134 |
+
masks=None if masks is None else masks[visibilities],
|
135 |
+
labels=labels,
|
136 |
+
keypoints=None if keypoints is None else keypoints[visibilities],
|
137 |
+
assigned_colors=assigned_colors,
|
138 |
+
alpha=alpha,
|
139 |
+
)
|
140 |
+
|
141 |
+
return frame_visualizer.output
|
142 |
+
|
143 |
+
def draw_sem_seg(self, frame, sem_seg, area_threshold=None):
|
144 |
+
"""
|
145 |
+
Args:
|
146 |
+
sem_seg (ndarray or Tensor): semantic segmentation of shape (H, W),
|
147 |
+
each value is the integer label.
|
148 |
+
area_threshold (Optional[int]): only draw segmentations larger than the threshold
|
149 |
+
"""
|
150 |
+
# don't need to do anything special
|
151 |
+
frame_visualizer = Visualizer(frame, self.metadata)
|
152 |
+
frame_visualizer.draw_sem_seg(sem_seg, area_threshold=None)
|
153 |
+
return frame_visualizer.output
|
154 |
+
|
155 |
+
def draw_panoptic_seg_predictions(
|
156 |
+
self, frame, panoptic_seg, segments_info, area_threshold=None, alpha=0.5
|
157 |
+
):
|
158 |
+
frame_visualizer = Visualizer(frame, self.metadata)
|
159 |
+
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
|
160 |
+
|
161 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
162 |
+
frame_visualizer.output.reset_image(
|
163 |
+
frame_visualizer._create_grayscale_image(pred.non_empty_mask())
|
164 |
+
)
|
165 |
+
|
166 |
+
# draw mask for all semantic segments first i.e. "stuff"
|
167 |
+
for mask, sinfo in pred.semantic_masks():
|
168 |
+
category_idx = sinfo["category_id"]
|
169 |
+
try:
|
170 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
|
171 |
+
except AttributeError:
|
172 |
+
mask_color = None
|
173 |
+
|
174 |
+
frame_visualizer.draw_binary_mask(
|
175 |
+
mask,
|
176 |
+
color=mask_color,
|
177 |
+
text=self.metadata.stuff_classes[category_idx],
|
178 |
+
alpha=alpha,
|
179 |
+
area_threshold=area_threshold,
|
180 |
+
)
|
181 |
+
|
182 |
+
all_instances = list(pred.instance_masks())
|
183 |
+
if len(all_instances) == 0:
|
184 |
+
return frame_visualizer.output
|
185 |
+
# draw mask for all instances second
|
186 |
+
masks, sinfo = list(zip(*all_instances))
|
187 |
+
num_instances = len(masks)
|
188 |
+
masks_rles = mask_util.encode(
|
189 |
+
np.asarray(np.asarray(masks).transpose(1, 2, 0), dtype=np.uint8, order="F")
|
190 |
+
)
|
191 |
+
assert len(masks_rles) == num_instances
|
192 |
+
|
193 |
+
category_ids = [x["category_id"] for x in sinfo]
|
194 |
+
detected = [
|
195 |
+
_DetectedInstance(category_ids[i], bbox=None, mask_rle=masks_rles[i], color=None, ttl=8)
|
196 |
+
for i in range(num_instances)
|
197 |
+
]
|
198 |
+
colors = self._assign_colors(detected)
|
199 |
+
labels = [self.metadata.thing_classes[k] for k in category_ids]
|
200 |
+
|
201 |
+
frame_visualizer.overlay_instances(
|
202 |
+
boxes=None,
|
203 |
+
masks=masks,
|
204 |
+
labels=labels,
|
205 |
+
keypoints=None,
|
206 |
+
assigned_colors=colors,
|
207 |
+
alpha=alpha,
|
208 |
+
)
|
209 |
+
return frame_visualizer.output
|
210 |
+
|
211 |
+
def _assign_colors(self, instances):
|
212 |
+
"""
|
213 |
+
Naive tracking heuristics to assign same color to the same instance,
|
214 |
+
will update the internal state of tracked instances.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
list[tuple[float]]: list of colors.
|
218 |
+
"""
|
219 |
+
|
220 |
+
# Compute iou with either boxes or masks:
|
221 |
+
is_crowd = np.zeros((len(instances),), dtype=bool)
|
222 |
+
if instances[0].bbox is None:
|
223 |
+
assert instances[0].mask_rle is not None
|
224 |
+
# use mask iou only when box iou is None
|
225 |
+
# because box seems good enough
|
226 |
+
rles_old = [x.mask_rle for x in self._old_instances]
|
227 |
+
rles_new = [x.mask_rle for x in instances]
|
228 |
+
ious = mask_util.iou(rles_old, rles_new, is_crowd)
|
229 |
+
threshold = 0.5
|
230 |
+
else:
|
231 |
+
boxes_old = [x.bbox for x in self._old_instances]
|
232 |
+
boxes_new = [x.bbox for x in instances]
|
233 |
+
ious = mask_util.iou(boxes_old, boxes_new, is_crowd)
|
234 |
+
threshold = 0.6
|
235 |
+
if len(ious) == 0:
|
236 |
+
ious = np.zeros((len(self._old_instances), len(instances)), dtype="float32")
|
237 |
+
|
238 |
+
# Only allow matching instances of the same label:
|
239 |
+
for old_idx, old in enumerate(self._old_instances):
|
240 |
+
for new_idx, new in enumerate(instances):
|
241 |
+
if old.label != new.label:
|
242 |
+
ious[old_idx, new_idx] = 0
|
243 |
+
|
244 |
+
matched_new_per_old = np.asarray(ious).argmax(axis=1)
|
245 |
+
max_iou_per_old = np.asarray(ious).max(axis=1)
|
246 |
+
|
247 |
+
# Try to find match for each old instance:
|
248 |
+
extra_instances = []
|
249 |
+
for idx, inst in enumerate(self._old_instances):
|
250 |
+
if max_iou_per_old[idx] > threshold:
|
251 |
+
newidx = matched_new_per_old[idx]
|
252 |
+
if instances[newidx].color is None:
|
253 |
+
instances[newidx].color = inst.color
|
254 |
+
continue
|
255 |
+
# If an old instance does not match any new instances,
|
256 |
+
# keep it for the next frame in case it is just missed by the detector
|
257 |
+
inst.ttl -= 1
|
258 |
+
if inst.ttl > 0:
|
259 |
+
extra_instances.append(inst)
|
260 |
+
|
261 |
+
# Assign random color to newly-detected instances:
|
262 |
+
for inst in instances:
|
263 |
+
if inst.color is None:
|
264 |
+
inst.color = random_color(rgb=True, maximum=1)
|
265 |
+
self._old_instances = instances[:] + extra_instances
|
266 |
+
return [d.color for d in instances]
|
267 |
+
|
268 |
+
def _assign_colors_by_id(self, instances: Instances) -> List:
|
269 |
+
colors = []
|
270 |
+
untracked_ids = set(self._assigned_colors.keys())
|
271 |
+
for id in instances.ID:
|
272 |
+
if id in self._assigned_colors:
|
273 |
+
colors.append(self._color_pool[self._assigned_colors[id]])
|
274 |
+
untracked_ids.remove(id)
|
275 |
+
else:
|
276 |
+
assert (
|
277 |
+
len(self._color_idx_set) >= 1
|
278 |
+
), f"Number of id exceeded maximum, \
|
279 |
+
max = {self._max_num_instances}"
|
280 |
+
idx = self._color_idx_set.pop()
|
281 |
+
color = self._color_pool[idx]
|
282 |
+
self._assigned_colors[id] = idx
|
283 |
+
colors.append(color)
|
284 |
+
for id in untracked_ids:
|
285 |
+
self._color_idx_set.add(self._assigned_colors[id])
|
286 |
+
del self._assigned_colors[id]
|
287 |
+
return colors
|
extensions/microsoftexcel-controlnet/annotator/oneformer/detectron2/utils/visualizer.py
ADDED
@@ -0,0 +1,1267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import colorsys
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
from enum import Enum, unique
|
7 |
+
import cv2
|
8 |
+
import matplotlib as mpl
|
9 |
+
import matplotlib.colors as mplc
|
10 |
+
import matplotlib.figure as mplfigure
|
11 |
+
import annotator.oneformer.pycocotools.mask as mask_util
|
12 |
+
import torch
|
13 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from annotator.oneformer.detectron2.data import MetadataCatalog
|
17 |
+
from annotator.oneformer.detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
|
18 |
+
from annotator.oneformer.detectron2.utils.file_io import PathManager
|
19 |
+
|
20 |
+
from .colormap import random_color
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
__all__ = ["ColorMode", "VisImage", "Visualizer"]
|
25 |
+
|
26 |
+
|
27 |
+
_SMALL_OBJECT_AREA_THRESH = 1000
|
28 |
+
_LARGE_MASK_AREA_THRESH = 120000
|
29 |
+
_OFF_WHITE = (1.0, 1.0, 240.0 / 255)
|
30 |
+
_BLACK = (0, 0, 0)
|
31 |
+
_RED = (1.0, 0, 0)
|
32 |
+
|
33 |
+
_KEYPOINT_THRESHOLD = 0.05
|
34 |
+
|
35 |
+
|
36 |
+
@unique
|
37 |
+
class ColorMode(Enum):
|
38 |
+
"""
|
39 |
+
Enum of different color modes to use for instance visualizations.
|
40 |
+
"""
|
41 |
+
|
42 |
+
IMAGE = 0
|
43 |
+
"""
|
44 |
+
Picks a random color for every instance and overlay segmentations with low opacity.
|
45 |
+
"""
|
46 |
+
SEGMENTATION = 1
|
47 |
+
"""
|
48 |
+
Let instances of the same category have similar colors
|
49 |
+
(from metadata.thing_colors), and overlay them with
|
50 |
+
high opacity. This provides more attention on the quality of segmentation.
|
51 |
+
"""
|
52 |
+
IMAGE_BW = 2
|
53 |
+
"""
|
54 |
+
Same as IMAGE, but convert all areas without masks to gray-scale.
|
55 |
+
Only available for drawing per-instance mask predictions.
|
56 |
+
"""
|
57 |
+
|
58 |
+
|
59 |
+
class GenericMask:
|
60 |
+
"""
|
61 |
+
Attribute:
|
62 |
+
polygons (list[ndarray]): list[ndarray]: polygons for this mask.
|
63 |
+
Each ndarray has format [x, y, x, y, ...]
|
64 |
+
mask (ndarray): a binary mask
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, mask_or_polygons, height, width):
|
68 |
+
self._mask = self._polygons = self._has_holes = None
|
69 |
+
self.height = height
|
70 |
+
self.width = width
|
71 |
+
|
72 |
+
m = mask_or_polygons
|
73 |
+
if isinstance(m, dict):
|
74 |
+
# RLEs
|
75 |
+
assert "counts" in m and "size" in m
|
76 |
+
if isinstance(m["counts"], list): # uncompressed RLEs
|
77 |
+
h, w = m["size"]
|
78 |
+
assert h == height and w == width
|
79 |
+
m = mask_util.frPyObjects(m, h, w)
|
80 |
+
self._mask = mask_util.decode(m)[:, :]
|
81 |
+
return
|
82 |
+
|
83 |
+
if isinstance(m, list): # list[ndarray]
|
84 |
+
self._polygons = [np.asarray(x).reshape(-1) for x in m]
|
85 |
+
return
|
86 |
+
|
87 |
+
if isinstance(m, np.ndarray): # assumed to be a binary mask
|
88 |
+
assert m.shape[1] != 2, m.shape
|
89 |
+
assert m.shape == (
|
90 |
+
height,
|
91 |
+
width,
|
92 |
+
), f"mask shape: {m.shape}, target dims: {height}, {width}"
|
93 |
+
self._mask = m.astype("uint8")
|
94 |
+
return
|
95 |
+
|
96 |
+
raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
|
97 |
+
|
98 |
+
@property
|
99 |
+
def mask(self):
|
100 |
+
if self._mask is None:
|
101 |
+
self._mask = self.polygons_to_mask(self._polygons)
|
102 |
+
return self._mask
|
103 |
+
|
104 |
+
@property
|
105 |
+
def polygons(self):
|
106 |
+
if self._polygons is None:
|
107 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
108 |
+
return self._polygons
|
109 |
+
|
110 |
+
@property
|
111 |
+
def has_holes(self):
|
112 |
+
if self._has_holes is None:
|
113 |
+
if self._mask is not None:
|
114 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
115 |
+
else:
|
116 |
+
self._has_holes = False # if original format is polygon, does not have holes
|
117 |
+
return self._has_holes
|
118 |
+
|
119 |
+
def mask_to_polygons(self, mask):
|
120 |
+
# cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
|
121 |
+
# hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
|
122 |
+
# Internal contours (holes) are placed in hierarchy-2.
|
123 |
+
# cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
|
124 |
+
mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
|
125 |
+
res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
126 |
+
hierarchy = res[-1]
|
127 |
+
if hierarchy is None: # empty mask
|
128 |
+
return [], False
|
129 |
+
has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
|
130 |
+
res = res[-2]
|
131 |
+
res = [x.flatten() for x in res]
|
132 |
+
# These coordinates from OpenCV are integers in range [0, W-1 or H-1].
|
133 |
+
# We add 0.5 to turn them into real-value coordinate space. A better solution
|
134 |
+
# would be to first +0.5 and then dilate the returned polygon by 0.5.
|
135 |
+
res = [x + 0.5 for x in res if len(x) >= 6]
|
136 |
+
return res, has_holes
|
137 |
+
|
138 |
+
def polygons_to_mask(self, polygons):
|
139 |
+
rle = mask_util.frPyObjects(polygons, self.height, self.width)
|
140 |
+
rle = mask_util.merge(rle)
|
141 |
+
return mask_util.decode(rle)[:, :]
|
142 |
+
|
143 |
+
def area(self):
|
144 |
+
return self.mask.sum()
|
145 |
+
|
146 |
+
def bbox(self):
|
147 |
+
p = mask_util.frPyObjects(self.polygons, self.height, self.width)
|
148 |
+
p = mask_util.merge(p)
|
149 |
+
bbox = mask_util.toBbox(p)
|
150 |
+
bbox[2] += bbox[0]
|
151 |
+
bbox[3] += bbox[1]
|
152 |
+
return bbox
|
153 |
+
|
154 |
+
|
155 |
+
class _PanopticPrediction:
|
156 |
+
"""
|
157 |
+
Unify different panoptic annotation/prediction formats
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(self, panoptic_seg, segments_info, metadata=None):
|
161 |
+
if segments_info is None:
|
162 |
+
assert metadata is not None
|
163 |
+
# If "segments_info" is None, we assume "panoptic_img" is a
|
164 |
+
# H*W int32 image storing the panoptic_id in the format of
|
165 |
+
# category_id * label_divisor + instance_id. We reserve -1 for
|
166 |
+
# VOID label.
|
167 |
+
label_divisor = metadata.label_divisor
|
168 |
+
segments_info = []
|
169 |
+
for panoptic_label in np.unique(panoptic_seg.numpy()):
|
170 |
+
if panoptic_label == -1:
|
171 |
+
# VOID region.
|
172 |
+
continue
|
173 |
+
pred_class = panoptic_label // label_divisor
|
174 |
+
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
|
175 |
+
segments_info.append(
|
176 |
+
{
|
177 |
+
"id": int(panoptic_label),
|
178 |
+
"category_id": int(pred_class),
|
179 |
+
"isthing": bool(isthing),
|
180 |
+
}
|
181 |
+
)
|
182 |
+
del metadata
|
183 |
+
|
184 |
+
self._seg = panoptic_seg
|
185 |
+
|
186 |
+
self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
|
187 |
+
segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
|
188 |
+
areas = areas.numpy()
|
189 |
+
sorted_idxs = np.argsort(-areas)
|
190 |
+
self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
|
191 |
+
self._seg_ids = self._seg_ids.tolist()
|
192 |
+
for sid, area in zip(self._seg_ids, self._seg_areas):
|
193 |
+
if sid in self._sinfo:
|
194 |
+
self._sinfo[sid]["area"] = float(area)
|
195 |
+
|
196 |
+
def non_empty_mask(self):
|
197 |
+
"""
|
198 |
+
Returns:
|
199 |
+
(H, W) array, a mask for all pixels that have a prediction
|
200 |
+
"""
|
201 |
+
empty_ids = []
|
202 |
+
for id in self._seg_ids:
|
203 |
+
if id not in self._sinfo:
|
204 |
+
empty_ids.append(id)
|
205 |
+
if len(empty_ids) == 0:
|
206 |
+
return np.zeros(self._seg.shape, dtype=np.uint8)
|
207 |
+
assert (
|
208 |
+
len(empty_ids) == 1
|
209 |
+
), ">1 ids corresponds to no labels. This is currently not supported"
|
210 |
+
return (self._seg != empty_ids[0]).numpy().astype(bool)
|
211 |
+
|
212 |
+
def semantic_masks(self):
|
213 |
+
for sid in self._seg_ids:
|
214 |
+
sinfo = self._sinfo.get(sid)
|
215 |
+
if sinfo is None or sinfo["isthing"]:
|
216 |
+
# Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
|
217 |
+
continue
|
218 |
+
yield (self._seg == sid).numpy().astype(bool), sinfo
|
219 |
+
|
220 |
+
def instance_masks(self):
|
221 |
+
for sid in self._seg_ids:
|
222 |
+
sinfo = self._sinfo.get(sid)
|
223 |
+
if sinfo is None or not sinfo["isthing"]:
|
224 |
+
continue
|
225 |
+
mask = (self._seg == sid).numpy().astype(bool)
|
226 |
+
if mask.sum() > 0:
|
227 |
+
yield mask, sinfo
|
228 |
+
|
229 |
+
|
230 |
+
def _create_text_labels(classes, scores, class_names, is_crowd=None):
|
231 |
+
"""
|
232 |
+
Args:
|
233 |
+
classes (list[int] or None):
|
234 |
+
scores (list[float] or None):
|
235 |
+
class_names (list[str] or None):
|
236 |
+
is_crowd (list[bool] or None):
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
list[str] or None
|
240 |
+
"""
|
241 |
+
labels = None
|
242 |
+
if classes is not None:
|
243 |
+
if class_names is not None and len(class_names) > 0:
|
244 |
+
labels = [class_names[i] for i in classes]
|
245 |
+
else:
|
246 |
+
labels = [str(i) for i in classes]
|
247 |
+
if scores is not None:
|
248 |
+
if labels is None:
|
249 |
+
labels = ["{:.0f}%".format(s * 100) for s in scores]
|
250 |
+
else:
|
251 |
+
labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
|
252 |
+
if labels is not None and is_crowd is not None:
|
253 |
+
labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
|
254 |
+
return labels
|
255 |
+
|
256 |
+
|
257 |
+
class VisImage:
|
258 |
+
def __init__(self, img, scale=1.0):
|
259 |
+
"""
|
260 |
+
Args:
|
261 |
+
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
|
262 |
+
scale (float): scale the input image
|
263 |
+
"""
|
264 |
+
self.img = img
|
265 |
+
self.scale = scale
|
266 |
+
self.width, self.height = img.shape[1], img.shape[0]
|
267 |
+
self._setup_figure(img)
|
268 |
+
|
269 |
+
def _setup_figure(self, img):
|
270 |
+
"""
|
271 |
+
Args:
|
272 |
+
Same as in :meth:`__init__()`.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
|
276 |
+
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
|
277 |
+
"""
|
278 |
+
fig = mplfigure.Figure(frameon=False)
|
279 |
+
self.dpi = fig.get_dpi()
|
280 |
+
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
|
281 |
+
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
282 |
+
fig.set_size_inches(
|
283 |
+
(self.width * self.scale + 1e-2) / self.dpi,
|
284 |
+
(self.height * self.scale + 1e-2) / self.dpi,
|
285 |
+
)
|
286 |
+
self.canvas = FigureCanvasAgg(fig)
|
287 |
+
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
|
288 |
+
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
289 |
+
ax.axis("off")
|
290 |
+
self.fig = fig
|
291 |
+
self.ax = ax
|
292 |
+
self.reset_image(img)
|
293 |
+
|
294 |
+
def reset_image(self, img):
|
295 |
+
"""
|
296 |
+
Args:
|
297 |
+
img: same as in __init__
|
298 |
+
"""
|
299 |
+
img = img.astype("uint8")
|
300 |
+
self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
|
301 |
+
|
302 |
+
def save(self, filepath):
|
303 |
+
"""
|
304 |
+
Args:
|
305 |
+
filepath (str): a string that contains the absolute path, including the file name, where
|
306 |
+
the visualized image will be saved.
|
307 |
+
"""
|
308 |
+
self.fig.savefig(filepath)
|
309 |
+
|
310 |
+
def get_image(self):
|
311 |
+
"""
|
312 |
+
Returns:
|
313 |
+
ndarray:
|
314 |
+
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
|
315 |
+
The shape is scaled w.r.t the input image using the given `scale` argument.
|
316 |
+
"""
|
317 |
+
canvas = self.canvas
|
318 |
+
s, (width, height) = canvas.print_to_buffer()
|
319 |
+
# buf = io.BytesIO() # works for cairo backend
|
320 |
+
# canvas.print_rgba(buf)
|
321 |
+
# width, height = self.width, self.height
|
322 |
+
# s = buf.getvalue()
|
323 |
+
|
324 |
+
buffer = np.frombuffer(s, dtype="uint8")
|
325 |
+
|
326 |
+
img_rgba = buffer.reshape(height, width, 4)
|
327 |
+
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
328 |
+
return rgb.astype("uint8")
|
329 |
+
|
330 |
+
|
331 |
+
class Visualizer:
|
332 |
+
"""
|
333 |
+
Visualizer that draws data about detection/segmentation on images.
|
334 |
+
|
335 |
+
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
|
336 |
+
that draw primitive objects to images, as well as high-level wrappers like
|
337 |
+
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
|
338 |
+
that draw composite data in some pre-defined style.
|
339 |
+
|
340 |
+
Note that the exact visualization style for the high-level wrappers are subject to change.
|
341 |
+
Style such as color, opacity, label contents, visibility of labels, or even the visibility
|
342 |
+
of objects themselves (e.g. when the object is too small) may change according
|
343 |
+
to different heuristics, as long as the results still look visually reasonable.
|
344 |
+
|
345 |
+
To obtain a consistent style, you can implement custom drawing functions with the
|
346 |
+
abovementioned primitive methods instead. If you need more customized visualization
|
347 |
+
styles, you can process the data yourself following their format documented in
|
348 |
+
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
|
349 |
+
intend to satisfy everyone's preference on drawing styles.
|
350 |
+
|
351 |
+
This visualizer focuses on high rendering quality rather than performance. It is not
|
352 |
+
designed to be used for real-time applications.
|
353 |
+
"""
|
354 |
+
|
355 |
+
# TODO implement a fast, rasterized version using OpenCV
|
356 |
+
|
357 |
+
def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
|
358 |
+
"""
|
359 |
+
Args:
|
360 |
+
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
|
361 |
+
the height and width of the image respectively. C is the number of
|
362 |
+
color channels. The image is required to be in RGB format since that
|
363 |
+
is a requirement of the Matplotlib library. The image is also expected
|
364 |
+
to be in the range [0, 255].
|
365 |
+
metadata (Metadata): dataset metadata (e.g. class names and colors)
|
366 |
+
instance_mode (ColorMode): defines one of the pre-defined style for drawing
|
367 |
+
instances on an image.
|
368 |
+
"""
|
369 |
+
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
|
370 |
+
if metadata is None:
|
371 |
+
metadata = MetadataCatalog.get("__nonexist__")
|
372 |
+
self.metadata = metadata
|
373 |
+
self.output = VisImage(self.img, scale=scale)
|
374 |
+
self.cpu_device = torch.device("cpu")
|
375 |
+
|
376 |
+
# too small texts are useless, therefore clamp to 9
|
377 |
+
self._default_font_size = max(
|
378 |
+
np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
|
379 |
+
)
|
380 |
+
self._instance_mode = instance_mode
|
381 |
+
self.keypoint_threshold = _KEYPOINT_THRESHOLD
|
382 |
+
|
383 |
+
def draw_instance_predictions(self, predictions):
|
384 |
+
"""
|
385 |
+
Draw instance-level prediction results on an image.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
predictions (Instances): the output of an instance detection/segmentation
|
389 |
+
model. Following fields will be used to draw:
|
390 |
+
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
output (VisImage): image object with visualizations.
|
394 |
+
"""
|
395 |
+
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
|
396 |
+
scores = predictions.scores if predictions.has("scores") else None
|
397 |
+
classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
|
398 |
+
labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
|
399 |
+
keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
|
400 |
+
|
401 |
+
if predictions.has("pred_masks"):
|
402 |
+
masks = np.asarray(predictions.pred_masks)
|
403 |
+
masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
|
404 |
+
else:
|
405 |
+
masks = None
|
406 |
+
|
407 |
+
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
|
408 |
+
colors = [
|
409 |
+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
|
410 |
+
]
|
411 |
+
alpha = 0.8
|
412 |
+
else:
|
413 |
+
colors = None
|
414 |
+
alpha = 0.5
|
415 |
+
|
416 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
417 |
+
self.output.reset_image(
|
418 |
+
self._create_grayscale_image(
|
419 |
+
(predictions.pred_masks.any(dim=0) > 0).numpy()
|
420 |
+
if predictions.has("pred_masks")
|
421 |
+
else None
|
422 |
+
)
|
423 |
+
)
|
424 |
+
alpha = 0.3
|
425 |
+
|
426 |
+
self.overlay_instances(
|
427 |
+
masks=masks,
|
428 |
+
boxes=boxes,
|
429 |
+
labels=labels,
|
430 |
+
keypoints=keypoints,
|
431 |
+
assigned_colors=colors,
|
432 |
+
alpha=alpha,
|
433 |
+
)
|
434 |
+
return self.output
|
435 |
+
|
436 |
+
def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
|
437 |
+
"""
|
438 |
+
Draw semantic segmentation predictions/labels.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
|
442 |
+
Each value is the integer label of the pixel.
|
443 |
+
area_threshold (int): segments with less than `area_threshold` are not drawn.
|
444 |
+
alpha (float): the larger it is, the more opaque the segmentations are.
|
445 |
+
|
446 |
+
Returns:
|
447 |
+
output (VisImage): image object with visualizations.
|
448 |
+
"""
|
449 |
+
if isinstance(sem_seg, torch.Tensor):
|
450 |
+
sem_seg = sem_seg.numpy()
|
451 |
+
labels, areas = np.unique(sem_seg, return_counts=True)
|
452 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
453 |
+
labels = labels[sorted_idxs]
|
454 |
+
for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
|
455 |
+
try:
|
456 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
|
457 |
+
except (AttributeError, IndexError):
|
458 |
+
mask_color = None
|
459 |
+
|
460 |
+
binary_mask = (sem_seg == label).astype(np.uint8)
|
461 |
+
text = self.metadata.stuff_classes[label]
|
462 |
+
self.draw_binary_mask(
|
463 |
+
binary_mask,
|
464 |
+
color=mask_color,
|
465 |
+
edge_color=_OFF_WHITE,
|
466 |
+
text=text,
|
467 |
+
alpha=alpha,
|
468 |
+
area_threshold=area_threshold,
|
469 |
+
)
|
470 |
+
return self.output
|
471 |
+
|
472 |
+
def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):
|
473 |
+
"""
|
474 |
+
Draw panoptic prediction annotations or results.
|
475 |
+
|
476 |
+
Args:
|
477 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
|
478 |
+
segment.
|
479 |
+
segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
|
480 |
+
If it is a ``list[dict]``, each dict contains keys "id", "category_id".
|
481 |
+
If None, category id of each pixel is computed by
|
482 |
+
``pixel // metadata.label_divisor``.
|
483 |
+
area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
output (VisImage): image object with visualizations.
|
487 |
+
"""
|
488 |
+
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
|
489 |
+
|
490 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
491 |
+
self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
|
492 |
+
|
493 |
+
# draw mask for all semantic segments first i.e. "stuff"
|
494 |
+
for mask, sinfo in pred.semantic_masks():
|
495 |
+
category_idx = sinfo["category_id"]
|
496 |
+
try:
|
497 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
|
498 |
+
except AttributeError:
|
499 |
+
mask_color = None
|
500 |
+
|
501 |
+
text = self.metadata.stuff_classes[category_idx]
|
502 |
+
self.draw_binary_mask(
|
503 |
+
mask,
|
504 |
+
color=mask_color,
|
505 |
+
edge_color=_OFF_WHITE,
|
506 |
+
text=text,
|
507 |
+
alpha=alpha,
|
508 |
+
area_threshold=area_threshold,
|
509 |
+
)
|
510 |
+
|
511 |
+
# draw mask for all instances second
|
512 |
+
all_instances = list(pred.instance_masks())
|
513 |
+
if len(all_instances) == 0:
|
514 |
+
return self.output
|
515 |
+
masks, sinfo = list(zip(*all_instances))
|
516 |
+
category_ids = [x["category_id"] for x in sinfo]
|
517 |
+
|
518 |
+
try:
|
519 |
+
scores = [x["score"] for x in sinfo]
|
520 |
+
except KeyError:
|
521 |
+
scores = None
|
522 |
+
labels = _create_text_labels(
|
523 |
+
category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
|
524 |
+
)
|
525 |
+
|
526 |
+
try:
|
527 |
+
colors = [
|
528 |
+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
|
529 |
+
]
|
530 |
+
except AttributeError:
|
531 |
+
colors = None
|
532 |
+
self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
|
533 |
+
|
534 |
+
return self.output
|
535 |
+
|
536 |
+
draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
|
537 |
+
|
538 |
+
def draw_dataset_dict(self, dic):
|
539 |
+
"""
|
540 |
+
Draw annotations/segmentations in Detectron2 Dataset format.
|
541 |
+
|
542 |
+
Args:
|
543 |
+
dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
|
544 |
+
|
545 |
+
Returns:
|
546 |
+
output (VisImage): image object with visualizations.
|
547 |
+
"""
|
548 |
+
annos = dic.get("annotations", None)
|
549 |
+
if annos:
|
550 |
+
if "segmentation" in annos[0]:
|
551 |
+
masks = [x["segmentation"] for x in annos]
|
552 |
+
else:
|
553 |
+
masks = None
|
554 |
+
if "keypoints" in annos[0]:
|
555 |
+
keypts = [x["keypoints"] for x in annos]
|
556 |
+
keypts = np.array(keypts).reshape(len(annos), -1, 3)
|
557 |
+
else:
|
558 |
+
keypts = None
|
559 |
+
|
560 |
+
boxes = [
|
561 |
+
BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
|
562 |
+
if len(x["bbox"]) == 4
|
563 |
+
else x["bbox"]
|
564 |
+
for x in annos
|
565 |
+
]
|
566 |
+
|
567 |
+
colors = None
|
568 |
+
category_ids = [x["category_id"] for x in annos]
|
569 |
+
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
|
570 |
+
colors = [
|
571 |
+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
|
572 |
+
for c in category_ids
|
573 |
+
]
|
574 |
+
names = self.metadata.get("thing_classes", None)
|
575 |
+
labels = _create_text_labels(
|
576 |
+
category_ids,
|
577 |
+
scores=None,
|
578 |
+
class_names=names,
|
579 |
+
is_crowd=[x.get("iscrowd", 0) for x in annos],
|
580 |
+
)
|
581 |
+
self.overlay_instances(
|
582 |
+
labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
|
583 |
+
)
|
584 |
+
|
585 |
+
sem_seg = dic.get("sem_seg", None)
|
586 |
+
if sem_seg is None and "sem_seg_file_name" in dic:
|
587 |
+
with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
|
588 |
+
sem_seg = Image.open(f)
|
589 |
+
sem_seg = np.asarray(sem_seg, dtype="uint8")
|
590 |
+
if sem_seg is not None:
|
591 |
+
self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
|
592 |
+
|
593 |
+
pan_seg = dic.get("pan_seg", None)
|
594 |
+
if pan_seg is None and "pan_seg_file_name" in dic:
|
595 |
+
with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
|
596 |
+
pan_seg = Image.open(f)
|
597 |
+
pan_seg = np.asarray(pan_seg)
|
598 |
+
from panopticapi.utils import rgb2id
|
599 |
+
|
600 |
+
pan_seg = rgb2id(pan_seg)
|
601 |
+
if pan_seg is not None:
|
602 |
+
segments_info = dic["segments_info"]
|
603 |
+
pan_seg = torch.tensor(pan_seg)
|
604 |
+
self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
|
605 |
+
return self.output
|
606 |
+
|
607 |
+
def overlay_instances(
|
608 |
+
self,
|
609 |
+
*,
|
610 |
+
boxes=None,
|
611 |
+
labels=None,
|
612 |
+
masks=None,
|
613 |
+
keypoints=None,
|
614 |
+
assigned_colors=None,
|
615 |
+
alpha=0.5,
|
616 |
+
):
|
617 |
+
"""
|
618 |
+
Args:
|
619 |
+
boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
|
620 |
+
or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
|
621 |
+
or a :class:`RotatedBoxes`,
|
622 |
+
or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
|
623 |
+
for the N objects in a single image,
|
624 |
+
labels (list[str]): the text to be displayed for each instance.
|
625 |
+
masks (masks-like object): Supported types are:
|
626 |
+
|
627 |
+
* :class:`detectron2.structures.PolygonMasks`,
|
628 |
+
:class:`detectron2.structures.BitMasks`.
|
629 |
+
* list[list[ndarray]]: contains the segmentation masks for all objects in one image.
|
630 |
+
The first level of the list corresponds to individual instances. The second
|
631 |
+
level to all the polygon that compose the instance, and the third level
|
632 |
+
to the polygon coordinates. The third level should have the format of
|
633 |
+
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
|
634 |
+
* list[ndarray]: each ndarray is a binary mask of shape (H, W).
|
635 |
+
* list[dict]: each dict is a COCO-style RLE.
|
636 |
+
keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
|
637 |
+
where the N is the number of instances and K is the number of keypoints.
|
638 |
+
The last dimension corresponds to (x, y, visibility or score).
|
639 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
640 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
641 |
+
for full list of formats that the colors are accepted in.
|
642 |
+
Returns:
|
643 |
+
output (VisImage): image object with visualizations.
|
644 |
+
"""
|
645 |
+
num_instances = 0
|
646 |
+
if boxes is not None:
|
647 |
+
boxes = self._convert_boxes(boxes)
|
648 |
+
num_instances = len(boxes)
|
649 |
+
if masks is not None:
|
650 |
+
masks = self._convert_masks(masks)
|
651 |
+
if num_instances:
|
652 |
+
assert len(masks) == num_instances
|
653 |
+
else:
|
654 |
+
num_instances = len(masks)
|
655 |
+
if keypoints is not None:
|
656 |
+
if num_instances:
|
657 |
+
assert len(keypoints) == num_instances
|
658 |
+
else:
|
659 |
+
num_instances = len(keypoints)
|
660 |
+
keypoints = self._convert_keypoints(keypoints)
|
661 |
+
if labels is not None:
|
662 |
+
assert len(labels) == num_instances
|
663 |
+
if assigned_colors is None:
|
664 |
+
assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
|
665 |
+
if num_instances == 0:
|
666 |
+
return self.output
|
667 |
+
if boxes is not None and boxes.shape[1] == 5:
|
668 |
+
return self.overlay_rotated_instances(
|
669 |
+
boxes=boxes, labels=labels, assigned_colors=assigned_colors
|
670 |
+
)
|
671 |
+
|
672 |
+
# Display in largest to smallest order to reduce occlusion.
|
673 |
+
areas = None
|
674 |
+
if boxes is not None:
|
675 |
+
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
|
676 |
+
elif masks is not None:
|
677 |
+
areas = np.asarray([x.area() for x in masks])
|
678 |
+
|
679 |
+
if areas is not None:
|
680 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
681 |
+
# Re-order overlapped instances in descending order.
|
682 |
+
boxes = boxes[sorted_idxs] if boxes is not None else None
|
683 |
+
labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
684 |
+
masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
|
685 |
+
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
|
686 |
+
keypoints = keypoints[sorted_idxs] if keypoints is not None else None
|
687 |
+
|
688 |
+
for i in range(num_instances):
|
689 |
+
color = assigned_colors[i]
|
690 |
+
if boxes is not None:
|
691 |
+
self.draw_box(boxes[i], edge_color=color)
|
692 |
+
|
693 |
+
if masks is not None:
|
694 |
+
for segment in masks[i].polygons:
|
695 |
+
self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
|
696 |
+
|
697 |
+
if labels is not None:
|
698 |
+
# first get a box
|
699 |
+
if boxes is not None:
|
700 |
+
x0, y0, x1, y1 = boxes[i]
|
701 |
+
text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
|
702 |
+
horiz_align = "left"
|
703 |
+
elif masks is not None:
|
704 |
+
# skip small mask without polygon
|
705 |
+
if len(masks[i].polygons) == 0:
|
706 |
+
continue
|
707 |
+
|
708 |
+
x0, y0, x1, y1 = masks[i].bbox()
|
709 |
+
|
710 |
+
# draw text in the center (defined by median) when box is not drawn
|
711 |
+
# median is less sensitive to outliers.
|
712 |
+
text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
|
713 |
+
horiz_align = "center"
|
714 |
+
else:
|
715 |
+
continue # drawing the box confidence for keypoints isn't very useful.
|
716 |
+
# for small objects, draw text at the side to avoid occlusion
|
717 |
+
instance_area = (y1 - y0) * (x1 - x0)
|
718 |
+
if (
|
719 |
+
instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
|
720 |
+
or y1 - y0 < 40 * self.output.scale
|
721 |
+
):
|
722 |
+
if y1 >= self.output.height - 5:
|
723 |
+
text_pos = (x1, y0)
|
724 |
+
else:
|
725 |
+
text_pos = (x0, y1)
|
726 |
+
|
727 |
+
height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
|
728 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
729 |
+
font_size = (
|
730 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
731 |
+
* 0.5
|
732 |
+
* self._default_font_size
|
733 |
+
)
|
734 |
+
self.draw_text(
|
735 |
+
labels[i],
|
736 |
+
text_pos,
|
737 |
+
color=lighter_color,
|
738 |
+
horizontal_alignment=horiz_align,
|
739 |
+
font_size=font_size,
|
740 |
+
)
|
741 |
+
|
742 |
+
# draw keypoints
|
743 |
+
if keypoints is not None:
|
744 |
+
for keypoints_per_instance in keypoints:
|
745 |
+
self.draw_and_connect_keypoints(keypoints_per_instance)
|
746 |
+
|
747 |
+
return self.output
|
748 |
+
|
749 |
+
def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
|
750 |
+
"""
|
751 |
+
Args:
|
752 |
+
boxes (ndarray): an Nx5 numpy array of
|
753 |
+
(x_center, y_center, width, height, angle_degrees) format
|
754 |
+
for the N objects in a single image.
|
755 |
+
labels (list[str]): the text to be displayed for each instance.
|
756 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
757 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
758 |
+
for full list of formats that the colors are accepted in.
|
759 |
+
|
760 |
+
Returns:
|
761 |
+
output (VisImage): image object with visualizations.
|
762 |
+
"""
|
763 |
+
num_instances = len(boxes)
|
764 |
+
|
765 |
+
if assigned_colors is None:
|
766 |
+
assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
|
767 |
+
if num_instances == 0:
|
768 |
+
return self.output
|
769 |
+
|
770 |
+
# Display in largest to smallest order to reduce occlusion.
|
771 |
+
if boxes is not None:
|
772 |
+
areas = boxes[:, 2] * boxes[:, 3]
|
773 |
+
|
774 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
775 |
+
# Re-order overlapped instances in descending order.
|
776 |
+
boxes = boxes[sorted_idxs]
|
777 |
+
labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
778 |
+
colors = [assigned_colors[idx] for idx in sorted_idxs]
|
779 |
+
|
780 |
+
for i in range(num_instances):
|
781 |
+
self.draw_rotated_box_with_label(
|
782 |
+
boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
|
783 |
+
)
|
784 |
+
|
785 |
+
return self.output
|
786 |
+
|
787 |
+
def draw_and_connect_keypoints(self, keypoints):
|
788 |
+
"""
|
789 |
+
Draws keypoints of an instance and follows the rules for keypoint connections
|
790 |
+
to draw lines between appropriate keypoints. This follows color heuristics for
|
791 |
+
line color.
|
792 |
+
|
793 |
+
Args:
|
794 |
+
keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
|
795 |
+
and the last dimension corresponds to (x, y, probability).
|
796 |
+
|
797 |
+
Returns:
|
798 |
+
output (VisImage): image object with visualizations.
|
799 |
+
"""
|
800 |
+
visible = {}
|
801 |
+
keypoint_names = self.metadata.get("keypoint_names")
|
802 |
+
for idx, keypoint in enumerate(keypoints):
|
803 |
+
|
804 |
+
# draw keypoint
|
805 |
+
x, y, prob = keypoint
|
806 |
+
if prob > self.keypoint_threshold:
|
807 |
+
self.draw_circle((x, y), color=_RED)
|
808 |
+
if keypoint_names:
|
809 |
+
keypoint_name = keypoint_names[idx]
|
810 |
+
visible[keypoint_name] = (x, y)
|
811 |
+
|
812 |
+
if self.metadata.get("keypoint_connection_rules"):
|
813 |
+
for kp0, kp1, color in self.metadata.keypoint_connection_rules:
|
814 |
+
if kp0 in visible and kp1 in visible:
|
815 |
+
x0, y0 = visible[kp0]
|
816 |
+
x1, y1 = visible[kp1]
|
817 |
+
color = tuple(x / 255.0 for x in color)
|
818 |
+
self.draw_line([x0, x1], [y0, y1], color=color)
|
819 |
+
|
820 |
+
# draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
|
821 |
+
# Note that this strategy is specific to person keypoints.
|
822 |
+
# For other keypoints, it should just do nothing
|
823 |
+
try:
|
824 |
+
ls_x, ls_y = visible["left_shoulder"]
|
825 |
+
rs_x, rs_y = visible["right_shoulder"]
|
826 |
+
mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
|
827 |
+
except KeyError:
|
828 |
+
pass
|
829 |
+
else:
|
830 |
+
# draw line from nose to mid-shoulder
|
831 |
+
nose_x, nose_y = visible.get("nose", (None, None))
|
832 |
+
if nose_x is not None:
|
833 |
+
self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
|
834 |
+
|
835 |
+
try:
|
836 |
+
# draw line from mid-shoulder to mid-hip
|
837 |
+
lh_x, lh_y = visible["left_hip"]
|
838 |
+
rh_x, rh_y = visible["right_hip"]
|
839 |
+
except KeyError:
|
840 |
+
pass
|
841 |
+
else:
|
842 |
+
mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
|
843 |
+
self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
|
844 |
+
return self.output
|
845 |
+
|
846 |
+
"""
|
847 |
+
Primitive drawing functions:
|
848 |
+
"""
|
849 |
+
|
850 |
+
def draw_text(
|
851 |
+
self,
|
852 |
+
text,
|
853 |
+
position,
|
854 |
+
*,
|
855 |
+
font_size=None,
|
856 |
+
color="g",
|
857 |
+
horizontal_alignment="center",
|
858 |
+
rotation=0,
|
859 |
+
):
|
860 |
+
"""
|
861 |
+
Args:
|
862 |
+
text (str): class label
|
863 |
+
position (tuple): a tuple of the x and y coordinates to place text on image.
|
864 |
+
font_size (int, optional): font of the text. If not provided, a font size
|
865 |
+
proportional to the image width is calculated and used.
|
866 |
+
color: color of the text. Refer to `matplotlib.colors` for full list
|
867 |
+
of formats that are accepted.
|
868 |
+
horizontal_alignment (str): see `matplotlib.text.Text`
|
869 |
+
rotation: rotation angle in degrees CCW
|
870 |
+
|
871 |
+
Returns:
|
872 |
+
output (VisImage): image object with text drawn.
|
873 |
+
"""
|
874 |
+
if not font_size:
|
875 |
+
font_size = self._default_font_size
|
876 |
+
|
877 |
+
# since the text background is dark, we don't want the text to be dark
|
878 |
+
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
879 |
+
color[np.argmax(color)] = max(0.8, np.max(color))
|
880 |
+
|
881 |
+
x, y = position
|
882 |
+
self.output.ax.text(
|
883 |
+
x,
|
884 |
+
y,
|
885 |
+
text,
|
886 |
+
size=font_size * self.output.scale,
|
887 |
+
family="sans-serif",
|
888 |
+
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
|
889 |
+
verticalalignment="top",
|
890 |
+
horizontalalignment=horizontal_alignment,
|
891 |
+
color=color,
|
892 |
+
zorder=10,
|
893 |
+
rotation=rotation,
|
894 |
+
)
|
895 |
+
return self.output
|
896 |
+
|
897 |
+
def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
|
898 |
+
"""
|
899 |
+
Args:
|
900 |
+
box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
|
901 |
+
are the coordinates of the image's top left corner. x1 and y1 are the
|
902 |
+
coordinates of the image's bottom right corner.
|
903 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
904 |
+
edge_color: color of the outline of the box. Refer to `matplotlib.colors`
|
905 |
+
for full list of formats that are accepted.
|
906 |
+
line_style (string): the string to use to create the outline of the boxes.
|
907 |
+
|
908 |
+
Returns:
|
909 |
+
output (VisImage): image object with box drawn.
|
910 |
+
"""
|
911 |
+
x0, y0, x1, y1 = box_coord
|
912 |
+
width = x1 - x0
|
913 |
+
height = y1 - y0
|
914 |
+
|
915 |
+
linewidth = max(self._default_font_size / 4, 1)
|
916 |
+
|
917 |
+
self.output.ax.add_patch(
|
918 |
+
mpl.patches.Rectangle(
|
919 |
+
(x0, y0),
|
920 |
+
width,
|
921 |
+
height,
|
922 |
+
fill=False,
|
923 |
+
edgecolor=edge_color,
|
924 |
+
linewidth=linewidth * self.output.scale,
|
925 |
+
alpha=alpha,
|
926 |
+
linestyle=line_style,
|
927 |
+
)
|
928 |
+
)
|
929 |
+
return self.output
|
930 |
+
|
931 |
+
def draw_rotated_box_with_label(
|
932 |
+
self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
|
933 |
+
):
|
934 |
+
"""
|
935 |
+
Draw a rotated box with label on its top-left corner.
|
936 |
+
|
937 |
+
Args:
|
938 |
+
rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
|
939 |
+
where cnt_x and cnt_y are the center coordinates of the box.
|
940 |
+
w and h are the width and height of the box. angle represents how
|
941 |
+
many degrees the box is rotated CCW with regard to the 0-degree box.
|
942 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
943 |
+
edge_color: color of the outline of the box. Refer to `matplotlib.colors`
|
944 |
+
for full list of formats that are accepted.
|
945 |
+
line_style (string): the string to use to create the outline of the boxes.
|
946 |
+
label (string): label for rotated box. It will not be rendered when set to None.
|
947 |
+
|
948 |
+
Returns:
|
949 |
+
output (VisImage): image object with box drawn.
|
950 |
+
"""
|
951 |
+
cnt_x, cnt_y, w, h, angle = rotated_box
|
952 |
+
area = w * h
|
953 |
+
# use thinner lines when the box is small
|
954 |
+
linewidth = self._default_font_size / (
|
955 |
+
6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
|
956 |
+
)
|
957 |
+
|
958 |
+
theta = angle * math.pi / 180.0
|
959 |
+
c = math.cos(theta)
|
960 |
+
s = math.sin(theta)
|
961 |
+
rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
|
962 |
+
# x: left->right ; y: top->down
|
963 |
+
rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
|
964 |
+
for k in range(4):
|
965 |
+
j = (k + 1) % 4
|
966 |
+
self.draw_line(
|
967 |
+
[rotated_rect[k][0], rotated_rect[j][0]],
|
968 |
+
[rotated_rect[k][1], rotated_rect[j][1]],
|
969 |
+
color=edge_color,
|
970 |
+
linestyle="--" if k == 1 else line_style,
|
971 |
+
linewidth=linewidth,
|
972 |
+
)
|
973 |
+
|
974 |
+
if label is not None:
|
975 |
+
text_pos = rotated_rect[1] # topleft corner
|
976 |
+
|
977 |
+
height_ratio = h / np.sqrt(self.output.height * self.output.width)
|
978 |
+
label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
|
979 |
+
font_size = (
|
980 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
|
981 |
+
)
|
982 |
+
self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
|
983 |
+
|
984 |
+
return self.output
|
985 |
+
|
986 |
+
def draw_circle(self, circle_coord, color, radius=3):
|
987 |
+
"""
|
988 |
+
Args:
|
989 |
+
circle_coord (list(int) or tuple(int)): contains the x and y coordinates
|
990 |
+
of the center of the circle.
|
991 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
992 |
+
formats that are accepted.
|
993 |
+
radius (int): radius of the circle.
|
994 |
+
|
995 |
+
Returns:
|
996 |
+
output (VisImage): image object with box drawn.
|
997 |
+
"""
|
998 |
+
x, y = circle_coord
|
999 |
+
self.output.ax.add_patch(
|
1000 |
+
mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
|
1001 |
+
)
|
1002 |
+
return self.output
|
1003 |
+
|
1004 |
+
def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
|
1005 |
+
"""
|
1006 |
+
Args:
|
1007 |
+
x_data (list[int]): a list containing x values of all the points being drawn.
|
1008 |
+
Length of list should match the length of y_data.
|
1009 |
+
y_data (list[int]): a list containing y values of all the points being drawn.
|
1010 |
+
Length of list should match the length of x_data.
|
1011 |
+
color: color of the line. Refer to `matplotlib.colors` for a full list of
|
1012 |
+
formats that are accepted.
|
1013 |
+
linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
|
1014 |
+
for a full list of formats that are accepted.
|
1015 |
+
linewidth (float or None): width of the line. When it's None,
|
1016 |
+
a default value will be computed and used.
|
1017 |
+
|
1018 |
+
Returns:
|
1019 |
+
output (VisImage): image object with line drawn.
|
1020 |
+
"""
|
1021 |
+
if linewidth is None:
|
1022 |
+
linewidth = self._default_font_size / 3
|
1023 |
+
linewidth = max(linewidth, 1)
|
1024 |
+
self.output.ax.add_line(
|
1025 |
+
mpl.lines.Line2D(
|
1026 |
+
x_data,
|
1027 |
+
y_data,
|
1028 |
+
linewidth=linewidth * self.output.scale,
|
1029 |
+
color=color,
|
1030 |
+
linestyle=linestyle,
|
1031 |
+
)
|
1032 |
+
)
|
1033 |
+
return self.output
|
1034 |
+
|
1035 |
+
def draw_binary_mask(
|
1036 |
+
self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10
|
1037 |
+
):
|
1038 |
+
"""
|
1039 |
+
Args:
|
1040 |
+
binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
|
1041 |
+
W is the image width. Each value in the array is either a 0 or 1 value of uint8
|
1042 |
+
type.
|
1043 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
1044 |
+
formats that are accepted. If None, will pick a random color.
|
1045 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
1046 |
+
full list of formats that are accepted.
|
1047 |
+
text (str): if None, will be drawn on the object
|
1048 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
1049 |
+
area_threshold (float): a connected component smaller than this area will not be shown.
|
1050 |
+
|
1051 |
+
Returns:
|
1052 |
+
output (VisImage): image object with mask drawn.
|
1053 |
+
"""
|
1054 |
+
if color is None:
|
1055 |
+
color = random_color(rgb=True, maximum=1)
|
1056 |
+
color = mplc.to_rgb(color)
|
1057 |
+
|
1058 |
+
has_valid_segment = False
|
1059 |
+
binary_mask = binary_mask.astype("uint8") # opencv needs uint8
|
1060 |
+
mask = GenericMask(binary_mask, self.output.height, self.output.width)
|
1061 |
+
shape2d = (binary_mask.shape[0], binary_mask.shape[1])
|
1062 |
+
|
1063 |
+
if not mask.has_holes:
|
1064 |
+
# draw polygons for regular masks
|
1065 |
+
for segment in mask.polygons:
|
1066 |
+
area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
|
1067 |
+
if area < (area_threshold or 0):
|
1068 |
+
continue
|
1069 |
+
has_valid_segment = True
|
1070 |
+
segment = segment.reshape(-1, 2)
|
1071 |
+
self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
|
1072 |
+
else:
|
1073 |
+
# TODO: Use Path/PathPatch to draw vector graphics:
|
1074 |
+
# https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
|
1075 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
1076 |
+
rgba[:, :, :3] = color
|
1077 |
+
rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
|
1078 |
+
has_valid_segment = True
|
1079 |
+
self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
|
1080 |
+
|
1081 |
+
if text is not None and has_valid_segment:
|
1082 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
1083 |
+
self._draw_text_in_mask(binary_mask, text, lighter_color)
|
1084 |
+
return self.output
|
1085 |
+
|
1086 |
+
def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
|
1087 |
+
"""
|
1088 |
+
Args:
|
1089 |
+
soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
|
1090 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
1091 |
+
formats that are accepted. If None, will pick a random color.
|
1092 |
+
text (str): if None, will be drawn on the object
|
1093 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
1094 |
+
|
1095 |
+
Returns:
|
1096 |
+
output (VisImage): image object with mask drawn.
|
1097 |
+
"""
|
1098 |
+
if color is None:
|
1099 |
+
color = random_color(rgb=True, maximum=1)
|
1100 |
+
color = mplc.to_rgb(color)
|
1101 |
+
|
1102 |
+
shape2d = (soft_mask.shape[0], soft_mask.shape[1])
|
1103 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
1104 |
+
rgba[:, :, :3] = color
|
1105 |
+
rgba[:, :, 3] = soft_mask * alpha
|
1106 |
+
self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
|
1107 |
+
|
1108 |
+
if text is not None:
|
1109 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
1110 |
+
binary_mask = (soft_mask > 0.5).astype("uint8")
|
1111 |
+
self._draw_text_in_mask(binary_mask, text, lighter_color)
|
1112 |
+
return self.output
|
1113 |
+
|
1114 |
+
def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
|
1115 |
+
"""
|
1116 |
+
Args:
|
1117 |
+
segment: numpy array of shape Nx2, containing all the points in the polygon.
|
1118 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
1119 |
+
formats that are accepted.
|
1120 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
1121 |
+
full list of formats that are accepted. If not provided, a darker shade
|
1122 |
+
of the polygon color will be used instead.
|
1123 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
1124 |
+
|
1125 |
+
Returns:
|
1126 |
+
output (VisImage): image object with polygon drawn.
|
1127 |
+
"""
|
1128 |
+
if edge_color is None:
|
1129 |
+
# make edge color darker than the polygon color
|
1130 |
+
if alpha > 0.8:
|
1131 |
+
edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
|
1132 |
+
else:
|
1133 |
+
edge_color = color
|
1134 |
+
edge_color = mplc.to_rgb(edge_color) + (1,)
|
1135 |
+
|
1136 |
+
polygon = mpl.patches.Polygon(
|
1137 |
+
segment,
|
1138 |
+
fill=True,
|
1139 |
+
facecolor=mplc.to_rgb(color) + (alpha,),
|
1140 |
+
edgecolor=edge_color,
|
1141 |
+
linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
|
1142 |
+
)
|
1143 |
+
self.output.ax.add_patch(polygon)
|
1144 |
+
return self.output
|
1145 |
+
|
1146 |
+
"""
|
1147 |
+
Internal methods:
|
1148 |
+
"""
|
1149 |
+
|
1150 |
+
def _jitter(self, color):
|
1151 |
+
"""
|
1152 |
+
Randomly modifies given color to produce a slightly different color than the color given.
|
1153 |
+
|
1154 |
+
Args:
|
1155 |
+
color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
|
1156 |
+
picked. The values in the list are in the [0.0, 1.0] range.
|
1157 |
+
|
1158 |
+
Returns:
|
1159 |
+
jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
|
1160 |
+
color after being jittered. The values in the list are in the [0.0, 1.0] range.
|
1161 |
+
"""
|
1162 |
+
color = mplc.to_rgb(color)
|
1163 |
+
vec = np.random.rand(3)
|
1164 |
+
# better to do it in another color space
|
1165 |
+
vec = vec / np.linalg.norm(vec) * 0.5
|
1166 |
+
res = np.clip(vec + color, 0, 1)
|
1167 |
+
return tuple(res)
|
1168 |
+
|
1169 |
+
def _create_grayscale_image(self, mask=None):
|
1170 |
+
"""
|
1171 |
+
Create a grayscale version of the original image.
|
1172 |
+
The colors in masked area, if given, will be kept.
|
1173 |
+
"""
|
1174 |
+
img_bw = self.img.astype("f4").mean(axis=2)
|
1175 |
+
img_bw = np.stack([img_bw] * 3, axis=2)
|
1176 |
+
if mask is not None:
|
1177 |
+
img_bw[mask] = self.img[mask]
|
1178 |
+
return img_bw
|
1179 |
+
|
1180 |
+
def _change_color_brightness(self, color, brightness_factor):
|
1181 |
+
"""
|
1182 |
+
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
|
1183 |
+
less or more saturation than the original color.
|
1184 |
+
|
1185 |
+
Args:
|
1186 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
1187 |
+
formats that are accepted.
|
1188 |
+
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
|
1189 |
+
0 will correspond to no change, a factor in [-1.0, 0) range will result in
|
1190 |
+
a darker color and a factor in (0, 1.0] range will result in a lighter color.
|
1191 |
+
|
1192 |
+
Returns:
|
1193 |
+
modified_color (tuple[double]): a tuple containing the RGB values of the
|
1194 |
+
modified color. Each value in the tuple is in the [0.0, 1.0] range.
|
1195 |
+
"""
|
1196 |
+
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
1197 |
+
color = mplc.to_rgb(color)
|
1198 |
+
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
1199 |
+
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
1200 |
+
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
1201 |
+
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
1202 |
+
modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
|
1203 |
+
return tuple(np.clip(modified_color, 0.0, 1.0))
|
1204 |
+
|
1205 |
+
def _convert_boxes(self, boxes):
|
1206 |
+
"""
|
1207 |
+
Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
|
1208 |
+
"""
|
1209 |
+
if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
|
1210 |
+
return boxes.tensor.detach().numpy()
|
1211 |
+
else:
|
1212 |
+
return np.asarray(boxes)
|
1213 |
+
|
1214 |
+
def _convert_masks(self, masks_or_polygons):
|
1215 |
+
"""
|
1216 |
+
Convert different format of masks or polygons to a tuple of masks and polygons.
|
1217 |
+
|
1218 |
+
Returns:
|
1219 |
+
list[GenericMask]:
|
1220 |
+
"""
|
1221 |
+
|
1222 |
+
m = masks_or_polygons
|
1223 |
+
if isinstance(m, PolygonMasks):
|
1224 |
+
m = m.polygons
|
1225 |
+
if isinstance(m, BitMasks):
|
1226 |
+
m = m.tensor.numpy()
|
1227 |
+
if isinstance(m, torch.Tensor):
|
1228 |
+
m = m.numpy()
|
1229 |
+
ret = []
|
1230 |
+
for x in m:
|
1231 |
+
if isinstance(x, GenericMask):
|
1232 |
+
ret.append(x)
|
1233 |
+
else:
|
1234 |
+
ret.append(GenericMask(x, self.output.height, self.output.width))
|
1235 |
+
return ret
|
1236 |
+
|
1237 |
+
def _draw_text_in_mask(self, binary_mask, text, color):
|
1238 |
+
"""
|
1239 |
+
Find proper places to draw text given a binary mask.
|
1240 |
+
"""
|
1241 |
+
# TODO sometimes drawn on wrong objects. the heuristics here can improve.
|
1242 |
+
_num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
|
1243 |
+
if stats[1:, -1].size == 0:
|
1244 |
+
return
|
1245 |
+
largest_component_id = np.argmax(stats[1:, -1]) + 1
|
1246 |
+
|
1247 |
+
# draw text on the largest component, as well as other very large components.
|
1248 |
+
for cid in range(1, _num_cc):
|
1249 |
+
if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
|
1250 |
+
# median is more stable than centroid
|
1251 |
+
# center = centroids[largest_component_id]
|
1252 |
+
center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
|
1253 |
+
self.draw_text(text, center, color=color)
|
1254 |
+
|
1255 |
+
def _convert_keypoints(self, keypoints):
|
1256 |
+
if isinstance(keypoints, Keypoints):
|
1257 |
+
keypoints = keypoints.tensor
|
1258 |
+
keypoints = np.asarray(keypoints)
|
1259 |
+
return keypoints
|
1260 |
+
|
1261 |
+
def get_output(self):
|
1262 |
+
"""
|
1263 |
+
Returns:
|
1264 |
+
output (VisImage): the image output containing the visualizations added
|
1265 |
+
to the image.
|
1266 |
+
"""
|
1267 |
+
return self.output
|
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import data # register all new datasets
|
3 |
+
from . import modeling
|
4 |
+
|
5 |
+
# config
|
6 |
+
from .config import *
|
7 |
+
|
8 |
+
# models
|
9 |
+
from .oneformer_model import OneFormer
|
extensions/microsoftexcel-controlnet/annotator/oneformer/oneformer/config.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
from annotator.oneformer.detectron2.config import CfgNode as CN
|
4 |
+
|
5 |
+
__all__ = ["add_common_config", "add_oneformer_config", "add_swin_config",
|
6 |
+
"add_dinat_config", "add_beit_adapter_config", "add_convnext_config"]
|
7 |
+
|
8 |
+
def add_common_config(cfg):
|
9 |
+
"""
|
10 |
+
Add config for common configuration
|
11 |
+
"""
|
12 |
+
# data config
|
13 |
+
# select the dataset mapper
|
14 |
+
cfg.INPUT.DATASET_MAPPER_NAME = "oneformer_unified"
|
15 |
+
# Color augmentation
|
16 |
+
cfg.INPUT.COLOR_AUG_SSD = False
|
17 |
+
# We retry random cropping until no single category in semantic segmentation GT occupies more
|
18 |
+
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
|
19 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
|
20 |
+
# Pad image and segmentation GT in dataset mapper.
|
21 |
+
cfg.INPUT.SIZE_DIVISIBILITY = -1
|
22 |
+
|
23 |
+
cfg.INPUT.TASK_SEQ_LEN = 77
|
24 |
+
cfg.INPUT.MAX_SEQ_LEN = 77
|
25 |
+
|
26 |
+
cfg.INPUT.TASK_PROB = CN()
|
27 |
+
cfg.INPUT.TASK_PROB.SEMANTIC = 0.33
|
28 |
+
cfg.INPUT.TASK_PROB.INSTANCE = 0.66
|
29 |
+
|
30 |
+
# test dataset
|
31 |
+
cfg.DATASETS.TEST_PANOPTIC = ("",)
|
32 |
+
cfg.DATASETS.TEST_INSTANCE = ("",)
|
33 |
+
cfg.DATASETS.TEST_SEMANTIC = ("",)
|
34 |
+
|
35 |
+
# solver config
|
36 |
+
# weight decay on embedding
|
37 |
+
cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
|
38 |
+
# optimizer
|
39 |
+
cfg.SOLVER.OPTIMIZER = "ADAMW"
|
40 |
+
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
|
41 |
+
|
42 |
+
# wandb
|
43 |
+
cfg.WANDB = CN()
|
44 |
+
cfg.WANDB.PROJECT = "unified_dense_recognition"
|
45 |
+
cfg.WANDB.NAME = None
|
46 |
+
|
47 |
+
cfg.MODEL.IS_TRAIN = False
|
48 |
+
cfg.MODEL.IS_DEMO = True
|
49 |
+
|
50 |
+
# text encoder config
|
51 |
+
cfg.MODEL.TEXT_ENCODER = CN()
|
52 |
+
|
53 |
+
cfg.MODEL.TEXT_ENCODER.WIDTH = 256
|
54 |
+
cfg.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
|
55 |
+
cfg.MODEL.TEXT_ENCODER.NUM_LAYERS = 12
|
56 |
+
cfg.MODEL.TEXT_ENCODER.VOCAB_SIZE = 49408
|
57 |
+
cfg.MODEL.TEXT_ENCODER.PROJ_NUM_LAYERS = 2
|
58 |
+
cfg.MODEL.TEXT_ENCODER.N_CTX = 16
|
59 |
+
|
60 |
+
# mask_former inference config
|
61 |
+
cfg.MODEL.TEST = CN()
|
62 |
+
cfg.MODEL.TEST.SEMANTIC_ON = True
|
63 |
+
cfg.MODEL.TEST.INSTANCE_ON = False
|
64 |
+
cfg.MODEL.TEST.PANOPTIC_ON = False
|
65 |
+
cfg.MODEL.TEST.DETECTION_ON = False
|
66 |
+
cfg.MODEL.TEST.OBJECT_MASK_THRESHOLD = 0.0
|
67 |
+
cfg.MODEL.TEST.OVERLAP_THRESHOLD = 0.0
|
68 |
+
cfg.MODEL.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
|
69 |
+
cfg.MODEL.TEST.TASK = "panoptic"
|
70 |
+
|
71 |
+
# TEST AUG Slide
|
72 |
+
cfg.TEST.AUG.IS_SLIDE = False
|
73 |
+
cfg.TEST.AUG.CROP_SIZE = (640, 640)
|
74 |
+
cfg.TEST.AUG.STRIDE = (426, 426)
|
75 |
+
cfg.TEST.AUG.SCALE = (2048, 640)
|
76 |
+
cfg.TEST.AUG.SETR_MULTI_SCALE = True
|
77 |
+
cfg.TEST.AUG.KEEP_RATIO = True
|
78 |
+
cfg.TEST.AUG.SIZE_DIVISOR = 32
|
79 |
+
|
80 |
+
# pixel decoder config
|
81 |
+
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
|
82 |
+
# adding transformer in pixel decoder
|
83 |
+
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
|
84 |
+
# pixel decoder
|
85 |
+
cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
|
86 |
+
cfg.MODEL.SEM_SEG_HEAD.SEM_EMBED_DIM = 256
|
87 |
+
cfg.MODEL.SEM_SEG_HEAD.INST_EMBED_DIM = 256
|
88 |
+
|
89 |
+
# LSJ aug
|
90 |
+
cfg.INPUT.IMAGE_SIZE = 1024
|
91 |
+
cfg.INPUT.MIN_SCALE = 0.1
|
92 |
+
cfg.INPUT.MAX_SCALE = 2.0
|
93 |
+
|
94 |
+
# MSDeformAttn encoder configs
|
95 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
|
96 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
|
97 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
|
98 |
+
|
99 |
+
def add_oneformer_config(cfg):
|
100 |
+
"""
|
101 |
+
Add config for ONE_FORMER.
|
102 |
+
"""
|
103 |
+
|
104 |
+
# mask_former model config
|
105 |
+
cfg.MODEL.ONE_FORMER = CN()
|
106 |
+
|
107 |
+
# loss
|
108 |
+
cfg.MODEL.ONE_FORMER.DEEP_SUPERVISION = True
|
109 |
+
cfg.MODEL.ONE_FORMER.NO_OBJECT_WEIGHT = 0.1
|
110 |
+
cfg.MODEL.ONE_FORMER.CLASS_WEIGHT = 1.0
|
111 |
+
cfg.MODEL.ONE_FORMER.DICE_WEIGHT = 1.0
|
112 |
+
cfg.MODEL.ONE_FORMER.MASK_WEIGHT = 20.0
|
113 |
+
cfg.MODEL.ONE_FORMER.CONTRASTIVE_WEIGHT = 0.5
|
114 |
+
cfg.MODEL.ONE_FORMER.CONTRASTIVE_TEMPERATURE = 0.07
|
115 |
+
|
116 |
+
# transformer config
|
117 |
+
cfg.MODEL.ONE_FORMER.NHEADS = 8
|
118 |
+
cfg.MODEL.ONE_FORMER.DROPOUT = 0.1
|
119 |
+
cfg.MODEL.ONE_FORMER.DIM_FEEDFORWARD = 2048
|
120 |
+
cfg.MODEL.ONE_FORMER.ENC_LAYERS = 0
|
121 |
+
cfg.MODEL.ONE_FORMER.CLASS_DEC_LAYERS = 2
|
122 |
+
cfg.MODEL.ONE_FORMER.DEC_LAYERS = 6
|
123 |
+
cfg.MODEL.ONE_FORMER.PRE_NORM = False
|
124 |
+
|
125 |
+
cfg.MODEL.ONE_FORMER.HIDDEN_DIM = 256
|
126 |
+
cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES = 120
|
127 |
+
cfg.MODEL.ONE_FORMER.NUM_OBJECT_CTX = 16
|
128 |
+
cfg.MODEL.ONE_FORMER.USE_TASK_NORM = True
|
129 |
+
|
130 |
+
cfg.MODEL.ONE_FORMER.TRANSFORMER_IN_FEATURE = "res5"
|
131 |
+
cfg.MODEL.ONE_FORMER.ENFORCE_INPUT_PROJ = False
|
132 |
+
|
133 |
+
# Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
|
134 |
+
# you can use this config to override
|
135 |
+
cfg.MODEL.ONE_FORMER.SIZE_DIVISIBILITY = 32
|
136 |
+
|
137 |
+
# transformer module
|
138 |
+
cfg.MODEL.ONE_FORMER.TRANSFORMER_DECODER_NAME = "ContrastiveMultiScaleMaskedTransformerDecoder"
|
139 |
+
|
140 |
+
# point loss configs
|
141 |
+
# Number of points sampled during training for a mask point head.
|
142 |
+
cfg.MODEL.ONE_FORMER.TRAIN_NUM_POINTS = 112 * 112
|
143 |
+
# Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
|
144 |
+
# original paper.
|
145 |
+
cfg.MODEL.ONE_FORMER.OVERSAMPLE_RATIO = 3.0
|
146 |
+
# Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
|
147 |
+
# the original paper.
|
148 |
+
cfg.MODEL.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
|
149 |
+
|
150 |
+
def add_swin_config(cfg):
|
151 |
+
"""
|
152 |
+
Add config forSWIN Backbone.
|
153 |
+
"""
|
154 |
+
|
155 |
+
# swin transformer backbone
|
156 |
+
cfg.MODEL.SWIN = CN()
|
157 |
+
cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
|
158 |
+
cfg.MODEL.SWIN.PATCH_SIZE = 4
|
159 |
+
cfg.MODEL.SWIN.EMBED_DIM = 96
|
160 |
+
cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
|
161 |
+
cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
162 |
+
cfg.MODEL.SWIN.WINDOW_SIZE = 7
|
163 |
+
cfg.MODEL.SWIN.MLP_RATIO = 4.0
|
164 |
+
cfg.MODEL.SWIN.QKV_BIAS = True
|
165 |
+
cfg.MODEL.SWIN.QK_SCALE = None
|
166 |
+
cfg.MODEL.SWIN.DROP_RATE = 0.0
|
167 |
+
cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
|
168 |
+
cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
|
169 |
+
cfg.MODEL.SWIN.APE = False
|
170 |
+
cfg.MODEL.SWIN.PATCH_NORM = True
|
171 |
+
cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
172 |
+
cfg.MODEL.SWIN.USE_CHECKPOINT = False
|
173 |
+
## Semask additions
|
174 |
+
cfg.MODEL.SWIN.SEM_WINDOW_SIZE = 7
|
175 |
+
cfg.MODEL.SWIN.NUM_SEM_BLOCKS = 1
|
176 |
+
|
177 |
+
def add_dinat_config(cfg):
|
178 |
+
"""
|
179 |
+
Add config for NAT Backbone.
|
180 |
+
"""
|
181 |
+
|
182 |
+
# DINAT transformer backbone
|
183 |
+
cfg.MODEL.DiNAT = CN()
|
184 |
+
cfg.MODEL.DiNAT.DEPTHS = [3, 4, 18, 5]
|
185 |
+
cfg.MODEL.DiNAT.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
186 |
+
cfg.MODEL.DiNAT.EMBED_DIM = 64
|
187 |
+
cfg.MODEL.DiNAT.MLP_RATIO = 3.0
|
188 |
+
cfg.MODEL.DiNAT.NUM_HEADS = [2, 4, 8, 16]
|
189 |
+
cfg.MODEL.DiNAT.DROP_PATH_RATE = 0.2
|
190 |
+
cfg.MODEL.DiNAT.KERNEL_SIZE = 7
|
191 |
+
cfg.MODEL.DiNAT.DILATIONS = [[1, 16, 1], [1, 4, 1, 8], [1, 2, 1, 3, 1, 4], [1, 2, 1, 2, 1]]
|
192 |
+
cfg.MODEL.DiNAT.OUT_INDICES = (0, 1, 2, 3)
|
193 |
+
cfg.MODEL.DiNAT.QKV_BIAS = True
|
194 |
+
cfg.MODEL.DiNAT.QK_SCALE = None
|
195 |
+
cfg.MODEL.DiNAT.DROP_RATE = 0
|
196 |
+
cfg.MODEL.DiNAT.ATTN_DROP_RATE = 0.
|
197 |
+
cfg.MODEL.DiNAT.IN_PATCH_SIZE = 4
|
198 |
+
|
199 |
+
def add_convnext_config(cfg):
|
200 |
+
"""
|
201 |
+
Add config for ConvNeXt Backbone.
|
202 |
+
"""
|
203 |
+
|
204 |
+
# swin transformer backbone
|
205 |
+
cfg.MODEL.CONVNEXT = CN()
|
206 |
+
cfg.MODEL.CONVNEXT.IN_CHANNELS = 3
|
207 |
+
cfg.MODEL.CONVNEXT.DEPTHS = [3, 3, 27, 3]
|
208 |
+
cfg.MODEL.CONVNEXT.DIMS = [192, 384, 768, 1536]
|
209 |
+
cfg.MODEL.CONVNEXT.DROP_PATH_RATE = 0.4
|
210 |
+
cfg.MODEL.CONVNEXT.LSIT = 1.0
|
211 |
+
cfg.MODEL.CONVNEXT.OUT_INDICES = [0, 1, 2, 3]
|
212 |
+
cfg.MODEL.CONVNEXT.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
213 |
+
|
214 |
+
def add_beit_adapter_config(cfg):
|
215 |
+
"""
|
216 |
+
Add config for BEiT Adapter Backbone.
|
217 |
+
"""
|
218 |
+
|
219 |
+
# beit adapter backbone
|
220 |
+
cfg.MODEL.BEiTAdapter = CN()
|
221 |
+
cfg.MODEL.BEiTAdapter.IMG_SIZE = 640
|
222 |
+
cfg.MODEL.BEiTAdapter.PATCH_SIZE = 16
|
223 |
+
cfg.MODEL.BEiTAdapter.EMBED_DIM = 1024
|
224 |
+
cfg.MODEL.BEiTAdapter.DEPTH = 24
|
225 |
+
cfg.MODEL.BEiTAdapter.NUM_HEADS = 16
|
226 |
+
cfg.MODEL.BEiTAdapter.MLP_RATIO = 4
|
227 |
+
cfg.MODEL.BEiTAdapter.QKV_BIAS = True
|
228 |
+
cfg.MODEL.BEiTAdapter.USE_ABS_POS_EMB = False
|
229 |
+
cfg.MODEL.BEiTAdapter.USE_REL_POS_BIAS = True
|
230 |
+
cfg.MODEL.BEiTAdapter.INIT_VALUES = 1e-6
|
231 |
+
cfg.MODEL.BEiTAdapter.DROP_PATH_RATE = 0.3
|
232 |
+
cfg.MODEL.BEiTAdapter.CONV_INPLANE = 64
|
233 |
+
cfg.MODEL.BEiTAdapter.N_POINTS = 4
|
234 |
+
cfg.MODEL.BEiTAdapter.DEFORM_NUM_HEADS = 16
|
235 |
+
cfg.MODEL.BEiTAdapter.CFFN_RATIO = 0.25
|
236 |
+
cfg.MODEL.BEiTAdapter.DEFORM_RATIO = 0.5
|
237 |
+
cfg.MODEL.BEiTAdapter.WITH_CP = True
|
238 |
+
cfg.MODEL.BEiTAdapter.INTERACTION_INDEXES=[[0, 5], [6, 11], [12, 17], [18, 23]]
|
239 |
+
cfg.MODEL.BEiTAdapter.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|