File size: 42,115 Bytes
fcd8cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
# Copyright (c) Facebook, Inc. and its affiliates.
import copy
import math
from typing import List
import torch
import torch.nn.functional as F
from fvcore.nn import sigmoid_focal_loss_star_jit, smooth_l1_loss
from torch import nn

from detectron2.layers import ShapeSpec, batched_nms, cat, paste_masks_in_image
from detectron2.modeling.anchor_generator import DefaultAnchorGenerator
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.retinanet import permute_to_N_HWA_K
from detectron2.structures import Boxes, ImageList, Instances

from tensormask.layers import SwapAlign2Nat

__all__ = ["TensorMask"]


def permute_all_cls_and_box_to_N_HWA_K_and_concat(pred_logits, pred_anchor_deltas, num_classes=80):
    """
    Rearrange the tensor layout from the network output, i.e.:
    list[Tensor]: #lvl tensors of shape (N, A x K, Hi, Wi)
    to per-image predictions, i.e.:
    Tensor: of shape (N x sum(Hi x Wi x A), K)
    """
    # for each feature level, permute the outputs to make them be in the
    # same format as the labels.
    pred_logits_flattened = [permute_to_N_HWA_K(x, num_classes) for x in pred_logits]
    pred_anchor_deltas_flattened = [permute_to_N_HWA_K(x, 4) for x in pred_anchor_deltas]
    # concatenate on the first dimension (representing the feature levels), to
    # take into account the way the labels were generated (with all feature maps
    # being concatenated as well)
    pred_logits = cat(pred_logits_flattened, dim=1).view(-1, num_classes)
    pred_anchor_deltas = cat(pred_anchor_deltas_flattened, dim=1).view(-1, 4)
    return pred_logits, pred_anchor_deltas


def _assignment_rule(
    gt_boxes,
    anchor_boxes,
    unit_lengths,
    min_anchor_size,
    scale_thresh=2.0,
    spatial_thresh=1.0,
    uniqueness_on=True,
):
    """
    Given two lists of boxes of N ground truth boxes and M anchor boxes,
    compute the assignment between the two, following the assignment rules in
    https://arxiv.org/abs/1903.12174.
    The box order must be (xmin, ymin, xmax, ymax), so please make sure to convert
    to BoxMode.XYXY_ABS before calling this function.

    Args:
        gt_boxes, anchor_boxes (Boxes): two Boxes. Contains N & M boxes/anchors, respectively.
        unit_lengths (Tensor): Contains the unit lengths of M anchor boxes.
        min_anchor_size (float): Minimum size of the anchor, in pixels
        scale_thresh (float): The `scale` threshold: the maximum size of the anchor
                              should not be greater than scale_thresh x max(h, w) of
                              the ground truth box.
        spatial_thresh (float): The `spatial` threshold: the l2 distance between the
                              center of the anchor and the ground truth box should not
                              be greater than spatial_thresh x u where u is the unit length.

    Returns:
        matches (Tensor[int64]): a vector of length M, where matches[i] is a matched
                ground-truth index in [0, N)
        match_labels (Tensor[int8]): a vector of length M, where pred_labels[i] indicates
            whether a prediction is a true or false positive or ignored
    """
    gt_boxes, anchor_boxes = gt_boxes.tensor, anchor_boxes.tensor
    N = gt_boxes.shape[0]
    M = anchor_boxes.shape[0]
    if N == 0 or M == 0:
        return (
            gt_boxes.new_full((N,), 0, dtype=torch.int64),
            gt_boxes.new_full((N,), -1, dtype=torch.int8),
        )

    # Containment rule
    lt = torch.min(gt_boxes[:, None, :2], anchor_boxes[:, :2])  # [N,M,2]
    rb = torch.max(gt_boxes[:, None, 2:], anchor_boxes[:, 2:])  # [N,M,2]
    union = cat([lt, rb], dim=2)  # [N,M,4]

    dummy_gt_boxes = torch.zeros_like(gt_boxes)
    anchor = dummy_gt_boxes[:, None, :] + anchor_boxes[:, :]  # [N,M,4]

    contain_matrix = torch.all(union == anchor, dim=2)  # [N,M]

    # Centrality rule, scale
    gt_size_lower = torch.max(gt_boxes[:, 2:] - gt_boxes[:, :2], dim=1)[0]  # [N]
    gt_size_upper = gt_size_lower * scale_thresh  # [N]
    # Fall back for small objects
    gt_size_upper[gt_size_upper < min_anchor_size] = min_anchor_size
    # Due to sampling of locations, the anchor sizes are deducted with sampling strides
    anchor_size = (
        torch.max(anchor_boxes[:, 2:] - anchor_boxes[:, :2], dim=1)[0] - unit_lengths
    )  # [M]

    size_diff_upper = gt_size_upper[:, None] - anchor_size  # [N,M]
    scale_matrix = size_diff_upper >= 0  # [N,M]

    # Centrality rule, spatial
    gt_center = (gt_boxes[:, 2:] + gt_boxes[:, :2]) / 2  # [N,2]
    anchor_center = (anchor_boxes[:, 2:] + anchor_boxes[:, :2]) / 2  # [M,2]
    offset_center = gt_center[:, None, :] - anchor_center[:, :]  # [N,M,2]
    offset_center /= unit_lengths[:, None]  # [N,M,2]
    spatial_square = spatial_thresh * spatial_thresh
    spatial_matrix = torch.sum(offset_center * offset_center, dim=2) <= spatial_square

    assign_matrix = (contain_matrix & scale_matrix & spatial_matrix).int()

    # assign_matrix is N (gt) x M (predicted)
    # Max over gt elements (dim 0) to find best gt candidate for each prediction
    matched_vals, matches = assign_matrix.max(dim=0)
    match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)

    match_labels[matched_vals == 0] = 0
    match_labels[matched_vals == 1] = 1

    # find all the elements that match to ground truths multiple times
    not_unique_idxs = assign_matrix.sum(dim=0) > 1
    if uniqueness_on:
        match_labels[not_unique_idxs] = 0
    else:
        match_labels[not_unique_idxs] = -1

    return matches, match_labels


# TODO make the paste_mask function in d2 core support mask list
def _paste_mask_lists_in_image(masks, boxes, image_shape, threshold=0.5):
    """
    Paste a list of masks that are of various resolutions (e.g., 28 x 28) into an image.
    The location, height, and width for pasting each mask is determined by their
    corresponding bounding boxes in boxes.

    Args:
        masks (list(Tensor)): A list of Tensor of shape (1, Hmask_i, Wmask_i).
                            Values are in [0, 1]. The list length, Bimg, is the
                            number of detected object instances in the image.
        boxes (Boxes): A Boxes of length Bimg. boxes.tensor[i] and masks[i] correspond
                            to the same object instance.
        image_shape (tuple): height, width
        threshold (float): A threshold in [0, 1] for converting the (soft) masks to
            binary masks.

    Returns:
        img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the
        number of detected object instances and Himage, Wimage are the image width
        and height. img_masks[i] is a binary mask for object instance i.
    """
    if len(masks) == 0:
        return torch.empty((0, 1) + image_shape, dtype=torch.uint8)

    # Loop over masks groups. Each group has the same mask prediction size.
    img_masks = []
    ind_masks = []
    mask_sizes = torch.tensor([m.shape[-1] for m in masks])
    unique_sizes = torch.unique(mask_sizes)
    for msize in unique_sizes.tolist():
        cur_ind = torch.where(mask_sizes == msize)[0]
        ind_masks.append(cur_ind)

        cur_masks = cat([masks[i] for i in cur_ind])
        cur_boxes = boxes[cur_ind]
        img_masks.append(paste_masks_in_image(cur_masks, cur_boxes, image_shape, threshold))

    img_masks = cat(img_masks)
    ind_masks = cat(ind_masks)

    img_masks_out = torch.empty_like(img_masks)
    img_masks_out[ind_masks, :, :] = img_masks

    return img_masks_out


def _postprocess(results, result_mask_info, output_height, output_width, mask_threshold=0.5):
    """
    Post-process the output boxes for TensorMask.
    The input images are often resized when entering an object detector.
    As a result, we often need the outputs of the detector in a different
    resolution from its inputs.

    This function will postprocess the raw outputs of TensorMask
    to produce outputs according to the desired output resolution.

    Args:
        results (Instances): the raw outputs from the detector.
            `results.image_size` contains the input image resolution the detector sees.
            This object might be modified in-place. Note that it does not contain the field
            `pred_masks`, which is provided by another input `result_masks`.
        result_mask_info (list[Tensor], Boxes): a pair of two items for mask related results.
                The first item is a list of #detection tensors, each is the predicted masks.
                The second item is the anchors corresponding to the predicted masks.
        output_height, output_width: the desired output resolution.

    Returns:
        Instances: the postprocessed output from the model, based on the output resolution
    """
    scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0])
    results = Instances((output_height, output_width), **results.get_fields())

    output_boxes = results.pred_boxes
    output_boxes.tensor[:, 0::2] *= scale_x
    output_boxes.tensor[:, 1::2] *= scale_y
    output_boxes.clip(results.image_size)

    inds_nonempty = output_boxes.nonempty()
    results = results[inds_nonempty]
    result_masks, result_anchors = result_mask_info
    if result_masks:
        result_anchors.tensor[:, 0::2] *= scale_x
        result_anchors.tensor[:, 1::2] *= scale_y
        result_masks = [x for (i, x) in zip(inds_nonempty.tolist(), result_masks) if i]
        results.pred_masks = _paste_mask_lists_in_image(
            result_masks,
            result_anchors[inds_nonempty],
            results.image_size,
            threshold=mask_threshold,
        )
    return results


class TensorMaskAnchorGenerator(DefaultAnchorGenerator):
    """
    For a set of image sizes and feature maps, computes a set of anchors for TensorMask.
    It also computes the unit lengths and indexes for each anchor box.
    """

    def grid_anchors_with_unit_lengths_and_indexes(self, grid_sizes):
        anchors = []
        unit_lengths = []
        indexes = []
        for lvl, (size, stride, base_anchors) in enumerate(
            zip(grid_sizes, self.strides, self.cell_anchors)
        ):
            grid_height, grid_width = size
            device = base_anchors.device
            shifts_x = torch.arange(
                0, grid_width * stride, step=stride, dtype=torch.float32, device=device
            )
            shifts_y = torch.arange(
                0, grid_height * stride, step=stride, dtype=torch.float32, device=device
            )
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=2)
            # Stack anchors in shapes of (HWA, 4)
            cur_anchor = (shifts[:, :, None, :] + base_anchors.view(1, 1, -1, 4)).view(-1, 4)
            anchors.append(cur_anchor)
            unit_lengths.append(
                torch.full((cur_anchor.shape[0],), stride, dtype=torch.float32, device=device)
            )
            # create mask indexes using mesh grid
            shifts_l = torch.full((1,), lvl, dtype=torch.int64, device=device)
            shifts_i = torch.zeros((1,), dtype=torch.int64, device=device)
            shifts_h = torch.arange(0, grid_height, dtype=torch.int64, device=device)
            shifts_w = torch.arange(0, grid_width, dtype=torch.int64, device=device)
            shifts_a = torch.arange(0, base_anchors.shape[0], dtype=torch.int64, device=device)
            grids = torch.meshgrid(shifts_l, shifts_i, shifts_h, shifts_w, shifts_a)

            indexes.append(torch.stack(grids, dim=5).view(-1, 5))

        return anchors, unit_lengths, indexes

    def forward(self, features):
        """
        Returns:
            list[list[Boxes]]: a list of #image elements. Each is a list of #feature level Boxes.
                The Boxes contains anchors of this image on the specific feature level.
            list[list[Tensor]]: a list of #image elements. Each is a list of #feature level tensors.
                The tensor contains strides, or unit lengths for the anchors.
            list[list[Tensor]]: a list of #image elements. Each is a list of #feature level tensors.
                The Tensor contains indexes for the anchors, with the last dimension meaning
                (L, N, H, W, A), where L is level, I is image (not set yet), H is height,
                W is width, and A is anchor.
        """
        num_images = len(features[0])
        grid_sizes = [feature_map.shape[-2:] for feature_map in features]
        anchors_list, lengths_list, indexes_list = self.grid_anchors_with_unit_lengths_and_indexes(
            grid_sizes
        )

        # Convert anchors from Tensor to Boxes
        anchors_per_im = [Boxes(x) for x in anchors_list]

        # TODO it can be simplified to not return duplicated information for
        # each image, just like detectron2's own AnchorGenerator
        anchors = [copy.deepcopy(anchors_per_im) for _ in range(num_images)]
        unit_lengths = [copy.deepcopy(lengths_list) for _ in range(num_images)]
        indexes = [copy.deepcopy(indexes_list) for _ in range(num_images)]

        return anchors, unit_lengths, indexes


@META_ARCH_REGISTRY.register()
class TensorMask(nn.Module):
    """
    TensorMask model. Creates FPN backbone, anchors and a head for classification
    and box regression. Calculates and applies proper losses to class, box, and
    masks.
    """

    def __init__(self, cfg):
        super().__init__()

        # fmt: off
        self.num_classes              = cfg.MODEL.TENSOR_MASK.NUM_CLASSES
        self.in_features              = cfg.MODEL.TENSOR_MASK.IN_FEATURES
        self.anchor_sizes             = cfg.MODEL.ANCHOR_GENERATOR.SIZES
        self.num_levels               = len(cfg.MODEL.ANCHOR_GENERATOR.SIZES)
        # Loss parameters:
        self.focal_loss_alpha         = cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_ALPHA
        self.focal_loss_gamma         = cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_GAMMA
        # Inference parameters:
        self.score_threshold          = cfg.MODEL.TENSOR_MASK.SCORE_THRESH_TEST
        self.topk_candidates          = cfg.MODEL.TENSOR_MASK.TOPK_CANDIDATES_TEST
        self.nms_threshold            = cfg.MODEL.TENSOR_MASK.NMS_THRESH_TEST
        self.detections_im            = cfg.TEST.DETECTIONS_PER_IMAGE
        # Mask parameters:
        self.mask_on                  = cfg.MODEL.MASK_ON
        self.mask_loss_weight         = cfg.MODEL.TENSOR_MASK.MASK_LOSS_WEIGHT
        self.mask_pos_weight          = torch.tensor(cfg.MODEL.TENSOR_MASK.POSITIVE_WEIGHT,
                                                     dtype=torch.float32)
        self.bipyramid_on             = cfg.MODEL.TENSOR_MASK.BIPYRAMID_ON
        # fmt: on

        # build the backbone
        self.backbone = build_backbone(cfg)

        backbone_shape = self.backbone.output_shape()
        feature_shapes = [backbone_shape[f] for f in self.in_features]
        feature_strides = [x.stride for x in feature_shapes]
        # build anchors
        self.anchor_generator = TensorMaskAnchorGenerator(cfg, feature_shapes)
        self.num_anchors = self.anchor_generator.num_cell_anchors[0]
        anchors_min_level = cfg.MODEL.ANCHOR_GENERATOR.SIZES[0]
        self.mask_sizes = [size // feature_strides[0] for size in anchors_min_level]
        self.min_anchor_size = min(anchors_min_level) - feature_strides[0]

        # head of the TensorMask
        self.head = TensorMaskHead(
            cfg, self.num_levels, self.num_anchors, self.mask_sizes, feature_shapes
        )
        # box transform
        self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.TENSOR_MASK.BBOX_REG_WEIGHTS)
        self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False)

    @property
    def device(self):
        return self.pixel_mean.device

    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DetectionTransform` .
                Each item in the list contains the inputs for one image.
            For now, each item in the list is a dict that contains:
                image: Tensor, image in (C, H, W) format.
                instances: Instances
                Other information that's included in the original dicts, such as:
                    "height", "width" (int): the output resolution of the model, used in inference.
                        See :meth:`postprocess` for details.
         Returns:
            losses (dict[str: Tensor]): mapping from a named loss to a tensor
                storing the loss. Used during training only.
        """
        images = self.preprocess_image(batched_inputs)
        if "instances" in batched_inputs[0]:
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
        else:
            gt_instances = None

        features = self.backbone(images.tensor)
        features = [features[f] for f in self.in_features]
        # apply the TensorMask head
        pred_logits, pred_deltas, pred_masks = self.head(features)
        # generate anchors based on features, is it image specific?
        anchors, unit_lengths, indexes = self.anchor_generator(features)

        if self.training:
            # get ground truths for class labels and box targets, it will label each anchor
            gt_class_info, gt_delta_info, gt_mask_info, num_fg = self.get_ground_truth(
                anchors, unit_lengths, indexes, gt_instances
            )
            # compute the loss
            return self.losses(
                gt_class_info,
                gt_delta_info,
                gt_mask_info,
                num_fg,
                pred_logits,
                pred_deltas,
                pred_masks,
            )
        else:
            # do inference to get the output
            results = self.inference(pred_logits, pred_deltas, pred_masks, anchors, indexes, images)
            processed_results = []
            for results_im, input_im, image_size in zip(
                results, batched_inputs, images.image_sizes
            ):
                height = input_im.get("height", image_size[0])
                width = input_im.get("width", image_size[1])
                # this is to do post-processing with the image size
                result_box, result_mask = results_im
                r = _postprocess(result_box, result_mask, height, width)
                processed_results.append({"instances": r})
            return processed_results

    def losses(
        self,
        gt_class_info,
        gt_delta_info,
        gt_mask_info,
        num_fg,
        pred_logits,
        pred_deltas,
        pred_masks,
    ):
        """
        Args:
            For `gt_class_info`, `gt_delta_info`, `gt_mask_info` and `num_fg` parameters, see
                :meth:`TensorMask.get_ground_truth`.
            For `pred_logits`, `pred_deltas` and `pred_masks`, see
                :meth:`TensorMaskHead.forward`.

        Returns:
            losses (dict[str: Tensor]): mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The potential dict keys are:
                "loss_cls", "loss_box_reg" and "loss_mask".
        """
        gt_classes_target, gt_valid_inds = gt_class_info
        gt_deltas, gt_fg_inds = gt_delta_info
        gt_masks, gt_mask_inds = gt_mask_info
        loss_normalizer = torch.tensor(max(1, num_fg), dtype=torch.float32, device=self.device)

        # classification and regression
        pred_logits, pred_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_logits, pred_deltas, self.num_classes
        )
        loss_cls = (
            sigmoid_focal_loss_star_jit(
                pred_logits[gt_valid_inds],
                gt_classes_target[gt_valid_inds],
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            )
            / loss_normalizer
        )

        if num_fg == 0:
            loss_box_reg = pred_deltas.sum() * 0
        else:
            loss_box_reg = (
                smooth_l1_loss(pred_deltas[gt_fg_inds], gt_deltas, beta=0.0, reduction="sum")
                / loss_normalizer
            )
        losses = {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}

        # mask prediction
        if self.mask_on:
            loss_mask = 0
            for lvl in range(self.num_levels):
                cur_level_factor = 2**lvl if self.bipyramid_on else 1
                for anc in range(self.num_anchors):
                    cur_gt_mask_inds = gt_mask_inds[lvl][anc]
                    if cur_gt_mask_inds is None:
                        loss_mask += pred_masks[lvl][anc][0, 0, 0, 0] * 0
                    else:
                        cur_mask_size = self.mask_sizes[anc] * cur_level_factor
                        # TODO maybe there are numerical issues when mask sizes are large
                        cur_size_divider = torch.tensor(
                            self.mask_loss_weight / (cur_mask_size**2),
                            dtype=torch.float32,
                            device=self.device,
                        )

                        cur_pred_masks = pred_masks[lvl][anc][
                            cur_gt_mask_inds[:, 0],  # N
                            :,  # V x U
                            cur_gt_mask_inds[:, 1],  # H
                            cur_gt_mask_inds[:, 2],  # W
                        ]

                        loss_mask += F.binary_cross_entropy_with_logits(
                            cur_pred_masks.view(-1, cur_mask_size, cur_mask_size),  # V, U
                            gt_masks[lvl][anc].to(dtype=torch.float32),
                            reduction="sum",
                            weight=cur_size_divider,
                            pos_weight=self.mask_pos_weight,
                        )
            losses["loss_mask"] = loss_mask / loss_normalizer
        return losses

    @torch.no_grad()
    def get_ground_truth(self, anchors, unit_lengths, indexes, targets):
        """
        Args:
            anchors (list[list[Boxes]]): a list of N=#image elements. Each is a
                list of #feature level Boxes. The Boxes contains anchors of
                this image on the specific feature level.
            unit_lengths (list[list[Tensor]]): a list of N=#image elements. Each is a
                list of #feature level Tensor. The tensor contains unit lengths for anchors of
                this image on the specific feature level.
            indexes (list[list[Tensor]]): a list of N=#image elements. Each is a
                list of #feature level Tensor. The tensor contains the 5D index of
                each anchor, the second dimension means (L, N, H, W, A), where L
                is level, I is image, H is height, W is width, and A is anchor.
            targets (list[Instances]): a list of N `Instances`s. The i-th
                `Instances` contains the ground-truth per-instance annotations
                for the i-th input image.  Specify `targets` during training only.

        Returns:
            gt_class_info (Tensor, Tensor): A pair of two tensors for classification.
                The first one is an integer tensor of shape (R, #classes) storing ground-truth
                labels for each anchor. R is the total number of anchors in the batch.
                The second one is an integer tensor of shape (R,), to indicate which
                anchors are valid for loss computation, which anchors are not.
            gt_delta_info (Tensor, Tensor): A pair of two tensors for boxes.
                The first one, of shape (F, 4). F=#foreground anchors.
                The last dimension represents ground-truth box2box transform
                targets (dx, dy, dw, dh) that map each anchor to its matched ground-truth box.
                Only foreground anchors have values in this tensor. Could be `None` if F=0.
                The second one, of shape (R,), is an integer tensor indicating which anchors
                are foreground ones used for box regression. Could be `None` if F=0.
            gt_mask_info (list[list[Tensor]], list[list[Tensor]]): A pair of two lists for masks.
                The first one is a list of P=#feature level elements. Each is a
                list of A=#anchor tensors. Each tensor contains the ground truth
                masks of the same size and for the same feature level. Could be `None`.
                The second one is a list of P=#feature level elements. Each is a
                list of A=#anchor tensors. Each tensor contains the location of the ground truth
                masks of the same size and for the same feature level. The second dimension means
                (N, H, W), where N is image, H is height, and W is width. Could be `None`.
            num_fg (int): F=#foreground anchors, used later for loss normalization.
        """
        gt_classes = []
        gt_deltas = []
        gt_masks = [[[] for _ in range(self.num_anchors)] for _ in range(self.num_levels)]
        gt_mask_inds = [[[] for _ in range(self.num_anchors)] for _ in range(self.num_levels)]

        anchors = [Boxes.cat(anchors_i) for anchors_i in anchors]
        unit_lengths = [cat(unit_lengths_i) for unit_lengths_i in unit_lengths]
        indexes = [cat(indexes_i) for indexes_i in indexes]

        num_fg = 0
        for i, (anchors_im, unit_lengths_im, indexes_im, targets_im) in enumerate(
            zip(anchors, unit_lengths, indexes, targets)
        ):
            # Initialize all
            gt_classes_i = torch.full_like(
                unit_lengths_im, self.num_classes, dtype=torch.int64, device=self.device
            )
            # Ground truth classes
            has_gt = len(targets_im) > 0
            if has_gt:
                # Compute the pairwise matrix
                gt_matched_inds, anchor_labels = _assignment_rule(
                    targets_im.gt_boxes, anchors_im, unit_lengths_im, self.min_anchor_size
                )
                # Find the foreground instances
                fg_inds = anchor_labels == 1
                fg_anchors = anchors_im[fg_inds]
                num_fg += len(fg_anchors)
                # Find the ground truths for foreground instances
                gt_fg_matched_inds = gt_matched_inds[fg_inds]
                # Assign labels for foreground instances
                gt_classes_i[fg_inds] = targets_im.gt_classes[gt_fg_matched_inds]
                # Anchors with label -1 are ignored, others are left as negative
                gt_classes_i[anchor_labels == -1] = -1

                # Boxes
                # Ground truth box regression, only for foregrounds
                matched_gt_boxes = targets_im[gt_fg_matched_inds].gt_boxes
                # Compute box regression offsets for foregrounds only
                gt_deltas_i = self.box2box_transform.get_deltas(
                    fg_anchors.tensor, matched_gt_boxes.tensor
                )
                gt_deltas.append(gt_deltas_i)

                # Masks
                if self.mask_on:
                    # Compute masks for each level and each anchor
                    matched_indexes = indexes_im[fg_inds, :]
                    for lvl in range(self.num_levels):
                        ids_lvl = matched_indexes[:, 0] == lvl
                        if torch.any(ids_lvl):
                            cur_level_factor = 2**lvl if self.bipyramid_on else 1
                            for anc in range(self.num_anchors):
                                ids_lvl_anchor = ids_lvl & (matched_indexes[:, 4] == anc)
                                if torch.any(ids_lvl_anchor):
                                    gt_masks[lvl][anc].append(
                                        targets_im[
                                            gt_fg_matched_inds[ids_lvl_anchor]
                                        ].gt_masks.crop_and_resize(
                                            fg_anchors[ids_lvl_anchor].tensor,
                                            self.mask_sizes[anc] * cur_level_factor,
                                        )
                                    )
                                    # Select (N, H, W) dimensions
                                    gt_mask_inds_lvl_anc = matched_indexes[ids_lvl_anchor, 1:4]
                                    # Set the image index to the current image
                                    gt_mask_inds_lvl_anc[:, 0] = i
                                    gt_mask_inds[lvl][anc].append(gt_mask_inds_lvl_anc)
            gt_classes.append(gt_classes_i)

        # Classes and boxes
        gt_classes = cat(gt_classes)
        gt_valid_inds = gt_classes >= 0
        gt_fg_inds = gt_valid_inds & (gt_classes < self.num_classes)
        gt_classes_target = torch.zeros(
            (gt_classes.shape[0], self.num_classes), dtype=torch.float32, device=self.device
        )
        gt_classes_target[gt_fg_inds, gt_classes[gt_fg_inds]] = 1
        gt_deltas = cat(gt_deltas) if gt_deltas else None

        # Masks
        gt_masks = [[cat(mla) if mla else None for mla in ml] for ml in gt_masks]
        gt_mask_inds = [[cat(ila) if ila else None for ila in il] for il in gt_mask_inds]
        return (
            (gt_classes_target, gt_valid_inds),
            (gt_deltas, gt_fg_inds),
            (gt_masks, gt_mask_inds),
            num_fg,
        )

    def inference(self, pred_logits, pred_deltas, pred_masks, anchors, indexes, images):
        """
        Arguments:
            pred_logits, pred_deltas, pred_masks: Same as the output of:
                meth:`TensorMaskHead.forward`
            anchors, indexes: Same as the input of meth:`TensorMask.get_ground_truth`
            images (ImageList): the input images

        Returns:
            results (List[Instances]): a list of #images elements.
        """
        assert len(anchors) == len(images)
        results = []

        pred_logits = [permute_to_N_HWA_K(x, self.num_classes) for x in pred_logits]
        pred_deltas = [permute_to_N_HWA_K(x, 4) for x in pred_deltas]

        pred_logits = cat(pred_logits, dim=1)
        pred_deltas = cat(pred_deltas, dim=1)

        for img_idx, (anchors_im, indexes_im) in enumerate(zip(anchors, indexes)):
            # Get the size of the current image
            image_size = images.image_sizes[img_idx]

            logits_im = pred_logits[img_idx]
            deltas_im = pred_deltas[img_idx]

            if self.mask_on:
                masks_im = [[mla[img_idx] for mla in ml] for ml in pred_masks]
            else:
                masks_im = [None] * self.num_levels
            results_im = self.inference_single_image(
                logits_im,
                deltas_im,
                masks_im,
                Boxes.cat(anchors_im),
                cat(indexes_im),
                tuple(image_size),
            )
            results.append(results_im)
        return results

    def inference_single_image(
        self, pred_logits, pred_deltas, pred_masks, anchors, indexes, image_size
    ):
        """
        Single-image inference. Return bounding-box detection results by thresholding
        on scores and applying non-maximum suppression (NMS).

        Arguments:
            pred_logits (list[Tensor]): list of #feature levels. Each entry contains
                tensor of size (AxHxW, K)
            pred_deltas (list[Tensor]): Same shape as 'pred_logits' except that K becomes 4.
            pred_masks (list[list[Tensor]]): List of #feature levels, each is a list of #anchors.
                Each entry contains tensor of size (M_i*M_i, H, W). `None` if mask_on=False.
            anchors (list[Boxes]): list of #feature levels. Each entry contains
                a Boxes object, which contains all the anchors for that
                image in that feature level.
            image_size (tuple(H, W)): a tuple of the image height and width.

        Returns:
            Same as `inference`, but for only one image.
        """
        pred_logits = pred_logits.flatten().sigmoid_()
        # We get top locations across all levels to accelerate the inference speed,
        # which does not seem to affect the accuracy.
        # First select values above the threshold
        logits_top_idxs = torch.where(pred_logits > self.score_threshold)[0]
        # Then get the top values
        num_topk = min(self.topk_candidates, logits_top_idxs.shape[0])
        pred_prob, topk_idxs = pred_logits[logits_top_idxs].sort(descending=True)
        # Keep top k scoring values
        pred_prob = pred_prob[:num_topk]
        # Keep top k values
        top_idxs = logits_top_idxs[topk_idxs[:num_topk]]

        # class index
        cls_idxs = top_idxs % self.num_classes
        # HWA index
        top_idxs //= self.num_classes
        # predict boxes
        pred_boxes = self.box2box_transform.apply_deltas(
            pred_deltas[top_idxs], anchors[top_idxs].tensor
        )
        # apply nms
        keep = batched_nms(pred_boxes, pred_prob, cls_idxs, self.nms_threshold)
        # pick the top ones
        keep = keep[: self.detections_im]

        results = Instances(image_size)
        results.pred_boxes = Boxes(pred_boxes[keep])
        results.scores = pred_prob[keep]
        results.pred_classes = cls_idxs[keep]

        # deal with masks
        result_masks, result_anchors = [], None
        if self.mask_on:
            # index and anchors, useful for masks
            top_indexes = indexes[top_idxs]
            top_anchors = anchors[top_idxs]
            result_indexes = top_indexes[keep]
            result_anchors = top_anchors[keep]
            # Get masks and do sigmoid
            for lvl, _, h, w, anc in result_indexes.tolist():
                cur_size = self.mask_sizes[anc] * (2**lvl if self.bipyramid_on else 1)
                result_masks.append(
                    torch.sigmoid(pred_masks[lvl][anc][:, h, w].view(1, cur_size, cur_size))
                )

        return results, (result_masks, result_anchors)

    def preprocess_image(self, batched_inputs):
        """
        Normalize, pad and batch the input images.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        images = ImageList.from_tensors(images, self.backbone.size_divisibility)
        return images


class TensorMaskHead(nn.Module):
    def __init__(self, cfg, num_levels, num_anchors, mask_sizes, input_shape: List[ShapeSpec]):
        """
        TensorMask head.
        """
        super().__init__()
        # fmt: off
        self.in_features        = cfg.MODEL.TENSOR_MASK.IN_FEATURES
        in_channels             = input_shape[0].channels
        num_classes             = cfg.MODEL.TENSOR_MASK.NUM_CLASSES
        cls_channels            = cfg.MODEL.TENSOR_MASK.CLS_CHANNELS
        num_convs               = cfg.MODEL.TENSOR_MASK.NUM_CONVS
        # box parameters
        bbox_channels           = cfg.MODEL.TENSOR_MASK.BBOX_CHANNELS
        # mask parameters
        self.mask_on            = cfg.MODEL.MASK_ON
        self.mask_sizes         = mask_sizes
        mask_channels           = cfg.MODEL.TENSOR_MASK.MASK_CHANNELS
        self.align_on           = cfg.MODEL.TENSOR_MASK.ALIGNED_ON
        self.bipyramid_on       = cfg.MODEL.TENSOR_MASK.BIPYRAMID_ON
        # fmt: on

        # class subnet
        cls_subnet = []
        cur_channels = in_channels
        for _ in range(num_convs):
            cls_subnet.append(
                nn.Conv2d(cur_channels, cls_channels, kernel_size=3, stride=1, padding=1)
            )
            cur_channels = cls_channels
            cls_subnet.append(nn.ReLU())

        self.cls_subnet = nn.Sequential(*cls_subnet)
        self.cls_score = nn.Conv2d(
            cur_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1
        )
        modules_list = [self.cls_subnet, self.cls_score]

        # box subnet
        bbox_subnet = []
        cur_channels = in_channels
        for _ in range(num_convs):
            bbox_subnet.append(
                nn.Conv2d(cur_channels, bbox_channels, kernel_size=3, stride=1, padding=1)
            )
            cur_channels = bbox_channels
            bbox_subnet.append(nn.ReLU())

        self.bbox_subnet = nn.Sequential(*bbox_subnet)
        self.bbox_pred = nn.Conv2d(
            cur_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1
        )
        modules_list.extend([self.bbox_subnet, self.bbox_pred])

        # mask subnet
        if self.mask_on:
            mask_subnet = []
            cur_channels = in_channels
            for _ in range(num_convs):
                mask_subnet.append(
                    nn.Conv2d(cur_channels, mask_channels, kernel_size=3, stride=1, padding=1)
                )
                cur_channels = mask_channels
                mask_subnet.append(nn.ReLU())

            self.mask_subnet = nn.Sequential(*mask_subnet)
            modules_list.append(self.mask_subnet)
            for mask_size in self.mask_sizes:
                cur_mask_module = "mask_pred_%02d" % mask_size
                self.add_module(
                    cur_mask_module,
                    nn.Conv2d(
                        cur_channels, mask_size * mask_size, kernel_size=1, stride=1, padding=0
                    ),
                )
                modules_list.append(getattr(self, cur_mask_module))
            if self.align_on:
                if self.bipyramid_on:
                    for lvl in range(num_levels):
                        cur_mask_module = "align2nat_%02d" % lvl
                        lambda_val = 2**lvl
                        setattr(self, cur_mask_module, SwapAlign2Nat(lambda_val))
                    # Also the fusing layer, stay at the same channel size
                    mask_fuse = [
                        nn.Conv2d(cur_channels, cur_channels, kernel_size=3, stride=1, padding=1),
                        nn.ReLU(),
                    ]
                    self.mask_fuse = nn.Sequential(*mask_fuse)
                    modules_list.append(self.mask_fuse)
                else:
                    self.align2nat = SwapAlign2Nat(1)

        # Initialization
        for modules in modules_list:
            for layer in modules.modules():
                if isinstance(layer, nn.Conv2d):
                    torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
                    torch.nn.init.constant_(layer.bias, 0)

        # Use prior in model initialization to improve stability
        bias_value = -(math.log((1 - 0.01) / 0.01))
        torch.nn.init.constant_(self.cls_score.bias, bias_value)

    def forward(self, features):
        """
        Arguments:
            features (list[Tensor]): FPN feature map tensors in high to low resolution.
                Each tensor in the list correspond to different feature levels.

        Returns:
            pred_logits (list[Tensor]): #lvl tensors, each has shape (N, AxK, Hi, Wi).
                The tensor predicts the classification probability
                at each spatial position for each of the A anchors and K object
                classes.
            pred_deltas (list[Tensor]): #lvl tensors, each has shape (N, Ax4, Hi, Wi).
                The tensor predicts 4-vector (dx,dy,dw,dh) box
                regression values for every anchor. These values are the
                relative offset between the anchor and the ground truth box.
            pred_masks (list(list[Tensor])): #lvl list of tensors, each is a list of
                A tensors of shape (N, M_{i,a}, Hi, Wi).
                The tensor predicts a dense set of M_ixM_i masks at every location.
        """
        pred_logits = [self.cls_score(self.cls_subnet(x)) for x in features]
        pred_deltas = [self.bbox_pred(self.bbox_subnet(x)) for x in features]

        pred_masks = None
        if self.mask_on:
            mask_feats = [self.mask_subnet(x) for x in features]

            if self.bipyramid_on:
                mask_feat_high_res = mask_feats[0]
                H, W = mask_feat_high_res.shape[-2:]
                mask_feats_up = []
                for lvl, mask_feat in enumerate(mask_feats):
                    lambda_val = 2.0**lvl
                    mask_feat_up = mask_feat
                    if lvl > 0:
                        mask_feat_up = F.interpolate(
                            mask_feat, scale_factor=lambda_val, mode="bilinear", align_corners=False
                        )
                    mask_feats_up.append(
                        self.mask_fuse(mask_feat_up[:, :, :H, :W] + mask_feat_high_res)
                    )
                mask_feats = mask_feats_up

            pred_masks = []
            for lvl, mask_feat in enumerate(mask_feats):
                cur_masks = []
                for mask_size in self.mask_sizes:
                    cur_mask_module = getattr(self, "mask_pred_%02d" % mask_size)
                    cur_mask = cur_mask_module(mask_feat)
                    if self.align_on:
                        if self.bipyramid_on:
                            cur_mask_module = getattr(self, "align2nat_%02d" % lvl)
                            cur_mask = cur_mask_module(cur_mask)
                        else:
                            cur_mask = self.align2nat(cur_mask)
                    cur_masks.append(cur_mask)
                pred_masks.append(cur_masks)
        return pred_logits, pred_deltas, pred_masks